mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 00:01:36 +00:00
[sans-io] more testing, more bugfixes!
This commit is contained in:
parent
38f006eb9a
commit
c639fafd64
@ -22,7 +22,10 @@ class AsyncReply(controller.Reply):
|
||||
|
||||
def commit(self):
|
||||
super().commit()
|
||||
try:
|
||||
self.loop.call_soon_threadsafe(lambda: self.done.set())
|
||||
except RuntimeError:
|
||||
pass # event loop may already be closed.
|
||||
|
||||
def kill(self, force=False):
|
||||
warnings.warn("reply.kill() is deprecated, set the error attribute instead.", PendingDeprecationWarning)
|
||||
@ -49,9 +52,11 @@ class ProxyConnectionHandler(server.StreamConnectionHandler):
|
||||
def log(self, message: str, level: str = "info") -> None:
|
||||
x = log.LogEntry(self.log_prefix + message, level)
|
||||
x.reply = controller.DummyReply()
|
||||
asyncio.ensure_future(
|
||||
self.master.addons.handle_lifecycle("log", x)
|
||||
)
|
||||
coro = self.master.addons.handle_lifecycle("log", x)
|
||||
try:
|
||||
asyncio.ensure_future(coro)
|
||||
except RuntimeError:
|
||||
coro.close() # event loop may already be closed, but we don't want a "has never been awaited error"
|
||||
|
||||
|
||||
class Proxyserver:
|
||||
@ -113,6 +118,7 @@ class Proxyserver:
|
||||
self.server = None
|
||||
|
||||
async def handle_connection(self, r, w):
|
||||
asyncio.current_task().set_name(f"proxy connection handler {w.get_extra_info('peername')}")
|
||||
handler = ProxyConnectionHandler(
|
||||
self.master,
|
||||
r,
|
||||
|
@ -62,8 +62,12 @@ class Layer:
|
||||
return f"{type(self).__name__}({state})"
|
||||
|
||||
def __debug(self, message):
|
||||
if len(message) > 512:
|
||||
message = message[:512] + "…"
|
||||
if Layer.__last_debug_message == message:
|
||||
message = message.split("\n", 1)[0].strip()
|
||||
if len(message) > 256:
|
||||
message = message[:256] + "…"
|
||||
else:
|
||||
Layer.__last_debug_message = message
|
||||
return commands.Log(
|
||||
|
@ -142,6 +142,7 @@ class Http1Server(Http1Connection):
|
||||
elif isinstance(event, ResponseEndOfMessage):
|
||||
if "chunked" in self.response.headers.get("transfer-encoding", "").lower():
|
||||
yield commands.SendData(self.conn, b"0\r\n\r\n")
|
||||
yield from self.mark_done(response=True)
|
||||
elif http1.expected_http_body_size(self.request, self.response) == -1:
|
||||
yield commands.CloseConnection(self.conn)
|
||||
elif self.request.first_line_format != "authority":
|
||||
|
@ -67,8 +67,12 @@ class Http2Connection(HttpConnection):
|
||||
self.h2_conn.send_data(event.stream_id, event.data)
|
||||
elif isinstance(event, self.SendEndOfMessage):
|
||||
self.h2_conn.send_data(event.stream_id, b"", end_stream=True)
|
||||
if self.h2_conn.streams.get(event.stream_id).closed:
|
||||
self.streams.pop(event.stream_id, None)
|
||||
elif isinstance(event, self.SendProtocolError):
|
||||
self.h2_conn.reset_stream(event.stream_id, h2.errors.ErrorCodes.PROTOCOL_ERROR)
|
||||
if self.h2_conn.streams.get(event.stream_id).closed:
|
||||
self.streams.pop(event.stream_id, None)
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event}")
|
||||
yield SendData(self.conn, self.h2_conn.data_to_send())
|
||||
@ -119,10 +123,14 @@ class Http2Connection(HttpConnection):
|
||||
yield ReceiveHttp(self.ReceiveEndOfMessage(event.stream_id))
|
||||
elif state is StreamState.EXPECTING_HEADERS:
|
||||
raise AssertionError("unreachable")
|
||||
if self.h2_conn.streams.get(event.stream_id).closed:
|
||||
self.streams.pop(event.stream_id, None)
|
||||
elif isinstance(event, h2.events.StreamReset):
|
||||
if event.stream_id in self.streams:
|
||||
yield ReceiveHttp(self.ReceiveProtocolError(event.stream_id, "Stream reset"))
|
||||
yield ReceiveHttp(self.ReceiveProtocolError(event.stream_id, f"Stream reset, error code {event.error_code}"))
|
||||
self.streams.pop(event.stream_id)
|
||||
else:
|
||||
pass # We don't track priority frames which could be followed by a stream reset here.
|
||||
elif isinstance(event, h2.exceptions.ProtocolError):
|
||||
yield from self.protocol_error(f"HTTP/2 protocol error: {event}")
|
||||
return True
|
||||
@ -282,6 +290,11 @@ class Http2Client(Http2Connection):
|
||||
]
|
||||
if event.request.authority:
|
||||
pseudo_headers.append((b":authority", event.request.data.authority))
|
||||
elif not event.request.is_http2:
|
||||
host_header = event.request.headers.pop("host", None)
|
||||
if host_header:
|
||||
pseudo_headers.append((b":authority", host_header))
|
||||
|
||||
headers = pseudo_headers + list(event.request.headers.fields)
|
||||
if not event.request.is_http2:
|
||||
headers = normalize_h1_headers(headers, True)
|
||||
|
@ -174,7 +174,7 @@ class _TLSLayer(tunnel.TunnelLayer):
|
||||
elif last_err == ('SSL routines', 'ssl3_get_record', 'wrong version number') and data[:4].isascii():
|
||||
err = f"The remote server does not speak TLS."
|
||||
else:
|
||||
err = repr(e)
|
||||
err = f"OpenSSL {e!r}"
|
||||
return False, err
|
||||
else:
|
||||
# Get all peer certificates.
|
||||
@ -194,7 +194,8 @@ class _TLSLayer(tunnel.TunnelLayer):
|
||||
self.conn.cipher = self.tls.get_cipher_name()
|
||||
self.conn.cipher_list = self.tls.get_cipher_list()
|
||||
self.conn.tls_version = self.tls.get_protocol_version_name()
|
||||
yield commands.Log(f"TLS established: {self.conn}", "debug")
|
||||
if self.debug:
|
||||
yield commands.Log(f"{self.debug}[tls] tls established: {self.conn}", "debug")
|
||||
yield from self.receive_data(b"")
|
||||
return True, None
|
||||
|
||||
|
@ -218,7 +218,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
if isinstance(command, commands.OpenConnection):
|
||||
assert command.connection not in self.transports
|
||||
handler = asyncio.create_task(
|
||||
self.open_connection(command)
|
||||
self.open_connection(command),
|
||||
name=f"open_connection {command.connection.address}"
|
||||
)
|
||||
self.transports[command.connection] = ConnectionIO(handler=handler)
|
||||
elif isinstance(command, commands.ConnectionCommand) and command.connection not in self.transports:
|
||||
@ -231,7 +232,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
socket = self.transports[command.connection].writer.get_extra_info("socket")
|
||||
self.server_event(events.GetSocketReply(command, socket))
|
||||
elif isinstance(command, commands.Hook):
|
||||
asyncio.create_task(self.hook_task(command))
|
||||
asyncio.create_task(self.hook_task(command), name=f"hook {command.name}")
|
||||
elif isinstance(command, commands.Log):
|
||||
self.log(command.message, command.level)
|
||||
else:
|
||||
|
@ -40,7 +40,10 @@ class TunnelLayer(layer.Layer):
|
||||
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, events.Start):
|
||||
if self.tunnel_connection.connected:
|
||||
if self.tunnel_connection.state is not context.ConnectionState.CLOSED:
|
||||
# we might be in the interesting state here where the connection is already half-closed,
|
||||
# for example because next_layer buffered events and the client disconnected in the meantime.
|
||||
# we still expect a close event to arrive, so we carry on here as normal for now.
|
||||
self.tunnel_state = TunnelState.ESTABLISHING
|
||||
yield from self.start_handshake()
|
||||
yield from self.event_to_child(event)
|
||||
@ -62,7 +65,7 @@ class TunnelLayer(layer.Layer):
|
||||
yield from self.receive_data(event.data)
|
||||
elif isinstance(event, events.ConnectionClosed):
|
||||
if self.conn != self.tunnel_connection:
|
||||
self.conn.state &= ~context.ConnectionState.CAN_READ
|
||||
self.conn.state = context.ConnectionState.CLOSED
|
||||
if self.tunnel_state is TunnelState.OPEN:
|
||||
yield from self.receive_close()
|
||||
elif self.tunnel_state is TunnelState.ESTABLISHING:
|
||||
@ -79,7 +82,8 @@ class TunnelLayer(layer.Layer):
|
||||
yield from self.send_data(command.data)
|
||||
elif isinstance(command, commands.CloseConnection):
|
||||
if self.conn != self.tunnel_connection:
|
||||
self.conn.state &= ~context.ConnectionState.CAN_WRITE
|
||||
# we don't have a use case for distinguishing between read/write here
|
||||
self.conn.state = context.ConnectionState.CLOSED
|
||||
yield from self.send_close()
|
||||
elif isinstance(command, commands.OpenConnection):
|
||||
# create our own OpenConnection command object that blocks here.
|
||||
|
@ -1,7 +1,8 @@
|
||||
import pytest
|
||||
|
||||
from mitmproxy import options
|
||||
from mitmproxy import log, options
|
||||
from mitmproxy.addons.proxyserver import Proxyserver
|
||||
from mitmproxy.addons.termlog import TermLog
|
||||
from mitmproxy.proxy2 import context
|
||||
|
||||
|
||||
@ -9,6 +10,7 @@ from mitmproxy.proxy2 import context
|
||||
def tctx() -> context.Context:
|
||||
opts = options.Options()
|
||||
Proxyserver().load(opts)
|
||||
TermLog().load(opts)
|
||||
return context.Context(
|
||||
context.Client(
|
||||
("client", 1234),
|
||||
|
Loading…
Reference in New Issue
Block a user