mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 00:01:36 +00:00
[sans-io] implement sans-io based replay
This commit is contained in:
parent
895466bc59
commit
05968a29bb
@ -25,6 +25,10 @@ from mitmproxy.addons import streambodies
|
||||
from mitmproxy.addons import save
|
||||
from mitmproxy.addons import tlsconfig
|
||||
from mitmproxy.addons import upstream_auth
|
||||
from mitmproxy.utils import compat
|
||||
|
||||
if compat.new_proxy_core: # pragma: no cover
|
||||
from mitmproxy.addons import clientplayback_sansio as clientplayback # noqa
|
||||
|
||||
|
||||
def default_addons():
|
||||
|
205
mitmproxy/addons/clientplayback_sansio.py
Normal file
205
mitmproxy/addons/clientplayback_sansio.py
Normal file
@ -0,0 +1,205 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
import typing
|
||||
|
||||
import mitmproxy.types
|
||||
from mitmproxy import command
|
||||
from mitmproxy import ctx
|
||||
from mitmproxy import exceptions
|
||||
from mitmproxy import flow
|
||||
from mitmproxy import http
|
||||
from mitmproxy import io
|
||||
from mitmproxy.addons.proxyserver import AsyncReply
|
||||
from mitmproxy.net import server_spec
|
||||
from mitmproxy.options import Options
|
||||
from mitmproxy.proxy.protocol.http import HTTPMode
|
||||
from mitmproxy.proxy2 import commands, events, layers, server
|
||||
from mitmproxy.proxy2.context import Context, Server
|
||||
from mitmproxy.proxy2.layer import CommandGenerator
|
||||
|
||||
|
||||
class MockServer(layers.http.HttpConnection):
|
||||
"""
|
||||
A mock HTTP "server" that just pretends it received a full HTTP request,
|
||||
which is then processed by the proxy core.
|
||||
"""
|
||||
flow: http.HTTPFlow
|
||||
|
||||
def __init__(self, flow: http.HTTPFlow, context: Context):
|
||||
super().__init__(context, context.client)
|
||||
self.flow = flow
|
||||
|
||||
def _handle_event(self, event: events.Event) -> CommandGenerator[None]:
|
||||
if isinstance(event, events.Start):
|
||||
yield layers.http.ReceiveHttp(layers.http.RequestHeaders(1, self.flow.request))
|
||||
if self.flow.request.raw_content:
|
||||
yield layers.http.ReceiveHttp(layers.http.RequestData(1, self.flow.request.raw_content))
|
||||
yield layers.http.ReceiveHttp(layers.http.RequestEndOfMessage(1))
|
||||
elif isinstance(event, (
|
||||
layers.http.ResponseHeaders,
|
||||
layers.http.ResponseData,
|
||||
layers.http.ResponseEndOfMessage,
|
||||
layers.http.ResponseProtocolError,
|
||||
)):
|
||||
pass
|
||||
else:
|
||||
ctx.log(f"Unexpected event during replay: {events}")
|
||||
|
||||
|
||||
class ReplayHandler(server.ConnectionHandler):
|
||||
def __init__(self, flow: http.HTTPFlow, options: Options) -> None:
|
||||
client = flow.client_conn.copy()
|
||||
|
||||
context = Context(client, options)
|
||||
context.server = Server(
|
||||
(flow.request.host, flow.request.port)
|
||||
)
|
||||
context.server.tls = flow.request.scheme == "https"
|
||||
if options.mode.startswith("upstream:"):
|
||||
context.server.via = server_spec.parse_with_mode(options.mode)[1]
|
||||
|
||||
super().__init__(context)
|
||||
|
||||
self.layer = layers.HttpLayer(context, HTTPMode.transparent)
|
||||
self.layer.connections[client] = MockServer(flow, context.fork())
|
||||
self.flow = flow
|
||||
self.done = asyncio.Event()
|
||||
|
||||
async def replay(self) -> None:
|
||||
self.server_event(events.Start())
|
||||
await self.done.wait()
|
||||
|
||||
def log(self, message: str, level: str = "info") -> None:
|
||||
ctx.log(f"[replay] {message}", level)
|
||||
|
||||
async def handle_hook(self, hook: commands.Hook) -> None:
|
||||
data, = hook.as_tuple()
|
||||
data.reply = AsyncReply(data)
|
||||
await ctx.master.addons.handle_lifecycle(hook.name, data)
|
||||
await data.reply.done.wait()
|
||||
if isinstance(hook, (layers.http.HttpResponseHook, layers.http.HttpErrorHook)):
|
||||
if self.transports:
|
||||
# close server connections
|
||||
for x in self.transports.values():
|
||||
x.handler.cancel()
|
||||
await asyncio.wait([x.handler for x in self.transports.values()])
|
||||
# signal completion
|
||||
self.done.set()
|
||||
|
||||
|
||||
class ClientPlayback:
|
||||
playback_task: typing.Optional[asyncio.Task]
|
||||
inflight: typing.Optional[http.HTTPFlow]
|
||||
queue: asyncio.Queue
|
||||
options: Options
|
||||
|
||||
def __init__(self):
|
||||
self.queue = asyncio.Queue()
|
||||
self.inflight = None
|
||||
self.task = None
|
||||
|
||||
def running(self):
|
||||
self.playback_task = asyncio.create_task(self.playback())
|
||||
self.options = ctx.options
|
||||
|
||||
def done(self):
|
||||
self.playback_task.cancel()
|
||||
|
||||
async def playback(self):
|
||||
try:
|
||||
while True:
|
||||
self.inflight = await self.queue.get()
|
||||
try:
|
||||
h = ReplayHandler(self.inflight, self.options)
|
||||
await h.replay()
|
||||
except Exception:
|
||||
ctx.log(f"Client replay has crashed!\n{traceback.format_exc()}", "error")
|
||||
self.inflight = None
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
def check(self, f: flow.Flow) -> typing.Optional[str]:
|
||||
if f.live:
|
||||
return "Can't replay live flow."
|
||||
if f.intercepted:
|
||||
return "Can't replay intercepted flow."
|
||||
if isinstance(f, http.HTTPFlow):
|
||||
if not f.request:
|
||||
return "Can't replay flow with missing request."
|
||||
if f.request.raw_content is None:
|
||||
return "Can't replay flow with missing content."
|
||||
else:
|
||||
return "Can only replay HTTP flows."
|
||||
|
||||
def load(self, loader):
|
||||
loader.add_option(
|
||||
"client_replay", typing.Sequence[str], [],
|
||||
"Replay client requests from a saved file."
|
||||
)
|
||||
|
||||
def configure(self, updated):
|
||||
if "client_replay" in updated and ctx.options.client_replay:
|
||||
try:
|
||||
flows = io.read_flows_from_paths(ctx.options.client_replay)
|
||||
except exceptions.FlowReadException as e:
|
||||
raise exceptions.OptionsError(str(e))
|
||||
self.start_replay(flows)
|
||||
|
||||
@command.command("replay.client.count")
|
||||
def count(self) -> int:
|
||||
"""
|
||||
Approximate number of flows queued for replay.
|
||||
"""
|
||||
return self.queue.qsize() + int(bool(self.inflight))
|
||||
|
||||
@command.command("replay.client.stop")
|
||||
def stop_replay(self) -> None:
|
||||
"""
|
||||
Clear the replay queue.
|
||||
"""
|
||||
updated = []
|
||||
while True:
|
||||
try:
|
||||
f = self.queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
else:
|
||||
f.revert()
|
||||
updated.append(f)
|
||||
|
||||
ctx.master.addons.trigger("update", updated)
|
||||
ctx.log.alert("Client replay queue cleared.")
|
||||
|
||||
@command.command("replay.client")
|
||||
def start_replay(self, flows: typing.Sequence[flow.Flow]) -> None:
|
||||
"""
|
||||
Add flows to the replay queue, skipping flows that can't be replayed.
|
||||
"""
|
||||
updated: typing.List[http.HTTPFlow] = []
|
||||
for f in flows:
|
||||
err = self.check(f)
|
||||
if err:
|
||||
ctx.log.warn(err)
|
||||
continue
|
||||
|
||||
http_flow = typing.cast(http.HTTPFlow, f)
|
||||
|
||||
# Prepare the flow for replay
|
||||
http_flow.backup()
|
||||
http_flow.is_replay = "request"
|
||||
http_flow.response = None
|
||||
http_flow.error = None
|
||||
self.queue.put_nowait(http_flow)
|
||||
updated.append(http_flow)
|
||||
ctx.master.addons.trigger("update", updated)
|
||||
|
||||
@command.command("replay.client.file")
|
||||
def load_file(self, path: mitmproxy.types.Path) -> None:
|
||||
"""
|
||||
Load flows from file, and add them to the replay queue.
|
||||
"""
|
||||
try:
|
||||
flows = io.read_flows_from_paths([path])
|
||||
except exceptions.FlowReadException as e:
|
||||
raise exceptions.CommandError(str(e))
|
||||
self.start_replay(flows)
|
@ -29,17 +29,17 @@ class AsyncReply(controller.Reply):
|
||||
self.obj.error = Error.KILLED_MESSAGE
|
||||
|
||||
|
||||
class ProxyConnectionHandler(server.ConnectionHandler):
|
||||
class ProxyConnectionHandler(server.StreamConnectionHandler):
|
||||
master: master.Master
|
||||
|
||||
def __init__(self, master, r, w, options):
|
||||
self.master = master
|
||||
super().__init__(r, w, options)
|
||||
self.log_prefix = f"{human.format_address(self.client.address)}: "
|
||||
self.log_prefix = f"{human.format_address(self.client.peername)}: "
|
||||
|
||||
async def handle_hook(self, hook: commands.Hook) -> None:
|
||||
with self.timeout_watchdog.disarm():
|
||||
# TODO: We currently only support single-argument hooks.
|
||||
# We currently only support single-argument hooks.
|
||||
data, = hook.as_tuple()
|
||||
data.reply = AsyncReply(data)
|
||||
await self.master.addons.handle_lifecycle(hook.name, data)
|
||||
@ -88,19 +88,22 @@ class Proxyserver:
|
||||
def configure(self, updated):
|
||||
if not self.is_running:
|
||||
return
|
||||
if any(x in updated for x in ["listen_host", "listen_port"]):
|
||||
asyncio.ensure_future(self.start_server())
|
||||
if any(x in updated for x in ["server", "listen_host", "listen_port"]):
|
||||
asyncio.ensure_future(self.refresh_server())
|
||||
|
||||
async def start_server(self):
|
||||
async def refresh_server(self):
|
||||
async with self._lock:
|
||||
if self.server:
|
||||
await self.shutdown_server()
|
||||
print("Starting server...")
|
||||
self.server = await asyncio.start_server(
|
||||
self.handle_connection,
|
||||
self.options.listen_host,
|
||||
self.options.listen_port,
|
||||
)
|
||||
self.server = None
|
||||
if ctx.options.server:
|
||||
self.server = await asyncio.start_server(
|
||||
self.handle_connection,
|
||||
self.options.listen_host,
|
||||
self.options.listen_port,
|
||||
)
|
||||
addrs = {f"http://{human.format_address(s.getsockname())}" for s in self.server.sockets}
|
||||
ctx.log.info(f"Proxy server listening at {' and '.join(addrs)}")
|
||||
|
||||
async def shutdown_server(self):
|
||||
print("Stopping server...")
|
||||
|
@ -113,7 +113,7 @@ class Client(Connection):
|
||||
'address': self.peername,
|
||||
'alpn_proto_negotiated': self.alpn,
|
||||
'cipher_name': self.cipher,
|
||||
'clientcert': self.certificate_list[0] if self.certificate_list else None,
|
||||
'clientcert': self.certificate_list[0].get_state() if self.certificate_list else None,
|
||||
'id': self.id,
|
||||
'mitmcert': None,
|
||||
'sni': self.sni,
|
||||
@ -181,7 +181,7 @@ class Server(Connection):
|
||||
return {
|
||||
'address': self.address,
|
||||
'alpn_proto_negotiated': self.alpn,
|
||||
'cert': self.certificate_list[0] if self.certificate_list else None,
|
||||
'cert': self.certificate_list[0].get_state() if self.certificate_list else None,
|
||||
'id': self.id,
|
||||
'ip_address': self.peername,
|
||||
'sni': self.sni,
|
||||
|
@ -114,17 +114,18 @@ class Http1Server(Http1Connection):
|
||||
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
|
||||
assert event.stream_id == self.stream_id
|
||||
if isinstance(event, ResponseHeaders):
|
||||
self.response = event.response
|
||||
self.response = response = event.response
|
||||
|
||||
if self.response.is_http2:
|
||||
# Convert to an HTTP/1 request.
|
||||
self.response.http_version = b"HTTP/1.1"
|
||||
if response.is_http2:
|
||||
response = response.copy()
|
||||
# Convert to an HTTP/1 response.
|
||||
response.http_version = b"HTTP/1.1"
|
||||
# not everyone supports empty reason phrases, so we better make up one.
|
||||
self.response.reason = status_codes.RESPONSES.get(self.response.status_code, "")
|
||||
response.reason = status_codes.RESPONSES.get(response.status_code, "")
|
||||
# Shall we set a Content-Length header here if there is none?
|
||||
# For now, let's try to modify as little as possible.
|
||||
|
||||
raw = http1.assemble_response_head(event.response)
|
||||
raw = http1.assemble_response_head(response)
|
||||
yield commands.SendData(self.conn, raw)
|
||||
if self.request.first_line_format == "authority":
|
||||
assert self.state == self.wait
|
||||
@ -237,13 +238,15 @@ class Http1Client(Http1Connection):
|
||||
return
|
||||
|
||||
if isinstance(event, RequestHeaders):
|
||||
if event.request.is_http2:
|
||||
request = event.request
|
||||
if request.is_http2:
|
||||
# Convert to an HTTP/1 request.
|
||||
event.request.http_version = b"HTTP/1.1"
|
||||
if "Host" not in event.request.headers and event.request.authority:
|
||||
event.request.headers.insert(0, "Host", event.request.authority)
|
||||
event.request.authority = b""
|
||||
raw = http1.assemble_request_head(event.request)
|
||||
request = request.copy() # (we could probably be a bit more efficient here.)
|
||||
request.http_version = b"HTTP/1.1"
|
||||
if "Host" not in request.headers and request.authority:
|
||||
request.headers.insert(0, "Host", request.authority)
|
||||
request.authority = b""
|
||||
raw = http1.assemble_request_head(request)
|
||||
yield commands.SendData(self.conn, raw)
|
||||
elif isinstance(event, RequestData):
|
||||
if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
|
||||
|
@ -194,7 +194,7 @@ class _TLSLayer(tunnel.TunnelLayer):
|
||||
self.conn.cipher = self.tls.get_cipher_name()
|
||||
self.conn.cipher_list = self.tls.get_cipher_list()
|
||||
self.conn.tls_version = self.tls.get_protocol_version_name()
|
||||
yield commands.Log(f"TLS established: {self.conn}")
|
||||
yield commands.Log(f"TLS established: {self.conn}", "debug")
|
||||
yield from self.receive_data(b"")
|
||||
return True, None
|
||||
|
||||
|
@ -13,6 +13,7 @@ import socket
|
||||
import time
|
||||
import traceback
|
||||
import typing
|
||||
from abc import ABC
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
|
||||
@ -76,23 +77,16 @@ class ConnectionIO:
|
||||
class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
transports: typing.MutableMapping[Connection, ConnectionIO]
|
||||
timeout_watchdog: TimeoutWatchdog
|
||||
client: Client
|
||||
|
||||
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, options: moptions.Options) -> None:
|
||||
self.client = Client(
|
||||
writer.get_extra_info('peername'),
|
||||
writer.get_extra_info('sockname'),
|
||||
time.time(),
|
||||
)
|
||||
self.context = Context(self.client, options)
|
||||
self.transports = {
|
||||
self.client: ConnectionIO(handler=None, reader=reader, writer=writer)
|
||||
}
|
||||
def __init__(self, context: Context) -> None:
|
||||
self.client = context.client
|
||||
self.transports = {}
|
||||
|
||||
# Ask for the first layer right away.
|
||||
# In a reverse proxy scenario, this is necessary as we would otherwise hang
|
||||
# on protocols that start with a server greeting.
|
||||
self.layer = layer.NextLayer(self.context, ask_on_start=True)
|
||||
|
||||
self.layer = layer.NextLayer(context, ask_on_start=True)
|
||||
self.timeout_watchdog = TimeoutWatchdog(self.on_timeout)
|
||||
|
||||
async def handle_client(self) -> None:
|
||||
@ -264,7 +258,19 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
self.transports[connection].handler.cancel()
|
||||
|
||||
|
||||
class SimpleConnectionHandler(ConnectionHandler):
|
||||
class StreamConnectionHandler(ConnectionHandler, metaclass=abc.ABCMeta):
|
||||
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, options: moptions.Options) -> None:
|
||||
client = Client(
|
||||
writer.get_extra_info('peername'),
|
||||
writer.get_extra_info('sockname'),
|
||||
time.time(),
|
||||
)
|
||||
context = Context(client, options)
|
||||
super().__init__(context)
|
||||
self.transports[client] = ConnectionIO(handler=None, reader=reader, writer=writer)
|
||||
|
||||
|
||||
class SimpleConnectionHandler(StreamConnectionHandler):
|
||||
"""Simple handler that does not really process any hooks."""
|
||||
|
||||
hook_handlers: typing.Dict[str, typing.Callable]
|
||||
|
Loading…
Reference in New Issue
Block a user