From 05968a29bb0dbad87ed7e43d53990e6d1b59aba5 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 17 Nov 2020 19:07:41 +0100 Subject: [PATCH] [sans-io] implement sans-io based replay --- mitmproxy/addons/__init__.py | 4 + mitmproxy/addons/clientplayback_sansio.py | 205 ++++++++++++++++++++++ mitmproxy/addons/proxyserver.py | 27 +-- mitmproxy/proxy2/context.py | 4 +- mitmproxy/proxy2/layers/http/_http1.py | 27 +-- mitmproxy/proxy2/layers/tls.py | 2 +- mitmproxy/proxy2/server.py | 32 ++-- 7 files changed, 261 insertions(+), 40 deletions(-) create mode 100644 mitmproxy/addons/clientplayback_sansio.py diff --git a/mitmproxy/addons/__init__.py b/mitmproxy/addons/__init__.py index fd330a71a..6f508ac6b 100644 --- a/mitmproxy/addons/__init__.py +++ b/mitmproxy/addons/__init__.py @@ -25,6 +25,10 @@ from mitmproxy.addons import streambodies from mitmproxy.addons import save from mitmproxy.addons import tlsconfig from mitmproxy.addons import upstream_auth +from mitmproxy.utils import compat + +if compat.new_proxy_core: # pragma: no cover + from mitmproxy.addons import clientplayback_sansio as clientplayback # noqa def default_addons(): diff --git a/mitmproxy/addons/clientplayback_sansio.py b/mitmproxy/addons/clientplayback_sansio.py new file mode 100644 index 000000000..a194bc6d7 --- /dev/null +++ b/mitmproxy/addons/clientplayback_sansio.py @@ -0,0 +1,205 @@ +import asyncio +import traceback +import typing + +import mitmproxy.types +from mitmproxy import command +from mitmproxy import ctx +from mitmproxy import exceptions +from mitmproxy import flow +from mitmproxy import http +from mitmproxy import io +from mitmproxy.addons.proxyserver import AsyncReply +from mitmproxy.net import server_spec +from mitmproxy.options import Options +from mitmproxy.proxy.protocol.http import HTTPMode +from mitmproxy.proxy2 import commands, events, layers, server +from mitmproxy.proxy2.context import Context, Server +from mitmproxy.proxy2.layer import CommandGenerator + + +class MockServer(layers.http.HttpConnection): + """ + A mock HTTP "server" that just pretends it received a full HTTP request, + which is then processed by the proxy core. + """ + flow: http.HTTPFlow + + def __init__(self, flow: http.HTTPFlow, context: Context): + super().__init__(context, context.client) + self.flow = flow + + def _handle_event(self, event: events.Event) -> CommandGenerator[None]: + if isinstance(event, events.Start): + yield layers.http.ReceiveHttp(layers.http.RequestHeaders(1, self.flow.request)) + if self.flow.request.raw_content: + yield layers.http.ReceiveHttp(layers.http.RequestData(1, self.flow.request.raw_content)) + yield layers.http.ReceiveHttp(layers.http.RequestEndOfMessage(1)) + elif isinstance(event, ( + layers.http.ResponseHeaders, + layers.http.ResponseData, + layers.http.ResponseEndOfMessage, + layers.http.ResponseProtocolError, + )): + pass + else: + ctx.log(f"Unexpected event during replay: {events}") + + +class ReplayHandler(server.ConnectionHandler): + def __init__(self, flow: http.HTTPFlow, options: Options) -> None: + client = flow.client_conn.copy() + + context = Context(client, options) + context.server = Server( + (flow.request.host, flow.request.port) + ) + context.server.tls = flow.request.scheme == "https" + if options.mode.startswith("upstream:"): + context.server.via = server_spec.parse_with_mode(options.mode)[1] + + super().__init__(context) + + self.layer = layers.HttpLayer(context, HTTPMode.transparent) + self.layer.connections[client] = MockServer(flow, context.fork()) + self.flow = flow + self.done = asyncio.Event() + + async def replay(self) -> None: + self.server_event(events.Start()) + await self.done.wait() + + def log(self, message: str, level: str = "info") -> None: + ctx.log(f"[replay] {message}", level) + + async def handle_hook(self, hook: commands.Hook) -> None: + data, = hook.as_tuple() + data.reply = AsyncReply(data) + await ctx.master.addons.handle_lifecycle(hook.name, data) + await data.reply.done.wait() + if isinstance(hook, (layers.http.HttpResponseHook, layers.http.HttpErrorHook)): + if self.transports: + # close server connections + for x in self.transports.values(): + x.handler.cancel() + await asyncio.wait([x.handler for x in self.transports.values()]) + # signal completion + self.done.set() + + +class ClientPlayback: + playback_task: typing.Optional[asyncio.Task] + inflight: typing.Optional[http.HTTPFlow] + queue: asyncio.Queue + options: Options + + def __init__(self): + self.queue = asyncio.Queue() + self.inflight = None + self.task = None + + def running(self): + self.playback_task = asyncio.create_task(self.playback()) + self.options = ctx.options + + def done(self): + self.playback_task.cancel() + + async def playback(self): + try: + while True: + self.inflight = await self.queue.get() + try: + h = ReplayHandler(self.inflight, self.options) + await h.replay() + except Exception: + ctx.log(f"Client replay has crashed!\n{traceback.format_exc()}", "error") + self.inflight = None + except asyncio.CancelledError: + return + + def check(self, f: flow.Flow) -> typing.Optional[str]: + if f.live: + return "Can't replay live flow." + if f.intercepted: + return "Can't replay intercepted flow." + if isinstance(f, http.HTTPFlow): + if not f.request: + return "Can't replay flow with missing request." + if f.request.raw_content is None: + return "Can't replay flow with missing content." + else: + return "Can only replay HTTP flows." + + def load(self, loader): + loader.add_option( + "client_replay", typing.Sequence[str], [], + "Replay client requests from a saved file." + ) + + def configure(self, updated): + if "client_replay" in updated and ctx.options.client_replay: + try: + flows = io.read_flows_from_paths(ctx.options.client_replay) + except exceptions.FlowReadException as e: + raise exceptions.OptionsError(str(e)) + self.start_replay(flows) + + @command.command("replay.client.count") + def count(self) -> int: + """ + Approximate number of flows queued for replay. + """ + return self.queue.qsize() + int(bool(self.inflight)) + + @command.command("replay.client.stop") + def stop_replay(self) -> None: + """ + Clear the replay queue. + """ + updated = [] + while True: + try: + f = self.queue.get_nowait() + except asyncio.QueueEmpty: + break + else: + f.revert() + updated.append(f) + + ctx.master.addons.trigger("update", updated) + ctx.log.alert("Client replay queue cleared.") + + @command.command("replay.client") + def start_replay(self, flows: typing.Sequence[flow.Flow]) -> None: + """ + Add flows to the replay queue, skipping flows that can't be replayed. + """ + updated: typing.List[http.HTTPFlow] = [] + for f in flows: + err = self.check(f) + if err: + ctx.log.warn(err) + continue + + http_flow = typing.cast(http.HTTPFlow, f) + + # Prepare the flow for replay + http_flow.backup() + http_flow.is_replay = "request" + http_flow.response = None + http_flow.error = None + self.queue.put_nowait(http_flow) + updated.append(http_flow) + ctx.master.addons.trigger("update", updated) + + @command.command("replay.client.file") + def load_file(self, path: mitmproxy.types.Path) -> None: + """ + Load flows from file, and add them to the replay queue. + """ + try: + flows = io.read_flows_from_paths([path]) + except exceptions.FlowReadException as e: + raise exceptions.CommandError(str(e)) + self.start_replay(flows) diff --git a/mitmproxy/addons/proxyserver.py b/mitmproxy/addons/proxyserver.py index 3dab35f47..de93dd6c2 100644 --- a/mitmproxy/addons/proxyserver.py +++ b/mitmproxy/addons/proxyserver.py @@ -29,17 +29,17 @@ class AsyncReply(controller.Reply): self.obj.error = Error.KILLED_MESSAGE -class ProxyConnectionHandler(server.ConnectionHandler): +class ProxyConnectionHandler(server.StreamConnectionHandler): master: master.Master def __init__(self, master, r, w, options): self.master = master super().__init__(r, w, options) - self.log_prefix = f"{human.format_address(self.client.address)}: " + self.log_prefix = f"{human.format_address(self.client.peername)}: " async def handle_hook(self, hook: commands.Hook) -> None: with self.timeout_watchdog.disarm(): - # TODO: We currently only support single-argument hooks. + # We currently only support single-argument hooks. data, = hook.as_tuple() data.reply = AsyncReply(data) await self.master.addons.handle_lifecycle(hook.name, data) @@ -88,19 +88,22 @@ class Proxyserver: def configure(self, updated): if not self.is_running: return - if any(x in updated for x in ["listen_host", "listen_port"]): - asyncio.ensure_future(self.start_server()) + if any(x in updated for x in ["server", "listen_host", "listen_port"]): + asyncio.ensure_future(self.refresh_server()) - async def start_server(self): + async def refresh_server(self): async with self._lock: if self.server: await self.shutdown_server() - print("Starting server...") - self.server = await asyncio.start_server( - self.handle_connection, - self.options.listen_host, - self.options.listen_port, - ) + self.server = None + if ctx.options.server: + self.server = await asyncio.start_server( + self.handle_connection, + self.options.listen_host, + self.options.listen_port, + ) + addrs = {f"http://{human.format_address(s.getsockname())}" for s in self.server.sockets} + ctx.log.info(f"Proxy server listening at {' and '.join(addrs)}") async def shutdown_server(self): print("Stopping server...") diff --git a/mitmproxy/proxy2/context.py b/mitmproxy/proxy2/context.py index 119150286..c97b6cd0b 100644 --- a/mitmproxy/proxy2/context.py +++ b/mitmproxy/proxy2/context.py @@ -113,7 +113,7 @@ class Client(Connection): 'address': self.peername, 'alpn_proto_negotiated': self.alpn, 'cipher_name': self.cipher, - 'clientcert': self.certificate_list[0] if self.certificate_list else None, + 'clientcert': self.certificate_list[0].get_state() if self.certificate_list else None, 'id': self.id, 'mitmcert': None, 'sni': self.sni, @@ -181,7 +181,7 @@ class Server(Connection): return { 'address': self.address, 'alpn_proto_negotiated': self.alpn, - 'cert': self.certificate_list[0] if self.certificate_list else None, + 'cert': self.certificate_list[0].get_state() if self.certificate_list else None, 'id': self.id, 'ip_address': self.peername, 'sni': self.sni, diff --git a/mitmproxy/proxy2/layers/http/_http1.py b/mitmproxy/proxy2/layers/http/_http1.py index cbc0b6300..7d63136d7 100644 --- a/mitmproxy/proxy2/layers/http/_http1.py +++ b/mitmproxy/proxy2/layers/http/_http1.py @@ -114,17 +114,18 @@ class Http1Server(Http1Connection): def send(self, event: HttpEvent) -> layer.CommandGenerator[None]: assert event.stream_id == self.stream_id if isinstance(event, ResponseHeaders): - self.response = event.response + self.response = response = event.response - if self.response.is_http2: - # Convert to an HTTP/1 request. - self.response.http_version = b"HTTP/1.1" + if response.is_http2: + response = response.copy() + # Convert to an HTTP/1 response. + response.http_version = b"HTTP/1.1" # not everyone supports empty reason phrases, so we better make up one. - self.response.reason = status_codes.RESPONSES.get(self.response.status_code, "") + response.reason = status_codes.RESPONSES.get(response.status_code, "") # Shall we set a Content-Length header here if there is none? # For now, let's try to modify as little as possible. - raw = http1.assemble_response_head(event.response) + raw = http1.assemble_response_head(response) yield commands.SendData(self.conn, raw) if self.request.first_line_format == "authority": assert self.state == self.wait @@ -237,13 +238,15 @@ class Http1Client(Http1Connection): return if isinstance(event, RequestHeaders): - if event.request.is_http2: + request = event.request + if request.is_http2: # Convert to an HTTP/1 request. - event.request.http_version = b"HTTP/1.1" - if "Host" not in event.request.headers and event.request.authority: - event.request.headers.insert(0, "Host", event.request.authority) - event.request.authority = b"" - raw = http1.assemble_request_head(event.request) + request = request.copy() # (we could probably be a bit more efficient here.) + request.http_version = b"HTTP/1.1" + if "Host" not in request.headers and request.authority: + request.headers.insert(0, "Host", request.authority) + request.authority = b"" + raw = http1.assemble_request_head(request) yield commands.SendData(self.conn, raw) elif isinstance(event, RequestData): if "chunked" in self.request.headers.get("transfer-encoding", "").lower(): diff --git a/mitmproxy/proxy2/layers/tls.py b/mitmproxy/proxy2/layers/tls.py index df45e46d4..8990e8a69 100644 --- a/mitmproxy/proxy2/layers/tls.py +++ b/mitmproxy/proxy2/layers/tls.py @@ -194,7 +194,7 @@ class _TLSLayer(tunnel.TunnelLayer): self.conn.cipher = self.tls.get_cipher_name() self.conn.cipher_list = self.tls.get_cipher_list() self.conn.tls_version = self.tls.get_protocol_version_name() - yield commands.Log(f"TLS established: {self.conn}") + yield commands.Log(f"TLS established: {self.conn}", "debug") yield from self.receive_data(b"") return True, None diff --git a/mitmproxy/proxy2/server.py b/mitmproxy/proxy2/server.py index fa187d763..fb6cd4d4c 100644 --- a/mitmproxy/proxy2/server.py +++ b/mitmproxy/proxy2/server.py @@ -13,6 +13,7 @@ import socket import time import traceback import typing +from abc import ABC from contextlib import contextmanager from dataclasses import dataclass @@ -76,23 +77,16 @@ class ConnectionIO: class ConnectionHandler(metaclass=abc.ABCMeta): transports: typing.MutableMapping[Connection, ConnectionIO] timeout_watchdog: TimeoutWatchdog + client: Client - def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, options: moptions.Options) -> None: - self.client = Client( - writer.get_extra_info('peername'), - writer.get_extra_info('sockname'), - time.time(), - ) - self.context = Context(self.client, options) - self.transports = { - self.client: ConnectionIO(handler=None, reader=reader, writer=writer) - } + def __init__(self, context: Context) -> None: + self.client = context.client + self.transports = {} # Ask for the first layer right away. # In a reverse proxy scenario, this is necessary as we would otherwise hang # on protocols that start with a server greeting. - self.layer = layer.NextLayer(self.context, ask_on_start=True) - + self.layer = layer.NextLayer(context, ask_on_start=True) self.timeout_watchdog = TimeoutWatchdog(self.on_timeout) async def handle_client(self) -> None: @@ -264,7 +258,19 @@ class ConnectionHandler(metaclass=abc.ABCMeta): self.transports[connection].handler.cancel() -class SimpleConnectionHandler(ConnectionHandler): +class StreamConnectionHandler(ConnectionHandler, metaclass=abc.ABCMeta): + def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, options: moptions.Options) -> None: + client = Client( + writer.get_extra_info('peername'), + writer.get_extra_info('sockname'), + time.time(), + ) + context = Context(client, options) + super().__init__(context) + self.transports[client] = ConnectionIO(handler=None, reader=reader, writer=writer) + + +class SimpleConnectionHandler(StreamConnectionHandler): """Simple handler that does not really process any hooks.""" hook_handlers: typing.Dict[str, typing.Callable]