[sans-io] more testing, more bugfixes!

This commit is contained in:
Maximilian Hils 2020-11-20 18:55:47 +01:00
parent 38f006eb9a
commit c639fafd64
8 changed files with 46 additions and 14 deletions

View File

@ -22,7 +22,10 @@ class AsyncReply(controller.Reply):
def commit(self):
super().commit()
try:
self.loop.call_soon_threadsafe(lambda: self.done.set())
except RuntimeError:
pass # event loop may already be closed.
def kill(self, force=False):
warnings.warn("reply.kill() is deprecated, set the error attribute instead.", PendingDeprecationWarning)
@ -49,9 +52,11 @@ class ProxyConnectionHandler(server.StreamConnectionHandler):
def log(self, message: str, level: str = "info") -> None:
x = log.LogEntry(self.log_prefix + message, level)
x.reply = controller.DummyReply()
asyncio.ensure_future(
self.master.addons.handle_lifecycle("log", x)
)
coro = self.master.addons.handle_lifecycle("log", x)
try:
asyncio.ensure_future(coro)
except RuntimeError:
coro.close() # event loop may already be closed, but we don't want a "has never been awaited error"
class Proxyserver:
@ -113,6 +118,7 @@ class Proxyserver:
self.server = None
async def handle_connection(self, r, w):
asyncio.current_task().set_name(f"proxy connection handler {w.get_extra_info('peername')}")
handler = ProxyConnectionHandler(
self.master,
r,

View File

@ -62,8 +62,12 @@ class Layer:
return f"{type(self).__name__}({state})"
def __debug(self, message):
if len(message) > 512:
message = message[:512] + ""
if Layer.__last_debug_message == message:
message = message.split("\n", 1)[0].strip()
if len(message) > 256:
message = message[:256] + ""
else:
Layer.__last_debug_message = message
return commands.Log(

View File

@ -142,6 +142,7 @@ class Http1Server(Http1Connection):
elif isinstance(event, ResponseEndOfMessage):
if "chunked" in self.response.headers.get("transfer-encoding", "").lower():
yield commands.SendData(self.conn, b"0\r\n\r\n")
yield from self.mark_done(response=True)
elif http1.expected_http_body_size(self.request, self.response) == -1:
yield commands.CloseConnection(self.conn)
elif self.request.first_line_format != "authority":

View File

@ -67,8 +67,12 @@ class Http2Connection(HttpConnection):
self.h2_conn.send_data(event.stream_id, event.data)
elif isinstance(event, self.SendEndOfMessage):
self.h2_conn.send_data(event.stream_id, b"", end_stream=True)
if self.h2_conn.streams.get(event.stream_id).closed:
self.streams.pop(event.stream_id, None)
elif isinstance(event, self.SendProtocolError):
self.h2_conn.reset_stream(event.stream_id, h2.errors.ErrorCodes.PROTOCOL_ERROR)
if self.h2_conn.streams.get(event.stream_id).closed:
self.streams.pop(event.stream_id, None)
else:
raise AssertionError(f"Unexpected event: {event}")
yield SendData(self.conn, self.h2_conn.data_to_send())
@ -119,10 +123,14 @@ class Http2Connection(HttpConnection):
yield ReceiveHttp(self.ReceiveEndOfMessage(event.stream_id))
elif state is StreamState.EXPECTING_HEADERS:
raise AssertionError("unreachable")
if self.h2_conn.streams.get(event.stream_id).closed:
self.streams.pop(event.stream_id, None)
elif isinstance(event, h2.events.StreamReset):
if event.stream_id in self.streams:
yield ReceiveHttp(self.ReceiveProtocolError(event.stream_id, "Stream reset"))
yield ReceiveHttp(self.ReceiveProtocolError(event.stream_id, f"Stream reset, error code {event.error_code}"))
self.streams.pop(event.stream_id)
else:
pass # We don't track priority frames which could be followed by a stream reset here.
elif isinstance(event, h2.exceptions.ProtocolError):
yield from self.protocol_error(f"HTTP/2 protocol error: {event}")
return True
@ -282,6 +290,11 @@ class Http2Client(Http2Connection):
]
if event.request.authority:
pseudo_headers.append((b":authority", event.request.data.authority))
elif not event.request.is_http2:
host_header = event.request.headers.pop("host", None)
if host_header:
pseudo_headers.append((b":authority", host_header))
headers = pseudo_headers + list(event.request.headers.fields)
if not event.request.is_http2:
headers = normalize_h1_headers(headers, True)

View File

@ -174,7 +174,7 @@ class _TLSLayer(tunnel.TunnelLayer):
elif last_err == ('SSL routines', 'ssl3_get_record', 'wrong version number') and data[:4].isascii():
err = f"The remote server does not speak TLS."
else:
err = repr(e)
err = f"OpenSSL {e!r}"
return False, err
else:
# Get all peer certificates.
@ -194,7 +194,8 @@ class _TLSLayer(tunnel.TunnelLayer):
self.conn.cipher = self.tls.get_cipher_name()
self.conn.cipher_list = self.tls.get_cipher_list()
self.conn.tls_version = self.tls.get_protocol_version_name()
yield commands.Log(f"TLS established: {self.conn}", "debug")
if self.debug:
yield commands.Log(f"{self.debug}[tls] tls established: {self.conn}", "debug")
yield from self.receive_data(b"")
return True, None

View File

@ -218,7 +218,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
if isinstance(command, commands.OpenConnection):
assert command.connection not in self.transports
handler = asyncio.create_task(
self.open_connection(command)
self.open_connection(command),
name=f"open_connection {command.connection.address}"
)
self.transports[command.connection] = ConnectionIO(handler=handler)
elif isinstance(command, commands.ConnectionCommand) and command.connection not in self.transports:
@ -231,7 +232,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
socket = self.transports[command.connection].writer.get_extra_info("socket")
self.server_event(events.GetSocketReply(command, socket))
elif isinstance(command, commands.Hook):
asyncio.create_task(self.hook_task(command))
asyncio.create_task(self.hook_task(command), name=f"hook {command.name}")
elif isinstance(command, commands.Log):
self.log(command.message, command.level)
else:

View File

@ -40,7 +40,10 @@ class TunnelLayer(layer.Layer):
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, events.Start):
if self.tunnel_connection.connected:
if self.tunnel_connection.state is not context.ConnectionState.CLOSED:
# we might be in the interesting state here where the connection is already half-closed,
# for example because next_layer buffered events and the client disconnected in the meantime.
# we still expect a close event to arrive, so we carry on here as normal for now.
self.tunnel_state = TunnelState.ESTABLISHING
yield from self.start_handshake()
yield from self.event_to_child(event)
@ -62,7 +65,7 @@ class TunnelLayer(layer.Layer):
yield from self.receive_data(event.data)
elif isinstance(event, events.ConnectionClosed):
if self.conn != self.tunnel_connection:
self.conn.state &= ~context.ConnectionState.CAN_READ
self.conn.state = context.ConnectionState.CLOSED
if self.tunnel_state is TunnelState.OPEN:
yield from self.receive_close()
elif self.tunnel_state is TunnelState.ESTABLISHING:
@ -79,7 +82,8 @@ class TunnelLayer(layer.Layer):
yield from self.send_data(command.data)
elif isinstance(command, commands.CloseConnection):
if self.conn != self.tunnel_connection:
self.conn.state &= ~context.ConnectionState.CAN_WRITE
# we don't have a use case for distinguishing between read/write here
self.conn.state = context.ConnectionState.CLOSED
yield from self.send_close()
elif isinstance(command, commands.OpenConnection):
# create our own OpenConnection command object that blocks here.

View File

@ -1,7 +1,8 @@
import pytest
from mitmproxy import options
from mitmproxy import log, options
from mitmproxy.addons.proxyserver import Proxyserver
from mitmproxy.addons.termlog import TermLog
from mitmproxy.proxy2 import context
@ -9,6 +10,7 @@ from mitmproxy.proxy2 import context
def tctx() -> context.Context:
opts = options.Options()
Proxyserver().load(opts)
TermLog().load(opts)
return context.Context(
context.Client(
("client", 1234),