[sans-io] more testing, more bugfixes!

This commit is contained in:
Maximilian Hils 2020-11-20 18:55:47 +01:00
parent 38f006eb9a
commit c639fafd64
8 changed files with 46 additions and 14 deletions

View File

@ -22,7 +22,10 @@ class AsyncReply(controller.Reply):
def commit(self): def commit(self):
super().commit() super().commit()
try:
self.loop.call_soon_threadsafe(lambda: self.done.set()) self.loop.call_soon_threadsafe(lambda: self.done.set())
except RuntimeError:
pass # event loop may already be closed.
def kill(self, force=False): def kill(self, force=False):
warnings.warn("reply.kill() is deprecated, set the error attribute instead.", PendingDeprecationWarning) warnings.warn("reply.kill() is deprecated, set the error attribute instead.", PendingDeprecationWarning)
@ -49,9 +52,11 @@ class ProxyConnectionHandler(server.StreamConnectionHandler):
def log(self, message: str, level: str = "info") -> None: def log(self, message: str, level: str = "info") -> None:
x = log.LogEntry(self.log_prefix + message, level) x = log.LogEntry(self.log_prefix + message, level)
x.reply = controller.DummyReply() x.reply = controller.DummyReply()
asyncio.ensure_future( coro = self.master.addons.handle_lifecycle("log", x)
self.master.addons.handle_lifecycle("log", x) try:
) asyncio.ensure_future(coro)
except RuntimeError:
coro.close() # event loop may already be closed, but we don't want a "has never been awaited error"
class Proxyserver: class Proxyserver:
@ -113,6 +118,7 @@ class Proxyserver:
self.server = None self.server = None
async def handle_connection(self, r, w): async def handle_connection(self, r, w):
asyncio.current_task().set_name(f"proxy connection handler {w.get_extra_info('peername')}")
handler = ProxyConnectionHandler( handler = ProxyConnectionHandler(
self.master, self.master,
r, r,

View File

@ -62,8 +62,12 @@ class Layer:
return f"{type(self).__name__}({state})" return f"{type(self).__name__}({state})"
def __debug(self, message): def __debug(self, message):
if len(message) > 512:
message = message[:512] + ""
if Layer.__last_debug_message == message: if Layer.__last_debug_message == message:
message = message.split("\n", 1)[0].strip() message = message.split("\n", 1)[0].strip()
if len(message) > 256:
message = message[:256] + ""
else: else:
Layer.__last_debug_message = message Layer.__last_debug_message = message
return commands.Log( return commands.Log(

View File

@ -142,6 +142,7 @@ class Http1Server(Http1Connection):
elif isinstance(event, ResponseEndOfMessage): elif isinstance(event, ResponseEndOfMessage):
if "chunked" in self.response.headers.get("transfer-encoding", "").lower(): if "chunked" in self.response.headers.get("transfer-encoding", "").lower():
yield commands.SendData(self.conn, b"0\r\n\r\n") yield commands.SendData(self.conn, b"0\r\n\r\n")
yield from self.mark_done(response=True)
elif http1.expected_http_body_size(self.request, self.response) == -1: elif http1.expected_http_body_size(self.request, self.response) == -1:
yield commands.CloseConnection(self.conn) yield commands.CloseConnection(self.conn)
elif self.request.first_line_format != "authority": elif self.request.first_line_format != "authority":

View File

@ -67,8 +67,12 @@ class Http2Connection(HttpConnection):
self.h2_conn.send_data(event.stream_id, event.data) self.h2_conn.send_data(event.stream_id, event.data)
elif isinstance(event, self.SendEndOfMessage): elif isinstance(event, self.SendEndOfMessage):
self.h2_conn.send_data(event.stream_id, b"", end_stream=True) self.h2_conn.send_data(event.stream_id, b"", end_stream=True)
if self.h2_conn.streams.get(event.stream_id).closed:
self.streams.pop(event.stream_id, None)
elif isinstance(event, self.SendProtocolError): elif isinstance(event, self.SendProtocolError):
self.h2_conn.reset_stream(event.stream_id, h2.errors.ErrorCodes.PROTOCOL_ERROR) self.h2_conn.reset_stream(event.stream_id, h2.errors.ErrorCodes.PROTOCOL_ERROR)
if self.h2_conn.streams.get(event.stream_id).closed:
self.streams.pop(event.stream_id, None)
else: else:
raise AssertionError(f"Unexpected event: {event}") raise AssertionError(f"Unexpected event: {event}")
yield SendData(self.conn, self.h2_conn.data_to_send()) yield SendData(self.conn, self.h2_conn.data_to_send())
@ -119,10 +123,14 @@ class Http2Connection(HttpConnection):
yield ReceiveHttp(self.ReceiveEndOfMessage(event.stream_id)) yield ReceiveHttp(self.ReceiveEndOfMessage(event.stream_id))
elif state is StreamState.EXPECTING_HEADERS: elif state is StreamState.EXPECTING_HEADERS:
raise AssertionError("unreachable") raise AssertionError("unreachable")
if self.h2_conn.streams.get(event.stream_id).closed:
self.streams.pop(event.stream_id, None) self.streams.pop(event.stream_id, None)
elif isinstance(event, h2.events.StreamReset): elif isinstance(event, h2.events.StreamReset):
if event.stream_id in self.streams: if event.stream_id in self.streams:
yield ReceiveHttp(self.ReceiveProtocolError(event.stream_id, "Stream reset")) yield ReceiveHttp(self.ReceiveProtocolError(event.stream_id, f"Stream reset, error code {event.error_code}"))
self.streams.pop(event.stream_id)
else:
pass # We don't track priority frames which could be followed by a stream reset here.
elif isinstance(event, h2.exceptions.ProtocolError): elif isinstance(event, h2.exceptions.ProtocolError):
yield from self.protocol_error(f"HTTP/2 protocol error: {event}") yield from self.protocol_error(f"HTTP/2 protocol error: {event}")
return True return True
@ -282,6 +290,11 @@ class Http2Client(Http2Connection):
] ]
if event.request.authority: if event.request.authority:
pseudo_headers.append((b":authority", event.request.data.authority)) pseudo_headers.append((b":authority", event.request.data.authority))
elif not event.request.is_http2:
host_header = event.request.headers.pop("host", None)
if host_header:
pseudo_headers.append((b":authority", host_header))
headers = pseudo_headers + list(event.request.headers.fields) headers = pseudo_headers + list(event.request.headers.fields)
if not event.request.is_http2: if not event.request.is_http2:
headers = normalize_h1_headers(headers, True) headers = normalize_h1_headers(headers, True)

View File

@ -174,7 +174,7 @@ class _TLSLayer(tunnel.TunnelLayer):
elif last_err == ('SSL routines', 'ssl3_get_record', 'wrong version number') and data[:4].isascii(): elif last_err == ('SSL routines', 'ssl3_get_record', 'wrong version number') and data[:4].isascii():
err = f"The remote server does not speak TLS." err = f"The remote server does not speak TLS."
else: else:
err = repr(e) err = f"OpenSSL {e!r}"
return False, err return False, err
else: else:
# Get all peer certificates. # Get all peer certificates.
@ -194,7 +194,8 @@ class _TLSLayer(tunnel.TunnelLayer):
self.conn.cipher = self.tls.get_cipher_name() self.conn.cipher = self.tls.get_cipher_name()
self.conn.cipher_list = self.tls.get_cipher_list() self.conn.cipher_list = self.tls.get_cipher_list()
self.conn.tls_version = self.tls.get_protocol_version_name() self.conn.tls_version = self.tls.get_protocol_version_name()
yield commands.Log(f"TLS established: {self.conn}", "debug") if self.debug:
yield commands.Log(f"{self.debug}[tls] tls established: {self.conn}", "debug")
yield from self.receive_data(b"") yield from self.receive_data(b"")
return True, None return True, None

View File

@ -218,7 +218,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
if isinstance(command, commands.OpenConnection): if isinstance(command, commands.OpenConnection):
assert command.connection not in self.transports assert command.connection not in self.transports
handler = asyncio.create_task( handler = asyncio.create_task(
self.open_connection(command) self.open_connection(command),
name=f"open_connection {command.connection.address}"
) )
self.transports[command.connection] = ConnectionIO(handler=handler) self.transports[command.connection] = ConnectionIO(handler=handler)
elif isinstance(command, commands.ConnectionCommand) and command.connection not in self.transports: elif isinstance(command, commands.ConnectionCommand) and command.connection not in self.transports:
@ -231,7 +232,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
socket = self.transports[command.connection].writer.get_extra_info("socket") socket = self.transports[command.connection].writer.get_extra_info("socket")
self.server_event(events.GetSocketReply(command, socket)) self.server_event(events.GetSocketReply(command, socket))
elif isinstance(command, commands.Hook): elif isinstance(command, commands.Hook):
asyncio.create_task(self.hook_task(command)) asyncio.create_task(self.hook_task(command), name=f"hook {command.name}")
elif isinstance(command, commands.Log): elif isinstance(command, commands.Log):
self.log(command.message, command.level) self.log(command.message, command.level)
else: else:

View File

@ -40,7 +40,10 @@ class TunnelLayer(layer.Layer):
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, events.Start): if isinstance(event, events.Start):
if self.tunnel_connection.connected: if self.tunnel_connection.state is not context.ConnectionState.CLOSED:
# we might be in the interesting state here where the connection is already half-closed,
# for example because next_layer buffered events and the client disconnected in the meantime.
# we still expect a close event to arrive, so we carry on here as normal for now.
self.tunnel_state = TunnelState.ESTABLISHING self.tunnel_state = TunnelState.ESTABLISHING
yield from self.start_handshake() yield from self.start_handshake()
yield from self.event_to_child(event) yield from self.event_to_child(event)
@ -62,7 +65,7 @@ class TunnelLayer(layer.Layer):
yield from self.receive_data(event.data) yield from self.receive_data(event.data)
elif isinstance(event, events.ConnectionClosed): elif isinstance(event, events.ConnectionClosed):
if self.conn != self.tunnel_connection: if self.conn != self.tunnel_connection:
self.conn.state &= ~context.ConnectionState.CAN_READ self.conn.state = context.ConnectionState.CLOSED
if self.tunnel_state is TunnelState.OPEN: if self.tunnel_state is TunnelState.OPEN:
yield from self.receive_close() yield from self.receive_close()
elif self.tunnel_state is TunnelState.ESTABLISHING: elif self.tunnel_state is TunnelState.ESTABLISHING:
@ -79,7 +82,8 @@ class TunnelLayer(layer.Layer):
yield from self.send_data(command.data) yield from self.send_data(command.data)
elif isinstance(command, commands.CloseConnection): elif isinstance(command, commands.CloseConnection):
if self.conn != self.tunnel_connection: if self.conn != self.tunnel_connection:
self.conn.state &= ~context.ConnectionState.CAN_WRITE # we don't have a use case for distinguishing between read/write here
self.conn.state = context.ConnectionState.CLOSED
yield from self.send_close() yield from self.send_close()
elif isinstance(command, commands.OpenConnection): elif isinstance(command, commands.OpenConnection):
# create our own OpenConnection command object that blocks here. # create our own OpenConnection command object that blocks here.

View File

@ -1,7 +1,8 @@
import pytest import pytest
from mitmproxy import options from mitmproxy import log, options
from mitmproxy.addons.proxyserver import Proxyserver from mitmproxy.addons.proxyserver import Proxyserver
from mitmproxy.addons.termlog import TermLog
from mitmproxy.proxy2 import context from mitmproxy.proxy2 import context
@ -9,6 +10,7 @@ from mitmproxy.proxy2 import context
def tctx() -> context.Context: def tctx() -> context.Context:
opts = options.Options() opts = options.Options()
Proxyserver().load(opts) Proxyserver().load(opts)
TermLog().load(opts)
return context.Context( return context.Context(
context.Client( context.Client(
("client", 1234), ("client", 1234),