mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 18:18:25 +00:00
[sans-io] more testing, more bugfixes!
This commit is contained in:
parent
38f006eb9a
commit
c639fafd64
@ -22,7 +22,10 @@ class AsyncReply(controller.Reply):
|
|||||||
|
|
||||||
def commit(self):
|
def commit(self):
|
||||||
super().commit()
|
super().commit()
|
||||||
|
try:
|
||||||
self.loop.call_soon_threadsafe(lambda: self.done.set())
|
self.loop.call_soon_threadsafe(lambda: self.done.set())
|
||||||
|
except RuntimeError:
|
||||||
|
pass # event loop may already be closed.
|
||||||
|
|
||||||
def kill(self, force=False):
|
def kill(self, force=False):
|
||||||
warnings.warn("reply.kill() is deprecated, set the error attribute instead.", PendingDeprecationWarning)
|
warnings.warn("reply.kill() is deprecated, set the error attribute instead.", PendingDeprecationWarning)
|
||||||
@ -49,9 +52,11 @@ class ProxyConnectionHandler(server.StreamConnectionHandler):
|
|||||||
def log(self, message: str, level: str = "info") -> None:
|
def log(self, message: str, level: str = "info") -> None:
|
||||||
x = log.LogEntry(self.log_prefix + message, level)
|
x = log.LogEntry(self.log_prefix + message, level)
|
||||||
x.reply = controller.DummyReply()
|
x.reply = controller.DummyReply()
|
||||||
asyncio.ensure_future(
|
coro = self.master.addons.handle_lifecycle("log", x)
|
||||||
self.master.addons.handle_lifecycle("log", x)
|
try:
|
||||||
)
|
asyncio.ensure_future(coro)
|
||||||
|
except RuntimeError:
|
||||||
|
coro.close() # event loop may already be closed, but we don't want a "has never been awaited error"
|
||||||
|
|
||||||
|
|
||||||
class Proxyserver:
|
class Proxyserver:
|
||||||
@ -113,6 +118,7 @@ class Proxyserver:
|
|||||||
self.server = None
|
self.server = None
|
||||||
|
|
||||||
async def handle_connection(self, r, w):
|
async def handle_connection(self, r, w):
|
||||||
|
asyncio.current_task().set_name(f"proxy connection handler {w.get_extra_info('peername')}")
|
||||||
handler = ProxyConnectionHandler(
|
handler = ProxyConnectionHandler(
|
||||||
self.master,
|
self.master,
|
||||||
r,
|
r,
|
||||||
|
@ -62,8 +62,12 @@ class Layer:
|
|||||||
return f"{type(self).__name__}({state})"
|
return f"{type(self).__name__}({state})"
|
||||||
|
|
||||||
def __debug(self, message):
|
def __debug(self, message):
|
||||||
|
if len(message) > 512:
|
||||||
|
message = message[:512] + "…"
|
||||||
if Layer.__last_debug_message == message:
|
if Layer.__last_debug_message == message:
|
||||||
message = message.split("\n", 1)[0].strip()
|
message = message.split("\n", 1)[0].strip()
|
||||||
|
if len(message) > 256:
|
||||||
|
message = message[:256] + "…"
|
||||||
else:
|
else:
|
||||||
Layer.__last_debug_message = message
|
Layer.__last_debug_message = message
|
||||||
return commands.Log(
|
return commands.Log(
|
||||||
|
@ -142,6 +142,7 @@ class Http1Server(Http1Connection):
|
|||||||
elif isinstance(event, ResponseEndOfMessage):
|
elif isinstance(event, ResponseEndOfMessage):
|
||||||
if "chunked" in self.response.headers.get("transfer-encoding", "").lower():
|
if "chunked" in self.response.headers.get("transfer-encoding", "").lower():
|
||||||
yield commands.SendData(self.conn, b"0\r\n\r\n")
|
yield commands.SendData(self.conn, b"0\r\n\r\n")
|
||||||
|
yield from self.mark_done(response=True)
|
||||||
elif http1.expected_http_body_size(self.request, self.response) == -1:
|
elif http1.expected_http_body_size(self.request, self.response) == -1:
|
||||||
yield commands.CloseConnection(self.conn)
|
yield commands.CloseConnection(self.conn)
|
||||||
elif self.request.first_line_format != "authority":
|
elif self.request.first_line_format != "authority":
|
||||||
|
@ -67,8 +67,12 @@ class Http2Connection(HttpConnection):
|
|||||||
self.h2_conn.send_data(event.stream_id, event.data)
|
self.h2_conn.send_data(event.stream_id, event.data)
|
||||||
elif isinstance(event, self.SendEndOfMessage):
|
elif isinstance(event, self.SendEndOfMessage):
|
||||||
self.h2_conn.send_data(event.stream_id, b"", end_stream=True)
|
self.h2_conn.send_data(event.stream_id, b"", end_stream=True)
|
||||||
|
if self.h2_conn.streams.get(event.stream_id).closed:
|
||||||
|
self.streams.pop(event.stream_id, None)
|
||||||
elif isinstance(event, self.SendProtocolError):
|
elif isinstance(event, self.SendProtocolError):
|
||||||
self.h2_conn.reset_stream(event.stream_id, h2.errors.ErrorCodes.PROTOCOL_ERROR)
|
self.h2_conn.reset_stream(event.stream_id, h2.errors.ErrorCodes.PROTOCOL_ERROR)
|
||||||
|
if self.h2_conn.streams.get(event.stream_id).closed:
|
||||||
|
self.streams.pop(event.stream_id, None)
|
||||||
else:
|
else:
|
||||||
raise AssertionError(f"Unexpected event: {event}")
|
raise AssertionError(f"Unexpected event: {event}")
|
||||||
yield SendData(self.conn, self.h2_conn.data_to_send())
|
yield SendData(self.conn, self.h2_conn.data_to_send())
|
||||||
@ -119,10 +123,14 @@ class Http2Connection(HttpConnection):
|
|||||||
yield ReceiveHttp(self.ReceiveEndOfMessage(event.stream_id))
|
yield ReceiveHttp(self.ReceiveEndOfMessage(event.stream_id))
|
||||||
elif state is StreamState.EXPECTING_HEADERS:
|
elif state is StreamState.EXPECTING_HEADERS:
|
||||||
raise AssertionError("unreachable")
|
raise AssertionError("unreachable")
|
||||||
|
if self.h2_conn.streams.get(event.stream_id).closed:
|
||||||
self.streams.pop(event.stream_id, None)
|
self.streams.pop(event.stream_id, None)
|
||||||
elif isinstance(event, h2.events.StreamReset):
|
elif isinstance(event, h2.events.StreamReset):
|
||||||
if event.stream_id in self.streams:
|
if event.stream_id in self.streams:
|
||||||
yield ReceiveHttp(self.ReceiveProtocolError(event.stream_id, "Stream reset"))
|
yield ReceiveHttp(self.ReceiveProtocolError(event.stream_id, f"Stream reset, error code {event.error_code}"))
|
||||||
|
self.streams.pop(event.stream_id)
|
||||||
|
else:
|
||||||
|
pass # We don't track priority frames which could be followed by a stream reset here.
|
||||||
elif isinstance(event, h2.exceptions.ProtocolError):
|
elif isinstance(event, h2.exceptions.ProtocolError):
|
||||||
yield from self.protocol_error(f"HTTP/2 protocol error: {event}")
|
yield from self.protocol_error(f"HTTP/2 protocol error: {event}")
|
||||||
return True
|
return True
|
||||||
@ -282,6 +290,11 @@ class Http2Client(Http2Connection):
|
|||||||
]
|
]
|
||||||
if event.request.authority:
|
if event.request.authority:
|
||||||
pseudo_headers.append((b":authority", event.request.data.authority))
|
pseudo_headers.append((b":authority", event.request.data.authority))
|
||||||
|
elif not event.request.is_http2:
|
||||||
|
host_header = event.request.headers.pop("host", None)
|
||||||
|
if host_header:
|
||||||
|
pseudo_headers.append((b":authority", host_header))
|
||||||
|
|
||||||
headers = pseudo_headers + list(event.request.headers.fields)
|
headers = pseudo_headers + list(event.request.headers.fields)
|
||||||
if not event.request.is_http2:
|
if not event.request.is_http2:
|
||||||
headers = normalize_h1_headers(headers, True)
|
headers = normalize_h1_headers(headers, True)
|
||||||
|
@ -174,7 +174,7 @@ class _TLSLayer(tunnel.TunnelLayer):
|
|||||||
elif last_err == ('SSL routines', 'ssl3_get_record', 'wrong version number') and data[:4].isascii():
|
elif last_err == ('SSL routines', 'ssl3_get_record', 'wrong version number') and data[:4].isascii():
|
||||||
err = f"The remote server does not speak TLS."
|
err = f"The remote server does not speak TLS."
|
||||||
else:
|
else:
|
||||||
err = repr(e)
|
err = f"OpenSSL {e!r}"
|
||||||
return False, err
|
return False, err
|
||||||
else:
|
else:
|
||||||
# Get all peer certificates.
|
# Get all peer certificates.
|
||||||
@ -194,7 +194,8 @@ class _TLSLayer(tunnel.TunnelLayer):
|
|||||||
self.conn.cipher = self.tls.get_cipher_name()
|
self.conn.cipher = self.tls.get_cipher_name()
|
||||||
self.conn.cipher_list = self.tls.get_cipher_list()
|
self.conn.cipher_list = self.tls.get_cipher_list()
|
||||||
self.conn.tls_version = self.tls.get_protocol_version_name()
|
self.conn.tls_version = self.tls.get_protocol_version_name()
|
||||||
yield commands.Log(f"TLS established: {self.conn}", "debug")
|
if self.debug:
|
||||||
|
yield commands.Log(f"{self.debug}[tls] tls established: {self.conn}", "debug")
|
||||||
yield from self.receive_data(b"")
|
yield from self.receive_data(b"")
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
|
@ -218,7 +218,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||||||
if isinstance(command, commands.OpenConnection):
|
if isinstance(command, commands.OpenConnection):
|
||||||
assert command.connection not in self.transports
|
assert command.connection not in self.transports
|
||||||
handler = asyncio.create_task(
|
handler = asyncio.create_task(
|
||||||
self.open_connection(command)
|
self.open_connection(command),
|
||||||
|
name=f"open_connection {command.connection.address}"
|
||||||
)
|
)
|
||||||
self.transports[command.connection] = ConnectionIO(handler=handler)
|
self.transports[command.connection] = ConnectionIO(handler=handler)
|
||||||
elif isinstance(command, commands.ConnectionCommand) and command.connection not in self.transports:
|
elif isinstance(command, commands.ConnectionCommand) and command.connection not in self.transports:
|
||||||
@ -231,7 +232,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||||||
socket = self.transports[command.connection].writer.get_extra_info("socket")
|
socket = self.transports[command.connection].writer.get_extra_info("socket")
|
||||||
self.server_event(events.GetSocketReply(command, socket))
|
self.server_event(events.GetSocketReply(command, socket))
|
||||||
elif isinstance(command, commands.Hook):
|
elif isinstance(command, commands.Hook):
|
||||||
asyncio.create_task(self.hook_task(command))
|
asyncio.create_task(self.hook_task(command), name=f"hook {command.name}")
|
||||||
elif isinstance(command, commands.Log):
|
elif isinstance(command, commands.Log):
|
||||||
self.log(command.message, command.level)
|
self.log(command.message, command.level)
|
||||||
else:
|
else:
|
||||||
|
@ -40,7 +40,10 @@ class TunnelLayer(layer.Layer):
|
|||||||
|
|
||||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||||
if isinstance(event, events.Start):
|
if isinstance(event, events.Start):
|
||||||
if self.tunnel_connection.connected:
|
if self.tunnel_connection.state is not context.ConnectionState.CLOSED:
|
||||||
|
# we might be in the interesting state here where the connection is already half-closed,
|
||||||
|
# for example because next_layer buffered events and the client disconnected in the meantime.
|
||||||
|
# we still expect a close event to arrive, so we carry on here as normal for now.
|
||||||
self.tunnel_state = TunnelState.ESTABLISHING
|
self.tunnel_state = TunnelState.ESTABLISHING
|
||||||
yield from self.start_handshake()
|
yield from self.start_handshake()
|
||||||
yield from self.event_to_child(event)
|
yield from self.event_to_child(event)
|
||||||
@ -62,7 +65,7 @@ class TunnelLayer(layer.Layer):
|
|||||||
yield from self.receive_data(event.data)
|
yield from self.receive_data(event.data)
|
||||||
elif isinstance(event, events.ConnectionClosed):
|
elif isinstance(event, events.ConnectionClosed):
|
||||||
if self.conn != self.tunnel_connection:
|
if self.conn != self.tunnel_connection:
|
||||||
self.conn.state &= ~context.ConnectionState.CAN_READ
|
self.conn.state = context.ConnectionState.CLOSED
|
||||||
if self.tunnel_state is TunnelState.OPEN:
|
if self.tunnel_state is TunnelState.OPEN:
|
||||||
yield from self.receive_close()
|
yield from self.receive_close()
|
||||||
elif self.tunnel_state is TunnelState.ESTABLISHING:
|
elif self.tunnel_state is TunnelState.ESTABLISHING:
|
||||||
@ -79,7 +82,8 @@ class TunnelLayer(layer.Layer):
|
|||||||
yield from self.send_data(command.data)
|
yield from self.send_data(command.data)
|
||||||
elif isinstance(command, commands.CloseConnection):
|
elif isinstance(command, commands.CloseConnection):
|
||||||
if self.conn != self.tunnel_connection:
|
if self.conn != self.tunnel_connection:
|
||||||
self.conn.state &= ~context.ConnectionState.CAN_WRITE
|
# we don't have a use case for distinguishing between read/write here
|
||||||
|
self.conn.state = context.ConnectionState.CLOSED
|
||||||
yield from self.send_close()
|
yield from self.send_close()
|
||||||
elif isinstance(command, commands.OpenConnection):
|
elif isinstance(command, commands.OpenConnection):
|
||||||
# create our own OpenConnection command object that blocks here.
|
# create our own OpenConnection command object that blocks here.
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from mitmproxy import options
|
from mitmproxy import log, options
|
||||||
from mitmproxy.addons.proxyserver import Proxyserver
|
from mitmproxy.addons.proxyserver import Proxyserver
|
||||||
|
from mitmproxy.addons.termlog import TermLog
|
||||||
from mitmproxy.proxy2 import context
|
from mitmproxy.proxy2 import context
|
||||||
|
|
||||||
|
|
||||||
@ -9,6 +10,7 @@ from mitmproxy.proxy2 import context
|
|||||||
def tctx() -> context.Context:
|
def tctx() -> context.Context:
|
||||||
opts = options.Options()
|
opts = options.Options()
|
||||||
Proxyserver().load(opts)
|
Proxyserver().load(opts)
|
||||||
|
TermLog().load(opts)
|
||||||
return context.Context(
|
return context.Context(
|
||||||
context.Client(
|
context.Client(
|
||||||
("client", 1234),
|
("client", 1234),
|
||||||
|
Loading…
Reference in New Issue
Block a user