[sans-io] implement sans-io based replay

This commit is contained in:
Maximilian Hils 2020-11-17 19:07:41 +01:00
parent 895466bc59
commit 05968a29bb
7 changed files with 261 additions and 40 deletions

View File

@ -25,6 +25,10 @@ from mitmproxy.addons import streambodies
from mitmproxy.addons import save from mitmproxy.addons import save
from mitmproxy.addons import tlsconfig from mitmproxy.addons import tlsconfig
from mitmproxy.addons import upstream_auth from mitmproxy.addons import upstream_auth
from mitmproxy.utils import compat
if compat.new_proxy_core: # pragma: no cover
from mitmproxy.addons import clientplayback_sansio as clientplayback # noqa
def default_addons(): def default_addons():

View File

@ -0,0 +1,205 @@
import asyncio
import traceback
import typing
import mitmproxy.types
from mitmproxy import command
from mitmproxy import ctx
from mitmproxy import exceptions
from mitmproxy import flow
from mitmproxy import http
from mitmproxy import io
from mitmproxy.addons.proxyserver import AsyncReply
from mitmproxy.net import server_spec
from mitmproxy.options import Options
from mitmproxy.proxy.protocol.http import HTTPMode
from mitmproxy.proxy2 import commands, events, layers, server
from mitmproxy.proxy2.context import Context, Server
from mitmproxy.proxy2.layer import CommandGenerator
class MockServer(layers.http.HttpConnection):
"""
A mock HTTP "server" that just pretends it received a full HTTP request,
which is then processed by the proxy core.
"""
flow: http.HTTPFlow
def __init__(self, flow: http.HTTPFlow, context: Context):
super().__init__(context, context.client)
self.flow = flow
def _handle_event(self, event: events.Event) -> CommandGenerator[None]:
if isinstance(event, events.Start):
yield layers.http.ReceiveHttp(layers.http.RequestHeaders(1, self.flow.request))
if self.flow.request.raw_content:
yield layers.http.ReceiveHttp(layers.http.RequestData(1, self.flow.request.raw_content))
yield layers.http.ReceiveHttp(layers.http.RequestEndOfMessage(1))
elif isinstance(event, (
layers.http.ResponseHeaders,
layers.http.ResponseData,
layers.http.ResponseEndOfMessage,
layers.http.ResponseProtocolError,
)):
pass
else:
ctx.log(f"Unexpected event during replay: {events}")
class ReplayHandler(server.ConnectionHandler):
def __init__(self, flow: http.HTTPFlow, options: Options) -> None:
client = flow.client_conn.copy()
context = Context(client, options)
context.server = Server(
(flow.request.host, flow.request.port)
)
context.server.tls = flow.request.scheme == "https"
if options.mode.startswith("upstream:"):
context.server.via = server_spec.parse_with_mode(options.mode)[1]
super().__init__(context)
self.layer = layers.HttpLayer(context, HTTPMode.transparent)
self.layer.connections[client] = MockServer(flow, context.fork())
self.flow = flow
self.done = asyncio.Event()
async def replay(self) -> None:
self.server_event(events.Start())
await self.done.wait()
def log(self, message: str, level: str = "info") -> None:
ctx.log(f"[replay] {message}", level)
async def handle_hook(self, hook: commands.Hook) -> None:
data, = hook.as_tuple()
data.reply = AsyncReply(data)
await ctx.master.addons.handle_lifecycle(hook.name, data)
await data.reply.done.wait()
if isinstance(hook, (layers.http.HttpResponseHook, layers.http.HttpErrorHook)):
if self.transports:
# close server connections
for x in self.transports.values():
x.handler.cancel()
await asyncio.wait([x.handler for x in self.transports.values()])
# signal completion
self.done.set()
class ClientPlayback:
playback_task: typing.Optional[asyncio.Task]
inflight: typing.Optional[http.HTTPFlow]
queue: asyncio.Queue
options: Options
def __init__(self):
self.queue = asyncio.Queue()
self.inflight = None
self.task = None
def running(self):
self.playback_task = asyncio.create_task(self.playback())
self.options = ctx.options
def done(self):
self.playback_task.cancel()
async def playback(self):
try:
while True:
self.inflight = await self.queue.get()
try:
h = ReplayHandler(self.inflight, self.options)
await h.replay()
except Exception:
ctx.log(f"Client replay has crashed!\n{traceback.format_exc()}", "error")
self.inflight = None
except asyncio.CancelledError:
return
def check(self, f: flow.Flow) -> typing.Optional[str]:
if f.live:
return "Can't replay live flow."
if f.intercepted:
return "Can't replay intercepted flow."
if isinstance(f, http.HTTPFlow):
if not f.request:
return "Can't replay flow with missing request."
if f.request.raw_content is None:
return "Can't replay flow with missing content."
else:
return "Can only replay HTTP flows."
def load(self, loader):
loader.add_option(
"client_replay", typing.Sequence[str], [],
"Replay client requests from a saved file."
)
def configure(self, updated):
if "client_replay" in updated and ctx.options.client_replay:
try:
flows = io.read_flows_from_paths(ctx.options.client_replay)
except exceptions.FlowReadException as e:
raise exceptions.OptionsError(str(e))
self.start_replay(flows)
@command.command("replay.client.count")
def count(self) -> int:
"""
Approximate number of flows queued for replay.
"""
return self.queue.qsize() + int(bool(self.inflight))
@command.command("replay.client.stop")
def stop_replay(self) -> None:
"""
Clear the replay queue.
"""
updated = []
while True:
try:
f = self.queue.get_nowait()
except asyncio.QueueEmpty:
break
else:
f.revert()
updated.append(f)
ctx.master.addons.trigger("update", updated)
ctx.log.alert("Client replay queue cleared.")
@command.command("replay.client")
def start_replay(self, flows: typing.Sequence[flow.Flow]) -> None:
"""
Add flows to the replay queue, skipping flows that can't be replayed.
"""
updated: typing.List[http.HTTPFlow] = []
for f in flows:
err = self.check(f)
if err:
ctx.log.warn(err)
continue
http_flow = typing.cast(http.HTTPFlow, f)
# Prepare the flow for replay
http_flow.backup()
http_flow.is_replay = "request"
http_flow.response = None
http_flow.error = None
self.queue.put_nowait(http_flow)
updated.append(http_flow)
ctx.master.addons.trigger("update", updated)
@command.command("replay.client.file")
def load_file(self, path: mitmproxy.types.Path) -> None:
"""
Load flows from file, and add them to the replay queue.
"""
try:
flows = io.read_flows_from_paths([path])
except exceptions.FlowReadException as e:
raise exceptions.CommandError(str(e))
self.start_replay(flows)

View File

@ -29,17 +29,17 @@ class AsyncReply(controller.Reply):
self.obj.error = Error.KILLED_MESSAGE self.obj.error = Error.KILLED_MESSAGE
class ProxyConnectionHandler(server.ConnectionHandler): class ProxyConnectionHandler(server.StreamConnectionHandler):
master: master.Master master: master.Master
def __init__(self, master, r, w, options): def __init__(self, master, r, w, options):
self.master = master self.master = master
super().__init__(r, w, options) super().__init__(r, w, options)
self.log_prefix = f"{human.format_address(self.client.address)}: " self.log_prefix = f"{human.format_address(self.client.peername)}: "
async def handle_hook(self, hook: commands.Hook) -> None: async def handle_hook(self, hook: commands.Hook) -> None:
with self.timeout_watchdog.disarm(): with self.timeout_watchdog.disarm():
# TODO: We currently only support single-argument hooks. # We currently only support single-argument hooks.
data, = hook.as_tuple() data, = hook.as_tuple()
data.reply = AsyncReply(data) data.reply = AsyncReply(data)
await self.master.addons.handle_lifecycle(hook.name, data) await self.master.addons.handle_lifecycle(hook.name, data)
@ -88,19 +88,22 @@ class Proxyserver:
def configure(self, updated): def configure(self, updated):
if not self.is_running: if not self.is_running:
return return
if any(x in updated for x in ["listen_host", "listen_port"]): if any(x in updated for x in ["server", "listen_host", "listen_port"]):
asyncio.ensure_future(self.start_server()) asyncio.ensure_future(self.refresh_server())
async def start_server(self): async def refresh_server(self):
async with self._lock: async with self._lock:
if self.server: if self.server:
await self.shutdown_server() await self.shutdown_server()
print("Starting server...") self.server = None
self.server = await asyncio.start_server( if ctx.options.server:
self.handle_connection, self.server = await asyncio.start_server(
self.options.listen_host, self.handle_connection,
self.options.listen_port, self.options.listen_host,
) self.options.listen_port,
)
addrs = {f"http://{human.format_address(s.getsockname())}" for s in self.server.sockets}
ctx.log.info(f"Proxy server listening at {' and '.join(addrs)}")
async def shutdown_server(self): async def shutdown_server(self):
print("Stopping server...") print("Stopping server...")

View File

@ -113,7 +113,7 @@ class Client(Connection):
'address': self.peername, 'address': self.peername,
'alpn_proto_negotiated': self.alpn, 'alpn_proto_negotiated': self.alpn,
'cipher_name': self.cipher, 'cipher_name': self.cipher,
'clientcert': self.certificate_list[0] if self.certificate_list else None, 'clientcert': self.certificate_list[0].get_state() if self.certificate_list else None,
'id': self.id, 'id': self.id,
'mitmcert': None, 'mitmcert': None,
'sni': self.sni, 'sni': self.sni,
@ -181,7 +181,7 @@ class Server(Connection):
return { return {
'address': self.address, 'address': self.address,
'alpn_proto_negotiated': self.alpn, 'alpn_proto_negotiated': self.alpn,
'cert': self.certificate_list[0] if self.certificate_list else None, 'cert': self.certificate_list[0].get_state() if self.certificate_list else None,
'id': self.id, 'id': self.id,
'ip_address': self.peername, 'ip_address': self.peername,
'sni': self.sni, 'sni': self.sni,

View File

@ -114,17 +114,18 @@ class Http1Server(Http1Connection):
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]: def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
assert event.stream_id == self.stream_id assert event.stream_id == self.stream_id
if isinstance(event, ResponseHeaders): if isinstance(event, ResponseHeaders):
self.response = event.response self.response = response = event.response
if self.response.is_http2: if response.is_http2:
# Convert to an HTTP/1 request. response = response.copy()
self.response.http_version = b"HTTP/1.1" # Convert to an HTTP/1 response.
response.http_version = b"HTTP/1.1"
# not everyone supports empty reason phrases, so we better make up one. # not everyone supports empty reason phrases, so we better make up one.
self.response.reason = status_codes.RESPONSES.get(self.response.status_code, "") response.reason = status_codes.RESPONSES.get(response.status_code, "")
# Shall we set a Content-Length header here if there is none? # Shall we set a Content-Length header here if there is none?
# For now, let's try to modify as little as possible. # For now, let's try to modify as little as possible.
raw = http1.assemble_response_head(event.response) raw = http1.assemble_response_head(response)
yield commands.SendData(self.conn, raw) yield commands.SendData(self.conn, raw)
if self.request.first_line_format == "authority": if self.request.first_line_format == "authority":
assert self.state == self.wait assert self.state == self.wait
@ -237,13 +238,15 @@ class Http1Client(Http1Connection):
return return
if isinstance(event, RequestHeaders): if isinstance(event, RequestHeaders):
if event.request.is_http2: request = event.request
if request.is_http2:
# Convert to an HTTP/1 request. # Convert to an HTTP/1 request.
event.request.http_version = b"HTTP/1.1" request = request.copy() # (we could probably be a bit more efficient here.)
if "Host" not in event.request.headers and event.request.authority: request.http_version = b"HTTP/1.1"
event.request.headers.insert(0, "Host", event.request.authority) if "Host" not in request.headers and request.authority:
event.request.authority = b"" request.headers.insert(0, "Host", request.authority)
raw = http1.assemble_request_head(event.request) request.authority = b""
raw = http1.assemble_request_head(request)
yield commands.SendData(self.conn, raw) yield commands.SendData(self.conn, raw)
elif isinstance(event, RequestData): elif isinstance(event, RequestData):
if "chunked" in self.request.headers.get("transfer-encoding", "").lower(): if "chunked" in self.request.headers.get("transfer-encoding", "").lower():

View File

@ -194,7 +194,7 @@ class _TLSLayer(tunnel.TunnelLayer):
self.conn.cipher = self.tls.get_cipher_name() self.conn.cipher = self.tls.get_cipher_name()
self.conn.cipher_list = self.tls.get_cipher_list() self.conn.cipher_list = self.tls.get_cipher_list()
self.conn.tls_version = self.tls.get_protocol_version_name() self.conn.tls_version = self.tls.get_protocol_version_name()
yield commands.Log(f"TLS established: {self.conn}") yield commands.Log(f"TLS established: {self.conn}", "debug")
yield from self.receive_data(b"") yield from self.receive_data(b"")
return True, None return True, None

View File

@ -13,6 +13,7 @@ import socket
import time import time
import traceback import traceback
import typing import typing
from abc import ABC
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
@ -76,23 +77,16 @@ class ConnectionIO:
class ConnectionHandler(metaclass=abc.ABCMeta): class ConnectionHandler(metaclass=abc.ABCMeta):
transports: typing.MutableMapping[Connection, ConnectionIO] transports: typing.MutableMapping[Connection, ConnectionIO]
timeout_watchdog: TimeoutWatchdog timeout_watchdog: TimeoutWatchdog
client: Client
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, options: moptions.Options) -> None: def __init__(self, context: Context) -> None:
self.client = Client( self.client = context.client
writer.get_extra_info('peername'), self.transports = {}
writer.get_extra_info('sockname'),
time.time(),
)
self.context = Context(self.client, options)
self.transports = {
self.client: ConnectionIO(handler=None, reader=reader, writer=writer)
}
# Ask for the first layer right away. # Ask for the first layer right away.
# In a reverse proxy scenario, this is necessary as we would otherwise hang # In a reverse proxy scenario, this is necessary as we would otherwise hang
# on protocols that start with a server greeting. # on protocols that start with a server greeting.
self.layer = layer.NextLayer(self.context, ask_on_start=True) self.layer = layer.NextLayer(context, ask_on_start=True)
self.timeout_watchdog = TimeoutWatchdog(self.on_timeout) self.timeout_watchdog = TimeoutWatchdog(self.on_timeout)
async def handle_client(self) -> None: async def handle_client(self) -> None:
@ -264,7 +258,19 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
self.transports[connection].handler.cancel() self.transports[connection].handler.cancel()
class SimpleConnectionHandler(ConnectionHandler): class StreamConnectionHandler(ConnectionHandler, metaclass=abc.ABCMeta):
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, options: moptions.Options) -> None:
client = Client(
writer.get_extra_info('peername'),
writer.get_extra_info('sockname'),
time.time(),
)
context = Context(client, options)
super().__init__(context)
self.transports[client] = ConnectionIO(handler=None, reader=reader, writer=writer)
class SimpleConnectionHandler(StreamConnectionHandler):
"""Simple handler that does not really process any hooks.""" """Simple handler that does not really process any hooks."""
hook_handlers: typing.Dict[str, typing.Callable] hook_handlers: typing.Dict[str, typing.Callable]