mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 08:11:00 +00:00
client replay: re-design
Re-design the way client replay works. Before, we would fire up a thread, replay, wait for the thread to complete, get the next flow, and repeat the procedure. Now, we have one replay thread that starts when the addon starts, which pops flows off a thread-safe queue. This is much cleaner, removes the need for busy tick, and sets the scene for optimisations like server connection reuse down the track.
This commit is contained in:
parent
28d53d5a24
commit
236a2fb6fd
@ -1,3 +1,6 @@
|
||||
import queue
|
||||
import typing
|
||||
|
||||
from mitmproxy import log
|
||||
from mitmproxy import controller
|
||||
from mitmproxy import exceptions
|
||||
@ -14,46 +17,47 @@ from mitmproxy import io
|
||||
from mitmproxy import command
|
||||
import mitmproxy.types
|
||||
|
||||
import typing
|
||||
|
||||
|
||||
class RequestReplayThread(basethread.BaseThread):
|
||||
name = "RequestReplayThread"
|
||||
daemon = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
opts: options.Options,
|
||||
f: http.HTTPFlow,
|
||||
channel: controller.Channel,
|
||||
queue: queue.Queue,
|
||||
) -> None:
|
||||
self.options = opts
|
||||
self.f = f
|
||||
f.live = True
|
||||
self.channel = channel
|
||||
super().__init__(
|
||||
"RequestReplay (%s)" % f.request.url
|
||||
)
|
||||
self.daemon = True
|
||||
self.queue = queue
|
||||
super().__init__("RequestReplayThread")
|
||||
|
||||
def run(self):
|
||||
r = self.f.request
|
||||
while True:
|
||||
f = self.queue.get(block=True, timeout=None)
|
||||
self.replay(f)
|
||||
|
||||
def replay(self, f):
|
||||
f.live = True
|
||||
r = f.request
|
||||
bsl = human.parse_size(self.options.body_size_limit)
|
||||
first_line_format_backup = r.first_line_format
|
||||
server = None
|
||||
try:
|
||||
self.f.response = None
|
||||
f.response = None
|
||||
|
||||
# If we have a channel, run script hooks.
|
||||
if self.channel:
|
||||
request_reply = self.channel.ask("request", self.f)
|
||||
request_reply = self.channel.ask("request", f)
|
||||
if isinstance(request_reply, http.HTTPResponse):
|
||||
self.f.response = request_reply
|
||||
f.response = request_reply
|
||||
|
||||
if not self.f.response:
|
||||
if not f.response:
|
||||
# In all modes, we directly connect to the server displayed
|
||||
if self.options.mode.startswith("upstream:"):
|
||||
server_address = server_spec.parse_with_mode(self.options.mode)[1].address
|
||||
server = connections.ServerConnection(server_address, (self.options.listen_host, 0))
|
||||
server = connections.ServerConnection(
|
||||
server_address, (self.options.listen_host, 0)
|
||||
)
|
||||
server.connect()
|
||||
if r.scheme == "https":
|
||||
connect_request = http.make_connect_request((r.data.host, r.port))
|
||||
@ -65,9 +69,11 @@ class RequestReplayThread(basethread.BaseThread):
|
||||
body_size_limit=bsl
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
raise exceptions.ReplayException("Upstream server refuses CONNECT request")
|
||||
raise exceptions.ReplayException(
|
||||
"Upstream server refuses CONNECT request"
|
||||
)
|
||||
server.establish_tls(
|
||||
sni=self.f.server_conn.sni,
|
||||
sni=f.server_conn.sni,
|
||||
**tls.client_arguments_from_options(self.options)
|
||||
)
|
||||
r.first_line_format = "relative"
|
||||
@ -82,7 +88,7 @@ class RequestReplayThread(basethread.BaseThread):
|
||||
server.connect()
|
||||
if r.scheme == "https":
|
||||
server.establish_tls(
|
||||
sni=self.f.server_conn.sni,
|
||||
sni=f.server_conn.sni,
|
||||
**tls.client_arguments_from_options(self.options)
|
||||
)
|
||||
r.first_line_format = "relative"
|
||||
@ -90,104 +96,44 @@ class RequestReplayThread(basethread.BaseThread):
|
||||
server.wfile.write(http1.assemble_request(r))
|
||||
server.wfile.flush()
|
||||
|
||||
if self.f.server_conn:
|
||||
self.f.server_conn.close()
|
||||
self.f.server_conn = server
|
||||
if f.server_conn:
|
||||
f.server_conn.close()
|
||||
f.server_conn = server
|
||||
|
||||
self.f.response = http.HTTPResponse.wrap(
|
||||
http1.read_response(
|
||||
server.rfile,
|
||||
r,
|
||||
body_size_limit=bsl
|
||||
f.response = http.HTTPResponse.wrap(
|
||||
http1.read_response(server.rfile, r, body_size_limit=bsl)
|
||||
)
|
||||
)
|
||||
if self.channel:
|
||||
response_reply = self.channel.ask("response", self.f)
|
||||
response_reply = self.channel.ask("response", f)
|
||||
if response_reply == exceptions.Kill:
|
||||
raise exceptions.Kill()
|
||||
except (exceptions.ReplayException, exceptions.NetlibException) as e:
|
||||
self.f.error = flow.Error(str(e))
|
||||
if self.channel:
|
||||
self.channel.ask("error", self.f)
|
||||
f.error = flow.Error(str(e))
|
||||
self.channel.ask("error", f)
|
||||
except exceptions.Kill:
|
||||
# Kill should only be raised if there's a channel in the
|
||||
# first place.
|
||||
self.channel.tell(
|
||||
"log",
|
||||
log.LogEntry("Connection killed", "info")
|
||||
)
|
||||
self.channel.tell("log", log.LogEntry("Connection killed", "info"))
|
||||
except Exception as e:
|
||||
self.channel.tell(
|
||||
"log",
|
||||
log.LogEntry(repr(e), "error")
|
||||
)
|
||||
self.channel.tell("log", log.LogEntry(repr(e), "error"))
|
||||
finally:
|
||||
r.first_line_format = first_line_format_backup
|
||||
self.f.live = False
|
||||
f.live = False
|
||||
if server.connected():
|
||||
server.finish()
|
||||
|
||||
|
||||
class ClientPlayback:
|
||||
def __init__(self):
|
||||
self.flows: typing.List[flow.Flow] = []
|
||||
self.current_thread = None
|
||||
self.configured = False
|
||||
|
||||
def replay_request(
|
||||
self,
|
||||
f: http.HTTPFlow,
|
||||
block: bool=False
|
||||
) -> RequestReplayThread:
|
||||
"""
|
||||
Replay a HTTP request to receive a new response from the server.
|
||||
|
||||
Args:
|
||||
f: The flow to replay.
|
||||
block: If True, this function will wait for the replay to finish.
|
||||
This causes a deadlock if activated in the main thread.
|
||||
|
||||
Returns:
|
||||
The thread object doing the replay.
|
||||
|
||||
Raises:
|
||||
exceptions.ReplayException, if the flow is in a state
|
||||
where it is ineligible for replay.
|
||||
"""
|
||||
self.q: queue.Queue = queue.Queue()
|
||||
self.thread: RequestReplayThread | None = None
|
||||
|
||||
def check(self, f: http.HTTPFlow):
|
||||
if f.live:
|
||||
raise exceptions.ReplayException(
|
||||
"Can't replay live flow."
|
||||
)
|
||||
return "Can't replay live flow."
|
||||
if f.intercepted:
|
||||
raise exceptions.ReplayException(
|
||||
"Can't replay intercepted flow."
|
||||
)
|
||||
return "Can't replay intercepted flow."
|
||||
if not f.request:
|
||||
raise exceptions.ReplayException(
|
||||
"Can't replay flow with missing request."
|
||||
)
|
||||
return "Can't replay flow with missing request."
|
||||
if f.request.raw_content is None:
|
||||
raise exceptions.ReplayException(
|
||||
"Can't replay flow with missing content."
|
||||
)
|
||||
|
||||
f.backup()
|
||||
f.request.is_replay = True
|
||||
|
||||
f.response = None
|
||||
f.error = None
|
||||
|
||||
if f.request.http_version == "HTTP/2.0": # https://github.com/mitmproxy/mitmproxy/issues/2197
|
||||
f.request.http_version = "HTTP/1.1"
|
||||
host = f.request.headers.pop(":authority")
|
||||
f.request.headers.insert(0, "host", host)
|
||||
|
||||
rt = RequestReplayThread(ctx.master.options, f, ctx.master.channel)
|
||||
rt.start() # pragma: no cover
|
||||
if block:
|
||||
rt.join()
|
||||
return rt
|
||||
return "Can't replay flow with missing content."
|
||||
|
||||
def load(self, loader):
|
||||
loader.add_option(
|
||||
@ -195,65 +141,73 @@ class ClientPlayback:
|
||||
"Replay client requests from a saved file."
|
||||
)
|
||||
|
||||
def count(self) -> int:
|
||||
if self.current_thread:
|
||||
current = 1
|
||||
else:
|
||||
current = 0
|
||||
return current + len(self.flows)
|
||||
|
||||
@command.command("replay.client.stop")
|
||||
def stop_replay(self) -> None:
|
||||
"""
|
||||
Stop client replay.
|
||||
"""
|
||||
self.flows = []
|
||||
ctx.log.alert("Client replay stopped.")
|
||||
ctx.master.addons.trigger("update", [])
|
||||
|
||||
@command.command("replay.client")
|
||||
def start_replay(self, flows: typing.Sequence[flow.Flow]) -> None:
|
||||
"""
|
||||
Replay requests from flows.
|
||||
"""
|
||||
for f in flows:
|
||||
if f.live:
|
||||
raise exceptions.CommandError("Can't replay live flow.")
|
||||
self.flows = list(flows)
|
||||
ctx.log.alert("Replaying %s flows." % len(self.flows))
|
||||
ctx.master.addons.trigger("update", [])
|
||||
|
||||
@command.command("replay.client.file")
|
||||
def load_file(self, path: mitmproxy.types.Path) -> None:
|
||||
try:
|
||||
flows = io.read_flows_from_paths([path])
|
||||
except exceptions.FlowReadException as e:
|
||||
raise exceptions.CommandError(str(e))
|
||||
ctx.log.alert("Replaying %s flows." % len(self.flows))
|
||||
self.flows = flows
|
||||
ctx.master.addons.trigger("update", [])
|
||||
def running(self):
|
||||
self.thread = RequestReplayThread(
|
||||
ctx.options,
|
||||
ctx.master.channel,
|
||||
self.q,
|
||||
)
|
||||
self.thread.start()
|
||||
|
||||
def configure(self, updated):
|
||||
if not self.configured and ctx.options.client_replay:
|
||||
self.configured = True
|
||||
ctx.log.info("Client Replay: {}".format(ctx.options.client_replay))
|
||||
if "client_replay" in updated and ctx.options.client_replay:
|
||||
try:
|
||||
flows = io.read_flows_from_paths(ctx.options.client_replay)
|
||||
except exceptions.FlowReadException as e:
|
||||
raise exceptions.OptionsError(str(e))
|
||||
self.start_replay(flows)
|
||||
|
||||
def tick(self):
|
||||
current_is_done = self.current_thread and not self.current_thread.is_alive()
|
||||
can_start_new = not self.current_thread or current_is_done
|
||||
will_start_new = can_start_new and self.flows
|
||||
@command.command("replay.client.count")
|
||||
def count(self) -> int:
|
||||
"""
|
||||
Approximate number of flows queued for replay.
|
||||
"""
|
||||
return self.q.qsize()
|
||||
|
||||
if current_is_done:
|
||||
self.current_thread = None
|
||||
ctx.master.addons.trigger("update", [])
|
||||
if will_start_new:
|
||||
f = self.flows.pop(0)
|
||||
self.current_thread = self.replay_request(f)
|
||||
ctx.master.addons.trigger("update", [f])
|
||||
if current_is_done and not will_start_new:
|
||||
ctx.master.addons.trigger("processing_complete")
|
||||
@command.command("replay.client.stop")
|
||||
def stop_replay(self) -> None:
|
||||
"""
|
||||
Clear the replay queue.
|
||||
"""
|
||||
with self.q.mutex:
|
||||
lst = list(self.q.queue)
|
||||
self.q.queue.clear()
|
||||
ctx.master.addons.trigger("update", lst)
|
||||
ctx.log.alert("Client replay queue cleared.")
|
||||
|
||||
@command.command("replay.client")
|
||||
def start_replay(self, flows: typing.Sequence[flow.Flow]) -> None:
|
||||
"""
|
||||
Add flows to the replay queue, skipping flows that can't be replayed.
|
||||
"""
|
||||
lst = []
|
||||
for f in flows:
|
||||
err = self.check(f)
|
||||
if err:
|
||||
ctx.log.warn(err)
|
||||
continue
|
||||
|
||||
lst.append(f)
|
||||
# Prepare the flow for replay
|
||||
f.backup()
|
||||
f.request.is_replay = True
|
||||
f.response = None
|
||||
f.error = None
|
||||
# https://github.com/mitmproxy/mitmproxy/issues/2197
|
||||
if f.request.http_version == "HTTP/2.0":
|
||||
f.request.http_version = "HTTP/1.1"
|
||||
host = f.request.headers.pop(":authority")
|
||||
f.request.headers.insert(0, "host", host)
|
||||
self.q.put(f)
|
||||
ctx.master.addons.trigger("update", lst)
|
||||
|
||||
@command.command("replay.client.file")
|
||||
def load_file(self, path: mitmproxy.types.Path) -> None:
|
||||
"""
|
||||
Load flows from file, and add them to the replay queue.
|
||||
"""
|
||||
try:
|
||||
flows = io.read_flows_from_paths([path])
|
||||
except exceptions.FlowReadException as e:
|
||||
raise exceptions.CommandError(str(e))
|
||||
self.start_replay(flows)
|
||||
|
@ -112,12 +112,10 @@ class context:
|
||||
if addon not in self.master.addons:
|
||||
self.master.addons.register(addon)
|
||||
with self.options.rollback(kwargs.keys(), reraise=True):
|
||||
if kwargs:
|
||||
self.options.update(**kwargs)
|
||||
self.master.addons.invoke_addon(
|
||||
addon,
|
||||
"configure",
|
||||
kwargs.keys()
|
||||
)
|
||||
else:
|
||||
self.master.addons.invoke_addon(addon, "configure", {})
|
||||
|
||||
def script(self, path):
|
||||
"""
|
||||
|
@ -12,11 +12,11 @@ class TestCheckCA:
|
||||
async def test_check_ca(self, expired):
|
||||
msg = 'The mitmproxy certificate authority has expired!'
|
||||
|
||||
with taddons.context() as tctx:
|
||||
a = check_ca.CheckCA()
|
||||
with taddons.context(a) as tctx:
|
||||
tctx.master.server = mock.MagicMock()
|
||||
tctx.master.server.config.certstore.default_ca.has_expired = mock.MagicMock(
|
||||
return_value = expired
|
||||
)
|
||||
a = check_ca.CheckCA()
|
||||
tctx.configure(a)
|
||||
assert await tctx.master.await_log(msg) == expired
|
||||
|
@ -92,47 +92,13 @@ class TestClientPlayback:
|
||||
# assert rt.f.request.http_version == "HTTP/1.1"
|
||||
# assert ":authority" not in rt.f.request.headers
|
||||
|
||||
def test_playback(self):
|
||||
cp = clientplayback.ClientPlayback()
|
||||
with taddons.context(cp) as tctx:
|
||||
assert cp.count() == 0
|
||||
f = tflow.tflow(resp=True)
|
||||
cp.start_replay([f])
|
||||
assert cp.count() == 1
|
||||
RP = "mitmproxy.addons.clientplayback.RequestReplayThread"
|
||||
with mock.patch(RP) as rp:
|
||||
assert not cp.current_thread
|
||||
cp.tick()
|
||||
assert rp.called
|
||||
assert cp.current_thread
|
||||
|
||||
cp.flows = []
|
||||
cp.current_thread.is_alive.return_value = False
|
||||
assert cp.count() == 1
|
||||
cp.tick()
|
||||
assert cp.count() == 0
|
||||
assert tctx.master.has_event("update")
|
||||
assert tctx.master.has_event("processing_complete")
|
||||
|
||||
cp.current_thread = MockThread()
|
||||
cp.tick()
|
||||
assert cp.current_thread is None
|
||||
|
||||
cp.start_replay([f])
|
||||
cp.stop_replay()
|
||||
assert not cp.flows
|
||||
|
||||
df = tflow.DummyFlow(tflow.tclient_conn(), tflow.tserver_conn(), True)
|
||||
with pytest.raises(exceptions.CommandError, match="Can't replay live flow."):
|
||||
cp.start_replay([df])
|
||||
|
||||
def test_load_file(self, tmpdir):
|
||||
cp = clientplayback.ClientPlayback()
|
||||
with taddons.context(cp):
|
||||
fpath = str(tmpdir.join("flows"))
|
||||
tdump(fpath, [tflow.tflow(resp=True)])
|
||||
cp.load_file(fpath)
|
||||
assert cp.flows
|
||||
assert cp.count() == 1
|
||||
with pytest.raises(exceptions.CommandError):
|
||||
cp.load_file("/nonexistent")
|
||||
|
||||
@ -141,11 +107,39 @@ class TestClientPlayback:
|
||||
with taddons.context(cp) as tctx:
|
||||
path = str(tmpdir.join("flows"))
|
||||
tdump(path, [tflow.tflow()])
|
||||
assert cp.count() == 0
|
||||
tctx.configure(cp, client_replay=[path])
|
||||
cp.configured = False
|
||||
assert cp.count() == 1
|
||||
tctx.configure(cp, client_replay=[])
|
||||
cp.configured = False
|
||||
tctx.configure(cp)
|
||||
cp.configured = False
|
||||
with pytest.raises(exceptions.OptionsError):
|
||||
tctx.configure(cp, client_replay=["nonexistent"])
|
||||
|
||||
def test_check(self):
|
||||
cp = clientplayback.ClientPlayback()
|
||||
with taddons.context(cp):
|
||||
f = tflow.tflow(resp=True)
|
||||
f.live = True
|
||||
assert "live flow" in cp.check(f)
|
||||
|
||||
f = tflow.tflow(resp=True)
|
||||
f.intercepted = True
|
||||
assert "intercepted flow" in cp.check(f)
|
||||
|
||||
f = tflow.tflow(resp=True)
|
||||
f.request = None
|
||||
assert "missing request" in cp.check(f)
|
||||
|
||||
f = tflow.tflow(resp=True)
|
||||
f.request.raw_content = None
|
||||
assert "missing content" in cp.check(f)
|
||||
|
||||
def test_playback(self):
|
||||
cp = clientplayback.ClientPlayback()
|
||||
with taddons.context(cp):
|
||||
assert cp.count() == 0
|
||||
f = tflow.tflow(resp=True)
|
||||
cp.start_replay([f])
|
||||
assert cp.count() == 1
|
||||
|
||||
cp.stop_replay()
|
||||
assert cp.count() == 0
|
@ -42,7 +42,7 @@ def corrupt_data():
|
||||
class TestReadFile:
|
||||
def test_configure(self):
|
||||
rf = readfile.ReadFile()
|
||||
with taddons.context() as tctx:
|
||||
with taddons.context(rf) as tctx:
|
||||
tctx.configure(rf, readfile_filter="~q")
|
||||
with pytest.raises(Exception, match="Invalid readfile filter"):
|
||||
tctx.configure(rf, readfile_filter="~~")
|
||||
|
@ -11,7 +11,7 @@ from mitmproxy.addons import view
|
||||
|
||||
def test_configure(tmpdir):
|
||||
sa = save.Save()
|
||||
with taddons.context() as tctx:
|
||||
with taddons.context(sa) as tctx:
|
||||
with pytest.raises(exceptions.OptionsError):
|
||||
tctx.configure(sa, save_stream_file=str(tmpdir))
|
||||
with pytest.raises(Exception, match="Invalid filter"):
|
||||
@ -32,7 +32,7 @@ def rd(p):
|
||||
|
||||
def test_tcp(tmpdir):
|
||||
sa = save.Save()
|
||||
with taddons.context() as tctx:
|
||||
with taddons.context(sa) as tctx:
|
||||
p = str(tmpdir.join("foo"))
|
||||
tctx.configure(sa, save_stream_file=p)
|
||||
|
||||
@ -45,7 +45,7 @@ def test_tcp(tmpdir):
|
||||
|
||||
def test_websocket(tmpdir):
|
||||
sa = save.Save()
|
||||
with taddons.context() as tctx:
|
||||
with taddons.context(sa) as tctx:
|
||||
p = str(tmpdir.join("foo"))
|
||||
tctx.configure(sa, save_stream_file=p)
|
||||
|
||||
@ -78,7 +78,7 @@ def test_save_command(tmpdir):
|
||||
|
||||
def test_simple(tmpdir):
|
||||
sa = save.Save()
|
||||
with taddons.context() as tctx:
|
||||
with taddons.context(sa) as tctx:
|
||||
p = str(tmpdir.join("foo"))
|
||||
|
||||
tctx.configure(sa, save_stream_file=p)
|
||||
|
@ -92,14 +92,13 @@ class TestScript:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple(self, tdata):
|
||||
with taddons.context() as tctx:
|
||||
sc = script.Script(
|
||||
tdata.path(
|
||||
"mitmproxy/data/addonscripts/recorder/recorder.py"
|
||||
),
|
||||
True,
|
||||
)
|
||||
tctx.master.addons.add(sc)
|
||||
with taddons.context(sc) as tctx:
|
||||
tctx.configure(sc)
|
||||
await tctx.master.await_log("recorder running")
|
||||
rec = tctx.master.addons.get("recorder")
|
||||
@ -284,7 +283,7 @@ class TestScriptLoader:
|
||||
rec = tdata.path("mitmproxy/data/addonscripts/recorder")
|
||||
sc = script.ScriptLoader()
|
||||
sc.is_running = True
|
||||
with taddons.context() as tctx:
|
||||
with taddons.context(sc) as tctx:
|
||||
tctx.configure(
|
||||
sc,
|
||||
scripts = [
|
||||
|
@ -155,7 +155,7 @@ def test_create():
|
||||
|
||||
def test_orders():
|
||||
v = view.View()
|
||||
with taddons.context():
|
||||
with taddons.context(v):
|
||||
assert v.order_options()
|
||||
|
||||
|
||||
@ -303,7 +303,7 @@ def test_setgetval():
|
||||
|
||||
def test_order():
|
||||
v = view.View()
|
||||
with taddons.context() as tctx:
|
||||
with taddons.context(v) as tctx:
|
||||
v.request(tft(method="get", start=1))
|
||||
v.request(tft(method="put", start=2))
|
||||
v.request(tft(method="get", start=3))
|
||||
@ -434,7 +434,7 @@ def test_signals():
|
||||
|
||||
def test_focus_follow():
|
||||
v = view.View()
|
||||
with taddons.context() as tctx:
|
||||
with taddons.context(v) as tctx:
|
||||
console_addon = consoleaddons.ConsoleAddon(tctx.master)
|
||||
tctx.configure(console_addon)
|
||||
tctx.configure(v, console_focus_follow=True, view_filter="~m get")
|
||||
@ -553,7 +553,7 @@ def test_settings():
|
||||
|
||||
def test_configure():
|
||||
v = view.View()
|
||||
with taddons.context() as tctx:
|
||||
with taddons.context(v) as tctx:
|
||||
tctx.configure(v, view_filter="~q")
|
||||
with pytest.raises(Exception, match="Invalid interception filter"):
|
||||
tctx.configure(v, view_filter="~~")
|
||||
|
Loading…
Reference in New Issue
Block a user