diff --git a/mitmproxy/addons/clientplayback.py b/mitmproxy/addons/clientplayback.py index 59459917a..11d2453ba 100644 --- a/mitmproxy/addons/clientplayback.py +++ b/mitmproxy/addons/clientplayback.py @@ -1,18 +1,140 @@ +import queue +import typing + +from mitmproxy import log +from mitmproxy import controller from mitmproxy import exceptions +from mitmproxy import http +from mitmproxy import flow +from mitmproxy import options +from mitmproxy import connections +from mitmproxy.net import server_spec, tls +from mitmproxy.net.http import http1 +from mitmproxy.coretypes import basethread +from mitmproxy.utils import human from mitmproxy import ctx from mitmproxy import io -from mitmproxy import flow from mitmproxy import command import mitmproxy.types -import typing + +class RequestReplayThread(basethread.BaseThread): + daemon = True + + def __init__( + self, + opts: options.Options, + channel: controller.Channel, + queue: queue.Queue, + ) -> None: + self.options = opts + self.channel = channel + self.queue = queue + super().__init__("RequestReplayThread") + + def run(self): + while True: + f = self.queue.get() + self.replay(f) + + def replay(self, f): # pragma: no cover + f.live = True + r = f.request + bsl = human.parse_size(self.options.body_size_limit) + first_line_format_backup = r.first_line_format + server = None + try: + f.response = None + + # If we have a channel, run script hooks. + request_reply = self.channel.ask("request", f) + if isinstance(request_reply, http.HTTPResponse): + f.response = request_reply + + if not f.response: + # In all modes, we directly connect to the server displayed + if self.options.mode.startswith("upstream:"): + server_address = server_spec.parse_with_mode(self.options.mode)[1].address + server = connections.ServerConnection( + server_address, (self.options.listen_host, 0) + ) + server.connect() + if r.scheme == "https": + connect_request = http.make_connect_request((r.data.host, r.port)) + server.wfile.write(http1.assemble_request(connect_request)) + server.wfile.flush() + resp = http1.read_response( + server.rfile, + connect_request, + body_size_limit=bsl + ) + if resp.status_code != 200: + raise exceptions.ReplayException( + "Upstream server refuses CONNECT request" + ) + server.establish_tls( + sni=f.server_conn.sni, + **tls.client_arguments_from_options(self.options) + ) + r.first_line_format = "relative" + else: + r.first_line_format = "absolute" + else: + server_address = (r.host, r.port) + server = connections.ServerConnection( + server_address, + (self.options.listen_host, 0) + ) + server.connect() + if r.scheme == "https": + server.establish_tls( + sni=f.server_conn.sni, + **tls.client_arguments_from_options(self.options) + ) + r.first_line_format = "relative" + + server.wfile.write(http1.assemble_request(r)) + server.wfile.flush() + + if f.server_conn: + f.server_conn.close() + f.server_conn = server + + f.response = http.HTTPResponse.wrap( + http1.read_response(server.rfile, r, body_size_limit=bsl) + ) + response_reply = self.channel.ask("response", f) + if response_reply == exceptions.Kill: + raise exceptions.Kill() + except (exceptions.ReplayException, exceptions.NetlibException) as e: + f.error = flow.Error(str(e)) + self.channel.ask("error", f) + except exceptions.Kill: + self.channel.tell("log", log.LogEntry("Connection killed", "info")) + except Exception as e: + self.channel.tell("log", log.LogEntry(repr(e), "error")) + finally: + r.first_line_format = first_line_format_backup + f.live = False + if server.connected(): + server.finish() + server.close() class ClientPlayback: def __init__(self): - self.flows: typing.List[flow.Flow] = [] - self.current_thread = None - self.configured = False + self.q = queue.Queue() + self.thread: RequestReplayThread = None + + def check(self, f: http.HTTPFlow): + if f.live: + return "Can't replay live flow." + if f.intercepted: + return "Can't replay intercepted flow." + if not f.request: + return "Can't replay flow with missing request." + if f.request.raw_content is None: + return "Can't replay flow with missing content." def load(self, loader): loader.add_option( @@ -20,65 +142,77 @@ class ClientPlayback: "Replay client requests from a saved file." ) - def count(self) -> int: - if self.current_thread: - current = 1 - else: - current = 0 - return current + len(self.flows) - - @command.command("replay.client.stop") - def stop_replay(self) -> None: - """ - Stop client replay. - """ - self.flows = [] - ctx.log.alert("Client replay stopped.") - ctx.master.addons.trigger("update", []) - - @command.command("replay.client") - def start_replay(self, flows: typing.Sequence[flow.Flow]) -> None: - """ - Replay requests from flows. - """ - for f in flows: - if f.live: - raise exceptions.CommandError("Can't replay live flow.") - self.flows = list(flows) - ctx.log.alert("Replaying %s flows." % len(self.flows)) - ctx.master.addons.trigger("update", []) - - @command.command("replay.client.file") - def load_file(self, path: mitmproxy.types.Path) -> None: - try: - flows = io.read_flows_from_paths([path]) - except exceptions.FlowReadException as e: - raise exceptions.CommandError(str(e)) - ctx.log.alert("Replaying %s flows." % len(self.flows)) - self.flows = flows - ctx.master.addons.trigger("update", []) + def running(self): + self.thread = RequestReplayThread( + ctx.options, + ctx.master.channel, + self.q, + ) + self.thread.start() def configure(self, updated): - if not self.configured and ctx.options.client_replay: - self.configured = True - ctx.log.info("Client Replay: {}".format(ctx.options.client_replay)) + if "client_replay" in updated and ctx.options.client_replay: try: flows = io.read_flows_from_paths(ctx.options.client_replay) except exceptions.FlowReadException as e: raise exceptions.OptionsError(str(e)) self.start_replay(flows) - def tick(self): - current_is_done = self.current_thread and not self.current_thread.is_alive() - can_start_new = not self.current_thread or current_is_done - will_start_new = can_start_new and self.flows + @command.command("replay.client.count") + def count(self) -> int: + """ + Approximate number of flows queued for replay. + """ + return self.q.qsize() - if current_is_done: - self.current_thread = None - ctx.master.addons.trigger("update", []) - if will_start_new: - f = self.flows.pop(0) - self.current_thread = ctx.master.replay_request(f) - ctx.master.addons.trigger("update", [f]) - if current_is_done and not will_start_new: - ctx.master.addons.trigger("processing_complete") + @command.command("replay.client.stop") + def stop_replay(self) -> None: + """ + Clear the replay queue. + """ + with self.q.mutex: + lst = list(self.q.queue) + self.q.queue.clear() + for f in lst: + f.revert() + ctx.master.addons.trigger("update", lst) + ctx.log.alert("Client replay queue cleared.") + + @command.command("replay.client") + def start_replay(self, flows: typing.Sequence[flow.Flow]) -> None: + """ + Add flows to the replay queue, skipping flows that can't be replayed. + """ + lst = [] + for f in flows: + hf = typing.cast(http.HTTPFlow, f) + + err = self.check(hf) + if err: + ctx.log.warn(err) + continue + + lst.append(hf) + # Prepare the flow for replay + hf.backup() + hf.request.is_replay = True + hf.response = None + hf.error = None + # https://github.com/mitmproxy/mitmproxy/issues/2197 + if hf.request.http_version == "HTTP/2.0": + hf.request.http_version = "HTTP/1.1" + host = hf.request.headers.pop(":authority") + hf.request.headers.insert(0, "host", host) + self.q.put(hf) + ctx.master.addons.trigger("update", lst) + + @command.command("replay.client.file") + def load_file(self, path: mitmproxy.types.Path) -> None: + """ + Load flows from file, and add them to the replay queue. + """ + try: + flows = io.read_flows_from_paths([path]) + except exceptions.FlowReadException as e: + raise exceptions.CommandError(str(e)) + self.start_replay(flows) diff --git a/mitmproxy/command.py b/mitmproxy/command.py index 77494100f..8f0755254 100644 --- a/mitmproxy/command.py +++ b/mitmproxy/command.py @@ -204,7 +204,15 @@ class CommandManager(mitmproxy.types._CommandBase): return parse, remhelp - def call_args(self, path: str, args: typing.Sequence[str]) -> typing.Any: + def call(self, path: str, *args: typing.Sequence[typing.Any]) -> typing.Any: + """ + Call a command with native arguments. May raise CommandError. + """ + if path not in self.commands: + raise exceptions.CommandError("Unknown command: %s" % path) + return self.commands[path].func(*args) + + def call_strings(self, path: str, args: typing.Sequence[str]) -> typing.Any: """ Call a command using a list of string arguments. May raise CommandError. """ @@ -212,14 +220,14 @@ class CommandManager(mitmproxy.types._CommandBase): raise exceptions.CommandError("Unknown command: %s" % path) return self.commands[path].call(args) - def call(self, cmdstr: str): + def execute(self, cmdstr: str): """ - Call a command using a string. May raise CommandError. + Execute a command string. May raise CommandError. """ parts = list(lexer(cmdstr)) if not len(parts) >= 1: raise exceptions.CommandError("Invalid command: %s" % cmdstr) - return self.call_args(parts[0], parts[1:]) + return self.call_strings(parts[0], parts[1:]) def dump(self, out=sys.stdout) -> None: cmds = list(self.commands.values()) diff --git a/mitmproxy/master.py b/mitmproxy/master.py index 8eb016008..7f81d1857 100644 --- a/mitmproxy/master.py +++ b/mitmproxy/master.py @@ -8,13 +8,11 @@ from mitmproxy import addonmanager from mitmproxy import options from mitmproxy import controller from mitmproxy import eventsequence -from mitmproxy import exceptions from mitmproxy import command from mitmproxy import http from mitmproxy import websocket from mitmproxy import log from mitmproxy.net import server_spec -from mitmproxy.proxy.protocol import http_replay from mitmproxy.coretypes import basethread from . import ctx as mitmproxy_ctx @@ -164,58 +162,3 @@ class Master: f.reply = controller.DummyReply() for e, o in eventsequence.iterate(f): await self.addons.handle_lifecycle(e, o) - - def replay_request( - self, - f: http.HTTPFlow, - block: bool=False - ) -> http_replay.RequestReplayThread: - """ - Replay a HTTP request to receive a new response from the server. - - Args: - f: The flow to replay. - block: If True, this function will wait for the replay to finish. - This causes a deadlock if activated in the main thread. - - Returns: - The thread object doing the replay. - - Raises: - exceptions.ReplayException, if the flow is in a state - where it is ineligible for replay. - """ - - if f.live: - raise exceptions.ReplayException( - "Can't replay live flow." - ) - if f.intercepted: - raise exceptions.ReplayException( - "Can't replay intercepted flow." - ) - if not f.request: - raise exceptions.ReplayException( - "Can't replay flow with missing request." - ) - if f.request.raw_content is None: - raise exceptions.ReplayException( - "Can't replay flow with missing content." - ) - - f.backup() - f.request.is_replay = True - - f.response = None - f.error = None - - if f.request.http_version == "HTTP/2.0": # https://github.com/mitmproxy/mitmproxy/issues/2197 - f.request.http_version = "HTTP/1.1" - host = f.request.headers.pop(":authority") - f.request.headers.insert(0, "host", host) - - rt = http_replay.RequestReplayThread(self.options, f, self.channel) - rt.start() # pragma: no cover - if block: - rt.join() - return rt diff --git a/mitmproxy/net/tcp.py b/mitmproxy/net/tcp.py index 220162910..18429daa6 100644 --- a/mitmproxy/net/tcp.py +++ b/mitmproxy/net/tcp.py @@ -372,12 +372,11 @@ class TCPClient(_Connection): # Make sure to close the real socket, not the SSL proxy. # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection, # it tries to renegotiate... - if not self.connection: - return - elif isinstance(self.connection, SSL.Connection): - close_socket(self.connection._socket) - else: - close_socket(self.connection) + if self.connection: + if isinstance(self.connection, SSL.Connection): + close_socket(self.connection._socket) + else: + close_socket(self.connection) def convert_to_tls(self, sni=None, alpn_protos=None, **sslctx_kwargs): context = tls.create_client_context( diff --git a/mitmproxy/proxy/protocol/http_replay.py b/mitmproxy/proxy/protocol/http_replay.py deleted file mode 100644 index b2cca2b11..000000000 --- a/mitmproxy/proxy/protocol/http_replay.py +++ /dev/null @@ -1,125 +0,0 @@ -from mitmproxy import log -from mitmproxy import controller -from mitmproxy import exceptions -from mitmproxy import http -from mitmproxy import flow -from mitmproxy import options -from mitmproxy import connections -from mitmproxy.net import server_spec, tls -from mitmproxy.net.http import http1 -from mitmproxy.coretypes import basethread -from mitmproxy.utils import human - - -# TODO: Doesn't really belong into mitmproxy.proxy.protocol... - - -class RequestReplayThread(basethread.BaseThread): - name = "RequestReplayThread" - - def __init__( - self, - opts: options.Options, - f: http.HTTPFlow, - channel: controller.Channel, - ) -> None: - self.options = opts - self.f = f - f.live = True - self.channel = channel - super().__init__( - "RequestReplay (%s)" % f.request.url - ) - self.daemon = True - - def run(self): - r = self.f.request - bsl = human.parse_size(self.options.body_size_limit) - first_line_format_backup = r.first_line_format - server = None - try: - self.f.response = None - - # If we have a channel, run script hooks. - if self.channel: - request_reply = self.channel.ask("request", self.f) - if isinstance(request_reply, http.HTTPResponse): - self.f.response = request_reply - - if not self.f.response: - # In all modes, we directly connect to the server displayed - if self.options.mode.startswith("upstream:"): - server_address = server_spec.parse_with_mode(self.options.mode)[1].address - server = connections.ServerConnection(server_address, (self.options.listen_host, 0)) - server.connect() - if r.scheme == "https": - connect_request = http.make_connect_request((r.data.host, r.port)) - server.wfile.write(http1.assemble_request(connect_request)) - server.wfile.flush() - resp = http1.read_response( - server.rfile, - connect_request, - body_size_limit=bsl - ) - if resp.status_code != 200: - raise exceptions.ReplayException("Upstream server refuses CONNECT request") - server.establish_tls( - sni=self.f.server_conn.sni, - **tls.client_arguments_from_options(self.options) - ) - r.first_line_format = "relative" - else: - r.first_line_format = "absolute" - else: - server_address = (r.host, r.port) - server = connections.ServerConnection( - server_address, - (self.options.listen_host, 0) - ) - server.connect() - if r.scheme == "https": - server.establish_tls( - sni=self.f.server_conn.sni, - **tls.client_arguments_from_options(self.options) - ) - r.first_line_format = "relative" - - server.wfile.write(http1.assemble_request(r)) - server.wfile.flush() - - if self.f.server_conn: - self.f.server_conn.close() - self.f.server_conn = server - - self.f.response = http.HTTPResponse.wrap( - http1.read_response( - server.rfile, - r, - body_size_limit=bsl - ) - ) - if self.channel: - response_reply = self.channel.ask("response", self.f) - if response_reply == exceptions.Kill: - raise exceptions.Kill() - except (exceptions.ReplayException, exceptions.NetlibException) as e: - self.f.error = flow.Error(str(e)) - if self.channel: - self.channel.ask("error", self.f) - except exceptions.Kill: - # Kill should only be raised if there's a channel in the - # first place. - self.channel.tell( - "log", - log.LogEntry("Connection killed", "info") - ) - except Exception as e: - self.channel.tell( - "log", - log.LogEntry(repr(e), "error") - ) - finally: - r.first_line_format = first_line_format_backup - self.f.live = False - if server.connected(): - server.finish() diff --git a/mitmproxy/test/taddons.py b/mitmproxy/test/taddons.py index 0505f9f78..67c15f753 100644 --- a/mitmproxy/test/taddons.py +++ b/mitmproxy/test/taddons.py @@ -112,12 +112,10 @@ class context: if addon not in self.master.addons: self.master.addons.register(addon) with self.options.rollback(kwargs.keys(), reraise=True): - self.options.update(**kwargs) - self.master.addons.invoke_addon( - addon, - "configure", - kwargs.keys() - ) + if kwargs: + self.options.update(**kwargs) + else: + self.master.addons.invoke_addon(addon, "configure", {}) def script(self, path): """ diff --git a/mitmproxy/tools/console/consoleaddons.py b/mitmproxy/tools/console/consoleaddons.py index 2b9ff334c..81f568ed0 100644 --- a/mitmproxy/tools/console/consoleaddons.py +++ b/mitmproxy/tools/console/consoleaddons.py @@ -258,7 +258,7 @@ class ConsoleAddon: command, then invoke another command with all occurrences of {choice} replaced by the choice the user made. """ - choices = ctx.master.commands.call_args(choicecmd, []) + choices = ctx.master.commands.call_strings(choicecmd, []) def callback(opt): # We're now outside of the call context... @@ -514,7 +514,7 @@ class ConsoleAddon: raise exceptions.CommandError("Invalid flowview mode.") try: - self.master.commands.call_args( + self.master.commands.call_strings( "view.setval", ["@focus", "flowview_mode_%s" % idx, mode] ) @@ -537,7 +537,7 @@ class ConsoleAddon: if not fv: raise exceptions.CommandError("Not viewing a flow.") idx = fv.body.tab_offset - return self.master.commands.call_args( + return self.master.commands.call_strings( "view.getval", [ "@focus", diff --git a/mitmproxy/tools/console/statusbar.py b/mitmproxy/tools/console/statusbar.py index fa987e94d..1e1c0b927 100644 --- a/mitmproxy/tools/console/statusbar.py +++ b/mitmproxy/tools/console/statusbar.py @@ -167,6 +167,7 @@ class StatusBar(urwid.WidgetWrap): self.ib = urwid.WidgetWrap(urwid.Text("")) self.ab = ActionBar(self.master) super().__init__(urwid.Pile([self.ib, self.ab])) + signals.flow_change.connect(self.sig_update) signals.update_settings.connect(self.sig_update) signals.flowlist_change.connect(self.sig_update) master.options.changed.connect(self.sig_update) @@ -184,7 +185,7 @@ class StatusBar(urwid.WidgetWrap): r = [] sreplay = self.master.addons.get("serverplayback") - creplay = self.master.addons.get("clientplayback") + creplay = self.master.commands.call("replay.client.count") if len(self.master.options.setheaders): r.append("[") @@ -192,10 +193,10 @@ class StatusBar(urwid.WidgetWrap): r.append("eaders]") if len(self.master.options.replacements): r.append("[%d replacements]" % len(self.master.options.replacements)) - if creplay.count(): + if creplay: r.append("[") r.append(("heading_key", "cplayback")) - r.append(":%s]" % creplay.count()) + r.append(":%s]" % creplay) if sreplay.count(): r.append("[") r.append(("heading_key", "splayback")) diff --git a/mitmproxy/tools/web/app.py b/mitmproxy/tools/web/app.py index 184778b08..ae2394eb0 100644 --- a/mitmproxy/tools/web/app.py +++ b/mitmproxy/tools/web/app.py @@ -344,7 +344,7 @@ class ReplayFlow(RequestHandler): self.view.update([self.flow]) try: - self.master.replay_request(self.flow) + self.master.commands.call("replay.client", [self.flow]) except exceptions.ReplayException as e: raise APIError(400, str(e)) diff --git a/mitmproxy/types.py b/mitmproxy/types.py index 23320c121..283e7e2ec 100644 --- a/mitmproxy/types.py +++ b/mitmproxy/types.py @@ -47,10 +47,10 @@ class Choice: class _CommandBase: commands: typing.MutableMapping[str, typing.Any] = {} - def call_args(self, path: str, args: typing.Sequence[str]) -> typing.Any: + def call_strings(self, path: str, args: typing.Sequence[str]) -> typing.Any: raise NotImplementedError - def call(self, cmd: str) -> typing.Any: + def execute(self, cmd: str) -> typing.Any: raise NotImplementedError @@ -337,7 +337,7 @@ class _FlowType(_BaseFlowType): def parse(self, manager: _CommandBase, t: type, s: str) -> flow.Flow: try: - flows = manager.call_args("view.resolve", [s]) + flows = manager.call_strings("view.resolve", [s]) except exceptions.CommandError as e: raise exceptions.TypeError from e if len(flows) != 1: @@ -356,7 +356,7 @@ class _FlowsType(_BaseFlowType): def parse(self, manager: _CommandBase, t: type, s: str) -> typing.Sequence[flow.Flow]: try: - return manager.call_args("view.resolve", [s]) + return manager.call_strings("view.resolve", [s]) except exceptions.CommandError as e: raise exceptions.TypeError from e @@ -401,17 +401,17 @@ class _ChoiceType(_BaseType): display = "choice" def completion(self, manager: _CommandBase, t: Choice, s: str) -> typing.Sequence[str]: - return manager.call(t.options_command) + return manager.execute(t.options_command) def parse(self, manager: _CommandBase, t: Choice, s: str) -> str: - opts = manager.call(t.options_command) + opts = manager.execute(t.options_command) if s not in opts: raise exceptions.TypeError("Invalid choice.") return s def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool: try: - opts = manager.call(typ.options_command) + opts = manager.execute(typ.options_command) except exceptions.CommandError: return False return val in opts diff --git a/test/mitmproxy/addons/test_check_ca.py b/test/mitmproxy/addons/test_check_ca.py index 5e820b6df..27e6f7e68 100644 --- a/test/mitmproxy/addons/test_check_ca.py +++ b/test/mitmproxy/addons/test_check_ca.py @@ -12,11 +12,11 @@ class TestCheckCA: async def test_check_ca(self, expired): msg = 'The mitmproxy certificate authority has expired!' - with taddons.context() as tctx: + a = check_ca.CheckCA() + with taddons.context(a) as tctx: tctx.master.server = mock.MagicMock() tctx.master.server.config.certstore.default_ca.has_expired = mock.MagicMock( return_value = expired ) - a = check_ca.CheckCA() tctx.configure(a) assert await tctx.master.await_log(msg) == expired diff --git a/test/mitmproxy/addons/test_clientplayback.py b/test/mitmproxy/addons/test_clientplayback.py index f172af83e..1b385e237 100644 --- a/test/mitmproxy/addons/test_clientplayback.py +++ b/test/mitmproxy/addons/test_clientplayback.py @@ -1,13 +1,16 @@ +import time import pytest -from unittest import mock -from mitmproxy.test import tflow +from mitmproxy.test import tflow, tutils from mitmproxy import io from mitmproxy import exceptions +from mitmproxy.net import http as net_http from mitmproxy.addons import clientplayback from mitmproxy.test import taddons +from .. import tservers + def tdump(path, flows): with open(path, "wb") as f: @@ -21,48 +24,87 @@ class MockThread(): return False +class TBase(tservers.HTTPProxyTest): + @staticmethod + def wait_response(flow): + """ + Race condition: We don't want to replay the flow while it is still live. + """ + s = time.time() + while True: + if flow.response or flow.error: + break + time.sleep(0.001) + if time.time() - s > 5: + raise RuntimeError("Flow is live for too long.") + + @staticmethod + def reset(f): + f.live = False + f.repsonse = False + f.error = False + + def addons(self): + return [clientplayback.ClientPlayback()] + + def test_replay(self): + cr = self.master.addons.get("clientplayback") + + assert self.pathod("304").status_code == 304 + assert len(self.master.state.flows) == 1 + l = self.master.state.flows[-1] + assert l.response.status_code == 304 + l.request.path = "/p/305" + cr.start_replay([l]) + self.wait_response(l) + assert l.response.status_code == 305 + + # Disconnect error + cr.stop_replay() + self.reset(l) + l.request.path = "/p/305:d0" + cr.start_replay([l]) + self.wait_response(l) + if isinstance(self, tservers.HTTPUpstreamProxyTest): + assert l.response.status_code == 502 + else: + assert l.error + + # # Port error + cr.stop_replay() + self.reset(l) + l.request.port = 1 + # In upstream mode, we get a 502 response from the upstream proxy server. + # In upstream mode with ssl, the replay will fail as we cannot establish + # SSL with the upstream proxy. + cr.start_replay([l]) + self.wait_response(l) + if isinstance(self, tservers.HTTPUpstreamProxyTest): + assert l.response.status_code == 502 + else: + assert l.error + + +class TestHTTPProxy(TBase, tservers.HTTPProxyTest): + pass + + +class TestHTTPSProxy(TBase, tservers.HTTPProxyTest): + ssl = True + + +class TestUpstreamProxy(TBase, tservers.HTTPUpstreamProxyTest): + pass + + class TestClientPlayback: - def test_playback(self): - cp = clientplayback.ClientPlayback() - with taddons.context(cp) as tctx: - assert cp.count() == 0 - f = tflow.tflow(resp=True) - cp.start_replay([f]) - assert cp.count() == 1 - RP = "mitmproxy.proxy.protocol.http_replay.RequestReplayThread" - with mock.patch(RP) as rp: - assert not cp.current_thread - cp.tick() - assert rp.called - assert cp.current_thread - - cp.flows = [] - cp.current_thread.is_alive.return_value = False - assert cp.count() == 1 - cp.tick() - assert cp.count() == 0 - assert tctx.master.has_event("update") - assert tctx.master.has_event("processing_complete") - - cp.current_thread = MockThread() - cp.tick() - assert cp.current_thread is None - - cp.start_replay([f]) - cp.stop_replay() - assert not cp.flows - - df = tflow.DummyFlow(tflow.tclient_conn(), tflow.tserver_conn(), True) - with pytest.raises(exceptions.CommandError, match="Can't replay live flow."): - cp.start_replay([df]) - def test_load_file(self, tmpdir): cp = clientplayback.ClientPlayback() with taddons.context(cp): fpath = str(tmpdir.join("flows")) tdump(fpath, [tflow.tflow(resp=True)]) cp.load_file(fpath) - assert cp.flows + assert cp.count() == 1 with pytest.raises(exceptions.CommandError): cp.load_file("/nonexistent") @@ -71,11 +113,63 @@ class TestClientPlayback: with taddons.context(cp) as tctx: path = str(tmpdir.join("flows")) tdump(path, [tflow.tflow()]) + assert cp.count() == 0 tctx.configure(cp, client_replay=[path]) - cp.configured = False + assert cp.count() == 1 tctx.configure(cp, client_replay=[]) - cp.configured = False - tctx.configure(cp) - cp.configured = False with pytest.raises(exceptions.OptionsError): tctx.configure(cp, client_replay=["nonexistent"]) + + def test_check(self): + cp = clientplayback.ClientPlayback() + with taddons.context(cp): + f = tflow.tflow(resp=True) + f.live = True + assert "live flow" in cp.check(f) + + f = tflow.tflow(resp=True) + f.intercepted = True + assert "intercepted flow" in cp.check(f) + + f = tflow.tflow(resp=True) + f.request = None + assert "missing request" in cp.check(f) + + f = tflow.tflow(resp=True) + f.request.raw_content = None + assert "missing content" in cp.check(f) + + @pytest.mark.asyncio + async def test_playback(self): + cp = clientplayback.ClientPlayback() + with taddons.context(cp) as ctx: + assert cp.count() == 0 + f = tflow.tflow(resp=True) + cp.start_replay([f]) + assert cp.count() == 1 + + cp.stop_replay() + assert cp.count() == 0 + + f.live = True + cp.start_replay([f]) + assert cp.count() == 0 + await ctx.master.await_log("live") + + def test_http2(self): + cp = clientplayback.ClientPlayback() + with taddons.context(cp): + req = tutils.treq( + headers = net_http.Headers( + ( + (b":authority", b"foo"), + (b"header", b"qvalue"), + (b"content-length", b"7") + ) + ) + ) + f = tflow.tflow(req=req) + f.request.http_version = "HTTP/2.0" + cp.start_replay([f]) + assert f.request.http_version == "HTTP/1.1" + assert ":authority" not in f.request.headers diff --git a/test/mitmproxy/addons/test_readfile.py b/test/mitmproxy/addons/test_readfile.py index f7e0c5c53..d22382a84 100644 --- a/test/mitmproxy/addons/test_readfile.py +++ b/test/mitmproxy/addons/test_readfile.py @@ -42,7 +42,7 @@ def corrupt_data(): class TestReadFile: def test_configure(self): rf = readfile.ReadFile() - with taddons.context() as tctx: + with taddons.context(rf) as tctx: tctx.configure(rf, readfile_filter="~q") with pytest.raises(Exception, match="Invalid readfile filter"): tctx.configure(rf, readfile_filter="~~") diff --git a/test/mitmproxy/addons/test_save.py b/test/mitmproxy/addons/test_save.py index 4486ff783..4aa1f6488 100644 --- a/test/mitmproxy/addons/test_save.py +++ b/test/mitmproxy/addons/test_save.py @@ -11,7 +11,7 @@ from mitmproxy.addons import view def test_configure(tmpdir): sa = save.Save() - with taddons.context() as tctx: + with taddons.context(sa) as tctx: with pytest.raises(exceptions.OptionsError): tctx.configure(sa, save_stream_file=str(tmpdir)) with pytest.raises(Exception, match="Invalid filter"): @@ -32,7 +32,7 @@ def rd(p): def test_tcp(tmpdir): sa = save.Save() - with taddons.context() as tctx: + with taddons.context(sa) as tctx: p = str(tmpdir.join("foo")) tctx.configure(sa, save_stream_file=p) @@ -45,7 +45,7 @@ def test_tcp(tmpdir): def test_websocket(tmpdir): sa = save.Save() - with taddons.context() as tctx: + with taddons.context(sa) as tctx: p = str(tmpdir.join("foo")) tctx.configure(sa, save_stream_file=p) @@ -73,12 +73,12 @@ def test_save_command(tmpdir): v = view.View() tctx.master.addons.add(v) tctx.master.addons.add(sa) - tctx.master.commands.call_args("save.file", ["@shown", p]) + tctx.master.commands.call_strings("save.file", ["@shown", p]) def test_simple(tmpdir): sa = save.Save() - with taddons.context() as tctx: + with taddons.context(sa) as tctx: p = str(tmpdir.join("foo")) tctx.configure(sa, save_stream_file=p) diff --git a/test/mitmproxy/addons/test_script.py b/test/mitmproxy/addons/test_script.py index c358f0197..916374892 100644 --- a/test/mitmproxy/addons/test_script.py +++ b/test/mitmproxy/addons/test_script.py @@ -92,14 +92,13 @@ class TestScript: @pytest.mark.asyncio async def test_simple(self, tdata): - with taddons.context() as tctx: - sc = script.Script( - tdata.path( - "mitmproxy/data/addonscripts/recorder/recorder.py" - ), - True, - ) - tctx.master.addons.add(sc) + sc = script.Script( + tdata.path( + "mitmproxy/data/addonscripts/recorder/recorder.py" + ), + True, + ) + with taddons.context(sc) as tctx: tctx.configure(sc) await tctx.master.await_log("recorder running") rec = tctx.master.addons.get("recorder") @@ -284,7 +283,7 @@ class TestScriptLoader: rec = tdata.path("mitmproxy/data/addonscripts/recorder") sc = script.ScriptLoader() sc.is_running = True - with taddons.context() as tctx: + with taddons.context(sc) as tctx: tctx.configure( sc, scripts = [ diff --git a/test/mitmproxy/addons/test_view.py b/test/mitmproxy/addons/test_view.py index 62a6aeb0d..bd724950b 100644 --- a/test/mitmproxy/addons/test_view.py +++ b/test/mitmproxy/addons/test_view.py @@ -155,7 +155,7 @@ def test_create(): def test_orders(): v = view.View() - with taddons.context(): + with taddons.context(v): assert v.order_options() @@ -303,7 +303,7 @@ def test_setgetval(): def test_order(): v = view.View() - with taddons.context() as tctx: + with taddons.context(v) as tctx: v.request(tft(method="get", start=1)) v.request(tft(method="put", start=2)) v.request(tft(method="get", start=3)) @@ -434,7 +434,7 @@ def test_signals(): def test_focus_follow(): v = view.View() - with taddons.context() as tctx: + with taddons.context(v) as tctx: console_addon = consoleaddons.ConsoleAddon(tctx.master) tctx.configure(console_addon) tctx.configure(v, console_focus_follow=True, view_filter="~m get") @@ -553,7 +553,7 @@ def test_settings(): def test_configure(): v = view.View() - with taddons.context() as tctx: + with taddons.context(v) as tctx: tctx.configure(v, view_filter="~q") with pytest.raises(Exception, match="Invalid interception filter"): tctx.configure(v, view_filter="~~") diff --git a/test/mitmproxy/proxy/protocol/test_http_replay.py b/test/mitmproxy/proxy/protocol/test_http_replay.py deleted file mode 100644 index 777ab4dd1..000000000 --- a/test/mitmproxy/proxy/protocol/test_http_replay.py +++ /dev/null @@ -1 +0,0 @@ -# TODO: write tests diff --git a/test/mitmproxy/proxy/test_server.py b/test/mitmproxy/proxy/test_server.py index 936414ab8..914f9184b 100644 --- a/test/mitmproxy/proxy/test_server.py +++ b/test/mitmproxy/proxy/test_server.py @@ -31,48 +31,6 @@ class CommonMixin: def test_large(self): assert len(self.pathod("200:b@50k").content) == 1024 * 50 - @staticmethod - def wait_until_not_live(flow): - """ - Race condition: We don't want to replay the flow while it is still live. - """ - s = time.time() - while flow.live: - time.sleep(0.001) - if time.time() - s > 5: - raise RuntimeError("Flow is live for too long.") - - def test_replay(self): - assert self.pathod("304").status_code == 304 - assert len(self.master.state.flows) == 1 - l = self.master.state.flows[-1] - assert l.response.status_code == 304 - l.request.path = "/p/305" - self.wait_until_not_live(l) - rt = self.master.replay_request(l, block=True) - assert l.response.status_code == 305 - - # Disconnect error - l.request.path = "/p/305:d0" - rt = self.master.replay_request(l, block=True) - assert rt - if isinstance(self, tservers.HTTPUpstreamProxyTest): - assert l.response.status_code == 502 - else: - assert l.error - - # Port error - l.request.port = 1 - # In upstream mode, we get a 502 response from the upstream proxy server. - # In upstream mode with ssl, the replay will fail as we cannot establish - # SSL with the upstream proxy. - rt = self.master.replay_request(l, block=True) - assert rt - if isinstance(self, tservers.HTTPUpstreamProxyTest): - assert l.response.status_code == 502 - else: - assert l.error - def test_http(self): f = self.pathod("304") assert f.status_code == 304 diff --git a/test/mitmproxy/test_addonmanager.py b/test/mitmproxy/test_addonmanager.py index 796ae1bdd..1ef1521da 100644 --- a/test/mitmproxy/test_addonmanager.py +++ b/test/mitmproxy/test_addonmanager.py @@ -47,7 +47,7 @@ class AOption: def test_command(): with taddons.context() as tctx: tctx.master.addons.add(TAddon("test")) - assert tctx.master.commands.call("test.command") == "here" + assert tctx.master.commands.execute("test.command") == "here" def test_halt(): diff --git a/test/mitmproxy/test_command.py b/test/mitmproxy/test_command.py index 3d0a43f88..ea1017e7c 100644 --- a/test/mitmproxy/test_command.py +++ b/test/mitmproxy/test_command.py @@ -242,16 +242,19 @@ def test_simple(): a = TAddon() c.add("one.two", a.cmd1) assert c.commands["one.two"].help == "cmd1 help" - assert(c.call("one.two foo") == "ret foo") + assert(c.execute("one.two foo") == "ret foo") + assert(c.call("one.two", "foo") == "ret foo") + with pytest.raises(exceptions.CommandError, match="Unknown"): + c.execute("nonexistent") + with pytest.raises(exceptions.CommandError, match="Invalid"): + c.execute("") + with pytest.raises(exceptions.CommandError, match="argument mismatch"): + c.execute("one.two too many args") with pytest.raises(exceptions.CommandError, match="Unknown"): c.call("nonexistent") - with pytest.raises(exceptions.CommandError, match="Invalid"): - c.call("") - with pytest.raises(exceptions.CommandError, match="argument mismatch"): - c.call("one.two too many args") c.add("empty", a.empty) - c.call("empty") + c.execute("empty") fp = io.StringIO() c.dump(fp) @@ -340,13 +343,13 @@ def test_decorator(): a = TDec() c.collect_commands(a) assert "cmd1" in c.commands - assert c.call("cmd1 bar") == "ret bar" + assert c.execute("cmd1 bar") == "ret bar" assert "empty" in c.commands - assert c.call("empty") is None + assert c.execute("empty") is None with taddons.context() as tctx: tctx.master.addons.add(a) - assert tctx.master.commands.call("cmd1 bar") == "ret bar" + assert tctx.master.commands.execute("cmd1 bar") == "ret bar" def test_verify_arg_signature(): diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index a6f194a72..4956a1d22 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -1,17 +1,14 @@ import io -from unittest import mock import pytest -from mitmproxy.test import tflow, tutils, taddons +from mitmproxy.test import tflow, taddons import mitmproxy.io from mitmproxy import flowfilter from mitmproxy import options from mitmproxy.io import tnetstring -from mitmproxy.exceptions import FlowReadException, ReplayException +from mitmproxy.exceptions import FlowReadException from mitmproxy import flow from mitmproxy import http -from mitmproxy.net import http as net_http -from mitmproxy import master from . import tservers @@ -122,34 +119,6 @@ class TestFlowMaster: assert s.flows[1].handshake_flow == f.handshake_flow assert len(s.flows[1].messages) == len(f.messages) - def test_replay(self): - opts = options.Options() - fm = master.Master(opts) - f = tflow.tflow(resp=True) - f.request.content = None - with pytest.raises(ReplayException, match="missing"): - fm.replay_request(f) - - f.request = None - with pytest.raises(ReplayException, match="request"): - fm.replay_request(f) - - f.intercepted = True - with pytest.raises(ReplayException, match="intercepted"): - fm.replay_request(f) - - f.live = True - with pytest.raises(ReplayException, match="live"): - fm.replay_request(f) - - req = tutils.treq(headers=net_http.Headers(((b":authority", b"foo"), (b"header", b"qvalue"), (b"content-length", b"7")))) - f = tflow.tflow(req=req) - f.request.http_version = "HTTP/2.0" - with mock.patch('mitmproxy.proxy.protocol.http_replay.RequestReplayThread.run'): - rt = fm.replay_request(f) - assert rt.f.request.http_version == "HTTP/1.1" - assert ":authority" not in rt.f.request.headers - @pytest.mark.asyncio async def test_all(self): opts = options.Options( diff --git a/test/mitmproxy/tools/web/test_app.py b/test/mitmproxy/tools/web/test_app.py index 001888956..668d3c07d 100644 --- a/test/mitmproxy/tools/web/test_app.py +++ b/test/mitmproxy/tools/web/test_app.py @@ -9,7 +9,6 @@ import tornado.testing from tornado import httpclient from tornado import websocket -from mitmproxy import exceptions from mitmproxy import options from mitmproxy.test import tflow from mitmproxy.tools.web import app @@ -186,13 +185,9 @@ class TestApp(tornado.testing.AsyncHTTPTestCase): assert not f._backup def test_flow_replay(self): - with mock.patch("mitmproxy.master.Master.replay_request") as replay_request: + with mock.patch("mitmproxy.command.CommandManager.call") as replay_call: assert self.fetch("/flows/42/replay", method="POST").code == 200 - assert replay_request.called - replay_request.side_effect = exceptions.ReplayException( - "out of replays" - ) - assert self.fetch("/flows/42/replay", method="POST").code == 400 + assert replay_call.called def test_flow_content(self): f = self.view.get_by_id("42")