Redesign keepserving

- Instead of listening for a pseudo-event, we periodically check whether client
replay, server replay or file reading is active.
- Adjust server replay not to
use tick.
- Adjust readfile to expose a command to check whether reading is in progress.
This commit is contained in:
Aldo Cortesi 2018-05-02 11:20:20 +12:00
parent e963408434
commit 22a4b1d5d4
9 changed files with 78 additions and 40 deletions

View File

@ -1,3 +1,4 @@
import asyncio
from mitmproxy import ctx from mitmproxy import ctx
@ -12,6 +13,28 @@ class KeepServing:
""" """
) )
def event_processing_complete(self): def keepgoing(self) -> bool:
if not ctx.master.options.keepserving: checks = [
"readfile.reading",
"replay.client.count",
"replay.server.count",
]
return any([ctx.master.commands.call(c) for c in checks])
def shutdown(self): # pragma: no cover
ctx.master.shutdown() ctx.master.shutdown()
async def watch(self):
while True:
await asyncio.sleep(0.1)
if not self.keepgoing():
self.shutdown()
def running(self):
opts = [
ctx.options.client_replay,
ctx.options.server_replay,
ctx.options.rfile,
]
if any(opts) and not ctx.options.keepserving:
asyncio.get_event_loop().create_task(self.watch())

View File

@ -7,6 +7,7 @@ from mitmproxy import ctx
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy import flowfilter from mitmproxy import flowfilter
from mitmproxy import io from mitmproxy import io
from mitmproxy import command
class ReadFile: class ReadFile:
@ -15,6 +16,7 @@ class ReadFile:
""" """
def __init__(self): def __init__(self):
self.filter = None self.filter = None
self.is_reading = False
def load(self, loader): def load(self, loader):
loader.add_option( loader.add_option(
@ -65,17 +67,23 @@ class ReadFile:
raise exceptions.FlowReadException(str(e)) from e raise exceptions.FlowReadException(str(e)) from e
async def doread(self, rfile): async def doread(self, rfile):
self.is_reading = True
try: try:
await self.load_flows_from_path(ctx.options.rfile) await self.load_flows_from_path(ctx.options.rfile)
except exceptions.FlowReadException as e: except exceptions.FlowReadException as e:
raise exceptions.OptionsError(e) from e raise exceptions.OptionsError(e) from e
finally: finally:
self.is_reading = False
ctx.master.addons.trigger("processing_complete") ctx.master.addons.trigger("processing_complete")
def running(self): def running(self):
if ctx.options.rfile: if ctx.options.rfile:
asyncio.get_event_loop().create_task(self.doread(ctx.options.rfile)) asyncio.get_event_loop().create_task(self.doread(ctx.options.rfile))
@command.command("readfile.reading")
def reading(self) -> bool:
return self.is_reading
class ReadFileStdin(ReadFile): class ReadFileStdin(ReadFile):
"""Support the special case of "-" for reading from stdin""" """Support the special case of "-" for reading from stdin"""

View File

@ -13,8 +13,6 @@ import mitmproxy.types
class ServerPlayback: class ServerPlayback:
def __init__(self): def __init__(self):
self.flowmap = {} self.flowmap = {}
self.stop = False
self.final_flow = None
self.configured = False self.configured = False
def load(self, loader): def load(self, loader):
@ -175,10 +173,6 @@ class ServerPlayback:
raise exceptions.OptionsError(str(e)) raise exceptions.OptionsError(str(e))
self.load_flows(flows) self.load_flows(flows)
def tick(self):
if self.stop and not self.final_flow.live:
ctx.master.addons.trigger("processing_complete")
def request(self, f): def request(self, f):
if self.flowmap: if self.flowmap:
rflow = self.next_flow(f) rflow = self.next_flow(f)
@ -188,9 +182,6 @@ class ServerPlayback:
if ctx.options.server_replay_refresh: if ctx.options.server_replay_refresh:
response.refresh() response.refresh()
f.response = response f.response = response
if not self.flowmap:
self.final_flow = f
self.stop = True
elif ctx.options.server_replay_kill_extra: elif ctx.options.server_replay_kill_extra:
ctx.log.warn( ctx.log.warn(
"server_playback: killed non-replay request {}".format( "server_playback: killed non-replay request {}".format(

View File

@ -127,10 +127,7 @@ class Master:
""" """
if not self.should_exit.is_set(): if not self.should_exit.is_set():
self.should_exit.set() self.should_exit.set()
asyncio.run_coroutine_threadsafe( asyncio.run_coroutine_threadsafe(self._shutdown(), loop = self.channel.loop)
self._shutdown(),
loop = self.channel.loop,
)
def _change_reverse_host(self, f): def _change_reverse_host(self, f):
""" """

View File

@ -103,8 +103,6 @@ def run(
server = proxy.server.DummyServer(pconf) server = proxy.server.DummyServer(pconf)
master.server = server master.server = server
master.addons.trigger("configure", opts.keys())
master.addons.trigger("tick")
opts.update_known(**unknown) opts.update_known(**unknown)
if args.options: if args.options:
print(optmanager.dump_defaults(opts)) print(optmanager.dump_defaults(opts))

View File

@ -3,12 +3,48 @@ import pytest
from mitmproxy.addons import keepserving from mitmproxy.addons import keepserving
from mitmproxy.test import taddons from mitmproxy.test import taddons
from mitmproxy import command
class Dummy:
def __init__(self, val: bool):
self.val = val
def load(self, loader):
loader.add_option("client_replay", bool, self.val, "test")
loader.add_option("server_replay", bool, self.val, "test")
loader.add_option("rfile", bool, self.val, "test")
@command.command("readfile.reading")
def readfile(self) -> bool:
return self.val
@command.command("replay.client.count")
def creplay(self) -> int:
return 1 if self.val else 0
@command.command("replay.server.count")
def sreplay(self) -> int:
return 1 if self.val else 0
class TKS(keepserving.KeepServing):
_is_shutdown = False
def shutdown(self):
self.is_shutdown = True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_keepserving(): async def test_keepserving():
ks = keepserving.KeepServing() ks = TKS()
d = Dummy(True)
with taddons.context(ks) as tctx: with taddons.context(ks) as tctx:
ks.event_processing_complete() tctx.master.addons.add(d)
asyncio.sleep(0.1) ks.running()
assert tctx.master.should_exit.is_set() assert ks.keepgoing()
d.val = False
assert not ks.keepgoing()
await asyncio.sleep(0.3)
assert ks.is_shutdown

View File

@ -51,6 +51,8 @@ class TestReadFile:
async def test_read(self, tmpdir, data, corrupt_data): async def test_read(self, tmpdir, data, corrupt_data):
rf = readfile.ReadFile() rf = readfile.ReadFile()
with taddons.context(rf) as tctx: with taddons.context(rf) as tctx:
assert not rf.reading()
tf = tmpdir.join("tfile") tf = tmpdir.join("tfile")
with asynctest.patch('mitmproxy.master.Master.load_flow') as mck: with asynctest.patch('mitmproxy.master.Master.load_flow') as mck:

View File

@ -39,16 +39,6 @@ def test_config(tmpdir):
tctx.configure(s, server_replay=[str(tmpdir)]) tctx.configure(s, server_replay=[str(tmpdir)])
def test_tick():
s = serverplayback.ServerPlayback()
with taddons.context(s) as tctx:
s.stop = True
s.final_flow = tflow.tflow()
s.final_flow.live = False
s.tick()
assert tctx.master.has_event("processing_complete")
def test_server_playback(): def test_server_playback():
sp = serverplayback.ServerPlayback() sp = serverplayback.ServerPlayback()
with taddons.context(sp) as tctx: with taddons.context(sp) as tctx:
@ -349,14 +339,6 @@ def test_server_playback_full():
s.request(tf) s.request(tf)
assert not tf.response assert not tf.response
assert not s.stop
s.tick()
assert not s.stop
tf = tflow.tflow()
s.request(tflow.tflow())
assert s.stop
def test_server_playback_kill(): def test_server_playback_kill():
s = serverplayback.ServerPlayback() s = serverplayback.ServerPlayback()

View File

@ -30,6 +30,7 @@ def test_statusbar(monkeypatch):
m.options.update(view_order='url', console_focus_follow=True) m.options.update(view_order='url', console_focus_follow=True)
monkeypatch.setattr(m.addons.get("clientplayback"), "count", lambda: 42) monkeypatch.setattr(m.addons.get("clientplayback"), "count", lambda: 42)
monkeypatch.setattr(m.addons.get("serverplayback"), "count", lambda: 42) monkeypatch.setattr(m.addons.get("serverplayback"), "count", lambda: 42)
monkeypatch.setattr(statusbar.StatusBar, "refresh", lambda x: None)
bar = statusbar.StatusBar(m) # this already causes a redraw bar = statusbar.StatusBar(m) # this already causes a redraw
assert bar.ib._w assert bar.ib._w