mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-27 02:24:18 +00:00
[sans-io] implement sans-io based replay
This commit is contained in:
parent
895466bc59
commit
05968a29bb
@ -25,6 +25,10 @@ from mitmproxy.addons import streambodies
|
|||||||
from mitmproxy.addons import save
|
from mitmproxy.addons import save
|
||||||
from mitmproxy.addons import tlsconfig
|
from mitmproxy.addons import tlsconfig
|
||||||
from mitmproxy.addons import upstream_auth
|
from mitmproxy.addons import upstream_auth
|
||||||
|
from mitmproxy.utils import compat
|
||||||
|
|
||||||
|
if compat.new_proxy_core: # pragma: no cover
|
||||||
|
from mitmproxy.addons import clientplayback_sansio as clientplayback # noqa
|
||||||
|
|
||||||
|
|
||||||
def default_addons():
|
def default_addons():
|
||||||
|
205
mitmproxy/addons/clientplayback_sansio.py
Normal file
205
mitmproxy/addons/clientplayback_sansio.py
Normal file
@ -0,0 +1,205 @@
|
|||||||
|
import asyncio
|
||||||
|
import traceback
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import mitmproxy.types
|
||||||
|
from mitmproxy import command
|
||||||
|
from mitmproxy import ctx
|
||||||
|
from mitmproxy import exceptions
|
||||||
|
from mitmproxy import flow
|
||||||
|
from mitmproxy import http
|
||||||
|
from mitmproxy import io
|
||||||
|
from mitmproxy.addons.proxyserver import AsyncReply
|
||||||
|
from mitmproxy.net import server_spec
|
||||||
|
from mitmproxy.options import Options
|
||||||
|
from mitmproxy.proxy.protocol.http import HTTPMode
|
||||||
|
from mitmproxy.proxy2 import commands, events, layers, server
|
||||||
|
from mitmproxy.proxy2.context import Context, Server
|
||||||
|
from mitmproxy.proxy2.layer import CommandGenerator
|
||||||
|
|
||||||
|
|
||||||
|
class MockServer(layers.http.HttpConnection):
|
||||||
|
"""
|
||||||
|
A mock HTTP "server" that just pretends it received a full HTTP request,
|
||||||
|
which is then processed by the proxy core.
|
||||||
|
"""
|
||||||
|
flow: http.HTTPFlow
|
||||||
|
|
||||||
|
def __init__(self, flow: http.HTTPFlow, context: Context):
|
||||||
|
super().__init__(context, context.client)
|
||||||
|
self.flow = flow
|
||||||
|
|
||||||
|
def _handle_event(self, event: events.Event) -> CommandGenerator[None]:
|
||||||
|
if isinstance(event, events.Start):
|
||||||
|
yield layers.http.ReceiveHttp(layers.http.RequestHeaders(1, self.flow.request))
|
||||||
|
if self.flow.request.raw_content:
|
||||||
|
yield layers.http.ReceiveHttp(layers.http.RequestData(1, self.flow.request.raw_content))
|
||||||
|
yield layers.http.ReceiveHttp(layers.http.RequestEndOfMessage(1))
|
||||||
|
elif isinstance(event, (
|
||||||
|
layers.http.ResponseHeaders,
|
||||||
|
layers.http.ResponseData,
|
||||||
|
layers.http.ResponseEndOfMessage,
|
||||||
|
layers.http.ResponseProtocolError,
|
||||||
|
)):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
ctx.log(f"Unexpected event during replay: {events}")
|
||||||
|
|
||||||
|
|
||||||
|
class ReplayHandler(server.ConnectionHandler):
|
||||||
|
def __init__(self, flow: http.HTTPFlow, options: Options) -> None:
|
||||||
|
client = flow.client_conn.copy()
|
||||||
|
|
||||||
|
context = Context(client, options)
|
||||||
|
context.server = Server(
|
||||||
|
(flow.request.host, flow.request.port)
|
||||||
|
)
|
||||||
|
context.server.tls = flow.request.scheme == "https"
|
||||||
|
if options.mode.startswith("upstream:"):
|
||||||
|
context.server.via = server_spec.parse_with_mode(options.mode)[1]
|
||||||
|
|
||||||
|
super().__init__(context)
|
||||||
|
|
||||||
|
self.layer = layers.HttpLayer(context, HTTPMode.transparent)
|
||||||
|
self.layer.connections[client] = MockServer(flow, context.fork())
|
||||||
|
self.flow = flow
|
||||||
|
self.done = asyncio.Event()
|
||||||
|
|
||||||
|
async def replay(self) -> None:
|
||||||
|
self.server_event(events.Start())
|
||||||
|
await self.done.wait()
|
||||||
|
|
||||||
|
def log(self, message: str, level: str = "info") -> None:
|
||||||
|
ctx.log(f"[replay] {message}", level)
|
||||||
|
|
||||||
|
async def handle_hook(self, hook: commands.Hook) -> None:
|
||||||
|
data, = hook.as_tuple()
|
||||||
|
data.reply = AsyncReply(data)
|
||||||
|
await ctx.master.addons.handle_lifecycle(hook.name, data)
|
||||||
|
await data.reply.done.wait()
|
||||||
|
if isinstance(hook, (layers.http.HttpResponseHook, layers.http.HttpErrorHook)):
|
||||||
|
if self.transports:
|
||||||
|
# close server connections
|
||||||
|
for x in self.transports.values():
|
||||||
|
x.handler.cancel()
|
||||||
|
await asyncio.wait([x.handler for x in self.transports.values()])
|
||||||
|
# signal completion
|
||||||
|
self.done.set()
|
||||||
|
|
||||||
|
|
||||||
|
class ClientPlayback:
|
||||||
|
playback_task: typing.Optional[asyncio.Task]
|
||||||
|
inflight: typing.Optional[http.HTTPFlow]
|
||||||
|
queue: asyncio.Queue
|
||||||
|
options: Options
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.queue = asyncio.Queue()
|
||||||
|
self.inflight = None
|
||||||
|
self.task = None
|
||||||
|
|
||||||
|
def running(self):
|
||||||
|
self.playback_task = asyncio.create_task(self.playback())
|
||||||
|
self.options = ctx.options
|
||||||
|
|
||||||
|
def done(self):
|
||||||
|
self.playback_task.cancel()
|
||||||
|
|
||||||
|
async def playback(self):
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
self.inflight = await self.queue.get()
|
||||||
|
try:
|
||||||
|
h = ReplayHandler(self.inflight, self.options)
|
||||||
|
await h.replay()
|
||||||
|
except Exception:
|
||||||
|
ctx.log(f"Client replay has crashed!\n{traceback.format_exc()}", "error")
|
||||||
|
self.inflight = None
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return
|
||||||
|
|
||||||
|
def check(self, f: flow.Flow) -> typing.Optional[str]:
|
||||||
|
if f.live:
|
||||||
|
return "Can't replay live flow."
|
||||||
|
if f.intercepted:
|
||||||
|
return "Can't replay intercepted flow."
|
||||||
|
if isinstance(f, http.HTTPFlow):
|
||||||
|
if not f.request:
|
||||||
|
return "Can't replay flow with missing request."
|
||||||
|
if f.request.raw_content is None:
|
||||||
|
return "Can't replay flow with missing content."
|
||||||
|
else:
|
||||||
|
return "Can only replay HTTP flows."
|
||||||
|
|
||||||
|
def load(self, loader):
|
||||||
|
loader.add_option(
|
||||||
|
"client_replay", typing.Sequence[str], [],
|
||||||
|
"Replay client requests from a saved file."
|
||||||
|
)
|
||||||
|
|
||||||
|
def configure(self, updated):
|
||||||
|
if "client_replay" in updated and ctx.options.client_replay:
|
||||||
|
try:
|
||||||
|
flows = io.read_flows_from_paths(ctx.options.client_replay)
|
||||||
|
except exceptions.FlowReadException as e:
|
||||||
|
raise exceptions.OptionsError(str(e))
|
||||||
|
self.start_replay(flows)
|
||||||
|
|
||||||
|
@command.command("replay.client.count")
|
||||||
|
def count(self) -> int:
|
||||||
|
"""
|
||||||
|
Approximate number of flows queued for replay.
|
||||||
|
"""
|
||||||
|
return self.queue.qsize() + int(bool(self.inflight))
|
||||||
|
|
||||||
|
@command.command("replay.client.stop")
|
||||||
|
def stop_replay(self) -> None:
|
||||||
|
"""
|
||||||
|
Clear the replay queue.
|
||||||
|
"""
|
||||||
|
updated = []
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
f = self.queue.get_nowait()
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
f.revert()
|
||||||
|
updated.append(f)
|
||||||
|
|
||||||
|
ctx.master.addons.trigger("update", updated)
|
||||||
|
ctx.log.alert("Client replay queue cleared.")
|
||||||
|
|
||||||
|
@command.command("replay.client")
|
||||||
|
def start_replay(self, flows: typing.Sequence[flow.Flow]) -> None:
|
||||||
|
"""
|
||||||
|
Add flows to the replay queue, skipping flows that can't be replayed.
|
||||||
|
"""
|
||||||
|
updated: typing.List[http.HTTPFlow] = []
|
||||||
|
for f in flows:
|
||||||
|
err = self.check(f)
|
||||||
|
if err:
|
||||||
|
ctx.log.warn(err)
|
||||||
|
continue
|
||||||
|
|
||||||
|
http_flow = typing.cast(http.HTTPFlow, f)
|
||||||
|
|
||||||
|
# Prepare the flow for replay
|
||||||
|
http_flow.backup()
|
||||||
|
http_flow.is_replay = "request"
|
||||||
|
http_flow.response = None
|
||||||
|
http_flow.error = None
|
||||||
|
self.queue.put_nowait(http_flow)
|
||||||
|
updated.append(http_flow)
|
||||||
|
ctx.master.addons.trigger("update", updated)
|
||||||
|
|
||||||
|
@command.command("replay.client.file")
|
||||||
|
def load_file(self, path: mitmproxy.types.Path) -> None:
|
||||||
|
"""
|
||||||
|
Load flows from file, and add them to the replay queue.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
flows = io.read_flows_from_paths([path])
|
||||||
|
except exceptions.FlowReadException as e:
|
||||||
|
raise exceptions.CommandError(str(e))
|
||||||
|
self.start_replay(flows)
|
@ -29,17 +29,17 @@ class AsyncReply(controller.Reply):
|
|||||||
self.obj.error = Error.KILLED_MESSAGE
|
self.obj.error = Error.KILLED_MESSAGE
|
||||||
|
|
||||||
|
|
||||||
class ProxyConnectionHandler(server.ConnectionHandler):
|
class ProxyConnectionHandler(server.StreamConnectionHandler):
|
||||||
master: master.Master
|
master: master.Master
|
||||||
|
|
||||||
def __init__(self, master, r, w, options):
|
def __init__(self, master, r, w, options):
|
||||||
self.master = master
|
self.master = master
|
||||||
super().__init__(r, w, options)
|
super().__init__(r, w, options)
|
||||||
self.log_prefix = f"{human.format_address(self.client.address)}: "
|
self.log_prefix = f"{human.format_address(self.client.peername)}: "
|
||||||
|
|
||||||
async def handle_hook(self, hook: commands.Hook) -> None:
|
async def handle_hook(self, hook: commands.Hook) -> None:
|
||||||
with self.timeout_watchdog.disarm():
|
with self.timeout_watchdog.disarm():
|
||||||
# TODO: We currently only support single-argument hooks.
|
# We currently only support single-argument hooks.
|
||||||
data, = hook.as_tuple()
|
data, = hook.as_tuple()
|
||||||
data.reply = AsyncReply(data)
|
data.reply = AsyncReply(data)
|
||||||
await self.master.addons.handle_lifecycle(hook.name, data)
|
await self.master.addons.handle_lifecycle(hook.name, data)
|
||||||
@ -88,19 +88,22 @@ class Proxyserver:
|
|||||||
def configure(self, updated):
|
def configure(self, updated):
|
||||||
if not self.is_running:
|
if not self.is_running:
|
||||||
return
|
return
|
||||||
if any(x in updated for x in ["listen_host", "listen_port"]):
|
if any(x in updated for x in ["server", "listen_host", "listen_port"]):
|
||||||
asyncio.ensure_future(self.start_server())
|
asyncio.ensure_future(self.refresh_server())
|
||||||
|
|
||||||
async def start_server(self):
|
async def refresh_server(self):
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
if self.server:
|
if self.server:
|
||||||
await self.shutdown_server()
|
await self.shutdown_server()
|
||||||
print("Starting server...")
|
self.server = None
|
||||||
self.server = await asyncio.start_server(
|
if ctx.options.server:
|
||||||
self.handle_connection,
|
self.server = await asyncio.start_server(
|
||||||
self.options.listen_host,
|
self.handle_connection,
|
||||||
self.options.listen_port,
|
self.options.listen_host,
|
||||||
)
|
self.options.listen_port,
|
||||||
|
)
|
||||||
|
addrs = {f"http://{human.format_address(s.getsockname())}" for s in self.server.sockets}
|
||||||
|
ctx.log.info(f"Proxy server listening at {' and '.join(addrs)}")
|
||||||
|
|
||||||
async def shutdown_server(self):
|
async def shutdown_server(self):
|
||||||
print("Stopping server...")
|
print("Stopping server...")
|
||||||
|
@ -113,7 +113,7 @@ class Client(Connection):
|
|||||||
'address': self.peername,
|
'address': self.peername,
|
||||||
'alpn_proto_negotiated': self.alpn,
|
'alpn_proto_negotiated': self.alpn,
|
||||||
'cipher_name': self.cipher,
|
'cipher_name': self.cipher,
|
||||||
'clientcert': self.certificate_list[0] if self.certificate_list else None,
|
'clientcert': self.certificate_list[0].get_state() if self.certificate_list else None,
|
||||||
'id': self.id,
|
'id': self.id,
|
||||||
'mitmcert': None,
|
'mitmcert': None,
|
||||||
'sni': self.sni,
|
'sni': self.sni,
|
||||||
@ -181,7 +181,7 @@ class Server(Connection):
|
|||||||
return {
|
return {
|
||||||
'address': self.address,
|
'address': self.address,
|
||||||
'alpn_proto_negotiated': self.alpn,
|
'alpn_proto_negotiated': self.alpn,
|
||||||
'cert': self.certificate_list[0] if self.certificate_list else None,
|
'cert': self.certificate_list[0].get_state() if self.certificate_list else None,
|
||||||
'id': self.id,
|
'id': self.id,
|
||||||
'ip_address': self.peername,
|
'ip_address': self.peername,
|
||||||
'sni': self.sni,
|
'sni': self.sni,
|
||||||
|
@ -114,17 +114,18 @@ class Http1Server(Http1Connection):
|
|||||||
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
|
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
|
||||||
assert event.stream_id == self.stream_id
|
assert event.stream_id == self.stream_id
|
||||||
if isinstance(event, ResponseHeaders):
|
if isinstance(event, ResponseHeaders):
|
||||||
self.response = event.response
|
self.response = response = event.response
|
||||||
|
|
||||||
if self.response.is_http2:
|
if response.is_http2:
|
||||||
# Convert to an HTTP/1 request.
|
response = response.copy()
|
||||||
self.response.http_version = b"HTTP/1.1"
|
# Convert to an HTTP/1 response.
|
||||||
|
response.http_version = b"HTTP/1.1"
|
||||||
# not everyone supports empty reason phrases, so we better make up one.
|
# not everyone supports empty reason phrases, so we better make up one.
|
||||||
self.response.reason = status_codes.RESPONSES.get(self.response.status_code, "")
|
response.reason = status_codes.RESPONSES.get(response.status_code, "")
|
||||||
# Shall we set a Content-Length header here if there is none?
|
# Shall we set a Content-Length header here if there is none?
|
||||||
# For now, let's try to modify as little as possible.
|
# For now, let's try to modify as little as possible.
|
||||||
|
|
||||||
raw = http1.assemble_response_head(event.response)
|
raw = http1.assemble_response_head(response)
|
||||||
yield commands.SendData(self.conn, raw)
|
yield commands.SendData(self.conn, raw)
|
||||||
if self.request.first_line_format == "authority":
|
if self.request.first_line_format == "authority":
|
||||||
assert self.state == self.wait
|
assert self.state == self.wait
|
||||||
@ -237,13 +238,15 @@ class Http1Client(Http1Connection):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if isinstance(event, RequestHeaders):
|
if isinstance(event, RequestHeaders):
|
||||||
if event.request.is_http2:
|
request = event.request
|
||||||
|
if request.is_http2:
|
||||||
# Convert to an HTTP/1 request.
|
# Convert to an HTTP/1 request.
|
||||||
event.request.http_version = b"HTTP/1.1"
|
request = request.copy() # (we could probably be a bit more efficient here.)
|
||||||
if "Host" not in event.request.headers and event.request.authority:
|
request.http_version = b"HTTP/1.1"
|
||||||
event.request.headers.insert(0, "Host", event.request.authority)
|
if "Host" not in request.headers and request.authority:
|
||||||
event.request.authority = b""
|
request.headers.insert(0, "Host", request.authority)
|
||||||
raw = http1.assemble_request_head(event.request)
|
request.authority = b""
|
||||||
|
raw = http1.assemble_request_head(request)
|
||||||
yield commands.SendData(self.conn, raw)
|
yield commands.SendData(self.conn, raw)
|
||||||
elif isinstance(event, RequestData):
|
elif isinstance(event, RequestData):
|
||||||
if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
|
if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
|
||||||
|
@ -194,7 +194,7 @@ class _TLSLayer(tunnel.TunnelLayer):
|
|||||||
self.conn.cipher = self.tls.get_cipher_name()
|
self.conn.cipher = self.tls.get_cipher_name()
|
||||||
self.conn.cipher_list = self.tls.get_cipher_list()
|
self.conn.cipher_list = self.tls.get_cipher_list()
|
||||||
self.conn.tls_version = self.tls.get_protocol_version_name()
|
self.conn.tls_version = self.tls.get_protocol_version_name()
|
||||||
yield commands.Log(f"TLS established: {self.conn}")
|
yield commands.Log(f"TLS established: {self.conn}", "debug")
|
||||||
yield from self.receive_data(b"")
|
yield from self.receive_data(b"")
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ import socket
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import typing
|
import typing
|
||||||
|
from abc import ABC
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
@ -76,23 +77,16 @@ class ConnectionIO:
|
|||||||
class ConnectionHandler(metaclass=abc.ABCMeta):
|
class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||||
transports: typing.MutableMapping[Connection, ConnectionIO]
|
transports: typing.MutableMapping[Connection, ConnectionIO]
|
||||||
timeout_watchdog: TimeoutWatchdog
|
timeout_watchdog: TimeoutWatchdog
|
||||||
|
client: Client
|
||||||
|
|
||||||
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, options: moptions.Options) -> None:
|
def __init__(self, context: Context) -> None:
|
||||||
self.client = Client(
|
self.client = context.client
|
||||||
writer.get_extra_info('peername'),
|
self.transports = {}
|
||||||
writer.get_extra_info('sockname'),
|
|
||||||
time.time(),
|
|
||||||
)
|
|
||||||
self.context = Context(self.client, options)
|
|
||||||
self.transports = {
|
|
||||||
self.client: ConnectionIO(handler=None, reader=reader, writer=writer)
|
|
||||||
}
|
|
||||||
|
|
||||||
# Ask for the first layer right away.
|
# Ask for the first layer right away.
|
||||||
# In a reverse proxy scenario, this is necessary as we would otherwise hang
|
# In a reverse proxy scenario, this is necessary as we would otherwise hang
|
||||||
# on protocols that start with a server greeting.
|
# on protocols that start with a server greeting.
|
||||||
self.layer = layer.NextLayer(self.context, ask_on_start=True)
|
self.layer = layer.NextLayer(context, ask_on_start=True)
|
||||||
|
|
||||||
self.timeout_watchdog = TimeoutWatchdog(self.on_timeout)
|
self.timeout_watchdog = TimeoutWatchdog(self.on_timeout)
|
||||||
|
|
||||||
async def handle_client(self) -> None:
|
async def handle_client(self) -> None:
|
||||||
@ -264,7 +258,19 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||||||
self.transports[connection].handler.cancel()
|
self.transports[connection].handler.cancel()
|
||||||
|
|
||||||
|
|
||||||
class SimpleConnectionHandler(ConnectionHandler):
|
class StreamConnectionHandler(ConnectionHandler, metaclass=abc.ABCMeta):
|
||||||
|
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, options: moptions.Options) -> None:
|
||||||
|
client = Client(
|
||||||
|
writer.get_extra_info('peername'),
|
||||||
|
writer.get_extra_info('sockname'),
|
||||||
|
time.time(),
|
||||||
|
)
|
||||||
|
context = Context(client, options)
|
||||||
|
super().__init__(context)
|
||||||
|
self.transports[client] = ConnectionIO(handler=None, reader=reader, writer=writer)
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleConnectionHandler(StreamConnectionHandler):
|
||||||
"""Simple handler that does not really process any hooks."""
|
"""Simple handler that does not really process any hooks."""
|
||||||
|
|
||||||
hook_handlers: typing.Dict[str, typing.Callable]
|
hook_handlers: typing.Dict[str, typing.Callable]
|
||||||
|
Loading…
Reference in New Issue
Block a user