Merge pull request #3090 from cortesi/iflight

Redesign keepserving
This commit is contained in:
Aldo Cortesi 2018-05-02 13:45:06 +12:00 committed by GitHub
commit d12186a935
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 99 additions and 62 deletions

View File

@ -1,4 +1,5 @@
import queue import queue
import threading
import typing import typing
from mitmproxy import log from mitmproxy import log
@ -30,12 +31,15 @@ class RequestReplayThread(basethread.BaseThread):
self.options = opts self.options = opts
self.channel = channel self.channel = channel
self.queue = queue self.queue = queue
self.inflight = threading.Event()
super().__init__("RequestReplayThread") super().__init__("RequestReplayThread")
def run(self): def run(self):
while True: while True:
f = self.queue.get() f = self.queue.get()
self.inflight.set()
self.replay(f) self.replay(f)
self.inflight.clear()
def replay(self, f): # pragma: no cover def replay(self, f): # pragma: no cover
f.live = True f.live = True
@ -163,7 +167,8 @@ class ClientPlayback:
""" """
Approximate number of flows queued for replay. Approximate number of flows queued for replay.
""" """
return self.q.qsize() inflight = 1 if self.thread and self.thread.inflight.is_set() else 0
return self.q.qsize() + inflight
@command.command("replay.client.stop") @command.command("replay.client.stop")
def stop_replay(self) -> None: def stop_replay(self) -> None:

View File

@ -1,3 +1,4 @@
import asyncio
from mitmproxy import ctx from mitmproxy import ctx
@ -12,6 +13,28 @@ class KeepServing:
""" """
) )
def event_processing_complete(self): def keepgoing(self) -> bool:
if not ctx.master.options.keepserving: checks = [
"readfile.reading",
"replay.client.count",
"replay.server.count",
]
return any([ctx.master.commands.call(c) for c in checks])
def shutdown(self): # pragma: no cover
ctx.master.shutdown() ctx.master.shutdown()
async def watch(self):
while True:
await asyncio.sleep(0.1)
if not self.keepgoing():
self.shutdown()
def running(self):
opts = [
ctx.options.client_replay,
ctx.options.server_replay,
ctx.options.rfile,
]
if any(opts) and not ctx.options.keepserving:
asyncio.get_event_loop().create_task(self.watch())

View File

@ -7,6 +7,7 @@ from mitmproxy import ctx
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy import flowfilter from mitmproxy import flowfilter
from mitmproxy import io from mitmproxy import io
from mitmproxy import command
class ReadFile: class ReadFile:
@ -15,6 +16,7 @@ class ReadFile:
""" """
def __init__(self): def __init__(self):
self.filter = None self.filter = None
self.is_reading = False
def load(self, loader): def load(self, loader):
loader.add_option( loader.add_option(
@ -65,17 +67,23 @@ class ReadFile:
raise exceptions.FlowReadException(str(e)) from e raise exceptions.FlowReadException(str(e)) from e
async def doread(self, rfile): async def doread(self, rfile):
self.is_reading = True
try: try:
await self.load_flows_from_path(ctx.options.rfile) await self.load_flows_from_path(ctx.options.rfile)
except exceptions.FlowReadException as e: except exceptions.FlowReadException as e:
raise exceptions.OptionsError(e) from e raise exceptions.OptionsError(e) from e
finally: finally:
self.is_reading = False
ctx.master.addons.trigger("processing_complete") ctx.master.addons.trigger("processing_complete")
def running(self): def running(self):
if ctx.options.rfile: if ctx.options.rfile:
asyncio.get_event_loop().create_task(self.doread(ctx.options.rfile)) asyncio.get_event_loop().create_task(self.doread(ctx.options.rfile))
@command.command("readfile.reading")
def reading(self) -> bool:
return self.is_reading
class ReadFileStdin(ReadFile): class ReadFileStdin(ReadFile):
"""Support the special case of "-" for reading from stdin""" """Support the special case of "-" for reading from stdin"""

View File

@ -13,8 +13,6 @@ import mitmproxy.types
class ServerPlayback: class ServerPlayback:
def __init__(self): def __init__(self):
self.flowmap = {} self.flowmap = {}
self.stop = False
self.final_flow = None
self.configured = False self.configured = False
def load(self, loader): def load(self, loader):
@ -99,7 +97,8 @@ class ServerPlayback:
self.flowmap = {} self.flowmap = {}
ctx.master.addons.trigger("update", []) ctx.master.addons.trigger("update", [])
def count(self): @command.command("replay.server.count")
def count(self) -> int:
return sum([len(i) for i in self.flowmap.values()]) return sum([len(i) for i in self.flowmap.values()])
def _hash(self, flow): def _hash(self, flow):
@ -174,10 +173,6 @@ class ServerPlayback:
raise exceptions.OptionsError(str(e)) raise exceptions.OptionsError(str(e))
self.load_flows(flows) self.load_flows(flows)
def tick(self):
if self.stop and not self.final_flow.live:
ctx.master.addons.trigger("processing_complete")
def request(self, f): def request(self, f):
if self.flowmap: if self.flowmap:
rflow = self.next_flow(f) rflow = self.next_flow(f)
@ -187,9 +182,6 @@ class ServerPlayback:
if ctx.options.server_replay_refresh: if ctx.options.server_replay_refresh:
response.refresh() response.refresh()
f.response = response f.response = response
if not self.flowmap:
self.final_flow = f
self.stop = True
elif ctx.options.server_replay_kill_extra: elif ctx.options.server_replay_kill_extra:
ctx.log.warn( ctx.log.warn(
"server_playback: killed non-replay request {}".format( "server_playback: killed non-replay request {}".format(

View File

@ -127,10 +127,7 @@ class Master:
""" """
if not self.should_exit.is_set(): if not self.should_exit.is_set():
self.should_exit.set() self.should_exit.set()
asyncio.run_coroutine_threadsafe( asyncio.run_coroutine_threadsafe(self._shutdown(), loop = self.channel.loop)
self._shutdown(),
loop = self.channel.loop,
)
def _change_reverse_host(self, f): def _change_reverse_host(self, f):
""" """

View File

@ -17,10 +17,6 @@ class TestAddons(addonmanager.AddonManager):
def trigger(self, event, *args, **kwargs): def trigger(self, event, *args, **kwargs):
if event == "log": if event == "log":
self.master.logs.append(args[0]) self.master.logs.append(args[0])
elif event == "tick" and not args and not kwargs:
pass
else:
self.master.events.append((event, args, kwargs))
super().trigger(event, *args, **kwargs) super().trigger(event, *args, **kwargs)
@ -28,7 +24,6 @@ class RecordingMaster(mitmproxy.master.Master):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.addons = TestAddons(self) self.addons = TestAddons(self)
self.events = []
self.logs = [] self.logs = []
def dump_log(self, outf=sys.stdout): def dump_log(self, outf=sys.stdout):
@ -51,12 +46,6 @@ class RecordingMaster(mitmproxy.master.Master):
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
return False return False
def has_event(self, name):
for i in self.events:
if i[0] == name:
return True
return False
def clear(self): def clear(self):
self.logs = [] self.logs = []

View File

@ -14,7 +14,7 @@ class CommandExecutor:
def __call__(self, cmd): def __call__(self, cmd):
if cmd.strip(): if cmd.strip():
try: try:
ret = self.master.commands.call(cmd) ret = self.master.commands.execute(cmd)
except exceptions.CommandError as v: except exceptions.CommandError as v:
signals.status_message.send(message=str(v)) signals.status_message.send(message=str(v))
else: else:

View File

@ -79,11 +79,11 @@ class FlowListBox(urwid.ListBox, layoutwidget.LayoutWidget):
def keypress(self, size, key): def keypress(self, size, key):
if key == "m_start": if key == "m_start":
self.master.commands.call("view.go 0") self.master.commands.execute("view.go 0")
elif key == "m_end": elif key == "m_end":
self.master.commands.call("view.go -1") self.master.commands.execute("view.go -1")
elif key == "m_select": elif key == "m_select":
self.master.commands.call("console.view.flow @focus") self.master.commands.execute("console.view.flow @focus")
return urwid.ListBox.keypress(self, size, key) return urwid.ListBox.keypress(self, size, key)
def view_changed(self): def view_changed(self):

View File

@ -98,7 +98,7 @@ class FlowDetails(tabs.Tabs):
msg, body = "", [urwid.Text([("error", "[content missing]")])] msg, body = "", [urwid.Text([("error", "[content missing]")])]
return msg, body return msg, body
else: else:
full = self.master.commands.call("view.getval @focus fullcontents false") full = self.master.commands.execute("view.getval @focus fullcontents false")
if full == "true": if full == "true":
limit = sys.maxsize limit = sys.maxsize
else: else:

View File

@ -158,6 +158,7 @@ class ActionBar(urwid.WidgetWrap):
class StatusBar(urwid.WidgetWrap): class StatusBar(urwid.WidgetWrap):
REFRESHTIME = 0.5 # Timed refresh time in seconds
keyctx = "" keyctx = ""
def __init__( def __init__(
@ -173,7 +174,11 @@ class StatusBar(urwid.WidgetWrap):
master.options.changed.connect(self.sig_update) master.options.changed.connect(self.sig_update)
master.view.focus.sig_change.connect(self.sig_update) master.view.focus.sig_change.connect(self.sig_update)
master.view.sig_view_add.connect(self.sig_update) master.view.sig_view_add.connect(self.sig_update)
self.refresh()
def refresh(self):
self.redraw() self.redraw()
signals.call_in.send(seconds=self.REFRESHTIME, callback=self.refresh)
def sig_update(self, sender, flow=None, updated=None): def sig_update(self, sender, flow=None, updated=None):
self.redraw() self.redraw()
@ -184,7 +189,7 @@ class StatusBar(urwid.WidgetWrap):
def get_status(self): def get_status(self):
r = [] r = []
sreplay = self.master.addons.get("serverplayback") sreplay = self.master.commands.call("replay.server.count")
creplay = self.master.commands.call("replay.client.count") creplay = self.master.commands.call("replay.client.count")
if len(self.master.options.setheaders): if len(self.master.options.setheaders):
@ -197,10 +202,10 @@ class StatusBar(urwid.WidgetWrap):
r.append("[") r.append("[")
r.append(("heading_key", "cplayback")) r.append(("heading_key", "cplayback"))
r.append(":%s]" % creplay) r.append(":%s]" % creplay)
if sreplay.count(): if sreplay:
r.append("[") r.append("[")
r.append(("heading_key", "splayback")) r.append(("heading_key", "splayback"))
r.append(":%s]" % sreplay.count()) r.append(":%s]" % sreplay)
if self.master.options.ignore_hosts: if self.master.options.ignore_hosts:
r.append("[") r.append("[")
r.append(("heading_key", "I")) r.append(("heading_key", "I"))

View File

@ -103,8 +103,6 @@ def run(
server = proxy.server.DummyServer(pconf) server = proxy.server.DummyServer(pconf)
master.server = server master.server = server
master.addons.trigger("configure", opts.keys())
master.addons.trigger("tick")
opts.update_known(**unknown) opts.update_known(**unknown)
if args.options: if args.options:
print(optmanager.dump_defaults(opts)) print(optmanager.dump_defaults(opts))

View File

@ -3,12 +3,48 @@ import pytest
from mitmproxy.addons import keepserving from mitmproxy.addons import keepserving
from mitmproxy.test import taddons from mitmproxy.test import taddons
from mitmproxy import command
class Dummy:
def __init__(self, val: bool):
self.val = val
def load(self, loader):
loader.add_option("client_replay", bool, self.val, "test")
loader.add_option("server_replay", bool, self.val, "test")
loader.add_option("rfile", bool, self.val, "test")
@command.command("readfile.reading")
def readfile(self) -> bool:
return self.val
@command.command("replay.client.count")
def creplay(self) -> int:
return 1 if self.val else 0
@command.command("replay.server.count")
def sreplay(self) -> int:
return 1 if self.val else 0
class TKS(keepserving.KeepServing):
_is_shutdown = False
def shutdown(self):
self.is_shutdown = True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_keepserving(): async def test_keepserving():
ks = keepserving.KeepServing() ks = TKS()
d = Dummy(True)
with taddons.context(ks) as tctx: with taddons.context(ks) as tctx:
ks.event_processing_complete() tctx.master.addons.add(d)
asyncio.sleep(0.1) ks.running()
assert tctx.master.should_exit.is_set() assert ks.keepgoing()
d.val = False
assert not ks.keepgoing()
await asyncio.sleep(0.3)
assert ks.is_shutdown

View File

@ -51,6 +51,8 @@ class TestReadFile:
async def test_read(self, tmpdir, data, corrupt_data): async def test_read(self, tmpdir, data, corrupt_data):
rf = readfile.ReadFile() rf = readfile.ReadFile()
with taddons.context(rf) as tctx: with taddons.context(rf) as tctx:
assert not rf.reading()
tf = tmpdir.join("tfile") tf = tmpdir.join("tfile")
with asynctest.patch('mitmproxy.master.Master.load_flow') as mck: with asynctest.patch('mitmproxy.master.Master.load_flow') as mck:

View File

@ -39,16 +39,6 @@ def test_config(tmpdir):
tctx.configure(s, server_replay=[str(tmpdir)]) tctx.configure(s, server_replay=[str(tmpdir)])
def test_tick():
s = serverplayback.ServerPlayback()
with taddons.context(s) as tctx:
s.stop = True
s.final_flow = tflow.tflow()
s.final_flow.live = False
s.tick()
assert tctx.master.has_event("processing_complete")
def test_server_playback(): def test_server_playback():
sp = serverplayback.ServerPlayback() sp = serverplayback.ServerPlayback()
with taddons.context(sp) as tctx: with taddons.context(sp) as tctx:
@ -349,14 +339,6 @@ def test_server_playback_full():
s.request(tf) s.request(tf)
assert not tf.response assert not tf.response
assert not s.stop
s.tick()
assert not s.stop
tf = tflow.tflow()
s.request(tflow.tflow())
assert s.stop
def test_server_playback_kill(): def test_server_playback_kill():
s = serverplayback.ServerPlayback() s = serverplayback.ServerPlayback()

View File

@ -10,7 +10,6 @@ from mitmproxy import ctx
async def test_recordingmaster(): async def test_recordingmaster():
with taddons.context() as tctx: with taddons.context() as tctx:
assert not tctx.master.has_log("nonexistent") assert not tctx.master.has_log("nonexistent")
assert not tctx.master.has_event("nonexistent")
ctx.log.error("foo") ctx.log.error("foo")
assert not tctx.master.has_log("foo", level="debug") assert not tctx.master.has_log("foo", level="debug")
assert await tctx.master.await_log("foo", level="error") assert await tctx.master.await_log("foo", level="error")

View File

@ -30,6 +30,7 @@ def test_statusbar(monkeypatch):
m.options.update(view_order='url', console_focus_follow=True) m.options.update(view_order='url', console_focus_follow=True)
monkeypatch.setattr(m.addons.get("clientplayback"), "count", lambda: 42) monkeypatch.setattr(m.addons.get("clientplayback"), "count", lambda: 42)
monkeypatch.setattr(m.addons.get("serverplayback"), "count", lambda: 42) monkeypatch.setattr(m.addons.get("serverplayback"), "count", lambda: 42)
monkeypatch.setattr(statusbar.StatusBar, "refresh", lambda x: None)
bar = statusbar.StatusBar(m) # this already causes a redraw bar = statusbar.StatusBar(m) # this already causes a redraw
assert bar.ib._w assert bar.ib._w