diff --git a/mitmproxy/proxy2/context.py b/mitmproxy/proxy2/context.py index 5ee2fc118..69d33dace 100644 --- a/mitmproxy/proxy2/context.py +++ b/mitmproxy/proxy2/context.py @@ -1,21 +1,32 @@ -import copy -from typing import Optional, List, Union, Sequence, Any +from enum import Flag, auto +from typing import List, Optional, Sequence, Union from mitmproxy.options import Options +class ConnectionState(Flag): + CLOSED = 0 + CAN_READ = auto() + CAN_WRITE = auto() + OPEN = CAN_READ | CAN_WRITE + + class Connection: """ Connections exposed to the layers only contain metadata, no socket objects. """ address: tuple - connected: bool = False + state: ConnectionState tls: bool = False tls_established: bool = False alpn: Optional[bytes] = None alpn_offers: Sequence[bytes] = () sni: Union[bytes, bool, None] + @property + def connected(self): + return self.state is ConnectionState.OPEN + def __repr__(self): return f"{type(self).__name__}({repr(self.__dict__)})" @@ -25,7 +36,7 @@ class Client(Connection): def __init__(self, address): self.address = address - self.connected = True + self.state = ConnectionState.OPEN class Server(Connection): @@ -35,6 +46,7 @@ class Server(Connection): def __init__(self, address: Optional[tuple]): self.address = address + self.state = ConnectionState.CLOSED class Context: diff --git a/mitmproxy/proxy2/server.py b/mitmproxy/proxy2/server.py index 0ad029311..ab5ee66b4 100644 --- a/mitmproxy/proxy2/server.py +++ b/mitmproxy/proxy2/server.py @@ -10,15 +10,17 @@ import abc import asyncio import logging import socket +import traceback import typing from mitmproxy import http, options as moptions from mitmproxy.proxy.protocol.http import HTTPMode from mitmproxy.proxy2 import commands, events, layer, layers -from mitmproxy.proxy2.context import Client, Connection, Context +from mitmproxy.proxy2.context import Client, Connection, ConnectionState, Context from mitmproxy.proxy2.layers import glue from mitmproxy.utils import human + class StreamIO(typing.NamedTuple): r: asyncio.StreamReader w: asyncio.StreamWriter @@ -27,7 +29,7 @@ class StreamIO(typing.NamedTuple): class ConnectionHandler(metaclass=abc.ABCMeta): transports: typing.MutableMapping[Connection, StreamIO] - def __init__(self, reader, writer, options): + def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, options: moptions.Options) -> None: addr = writer.get_extra_info('peername') self.client = Client(addr) @@ -43,7 +45,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta): self.client: StreamIO(reader, writer) } - async def handle_client(self): + async def handle_client(self) -> None: # FIXME: Work around log suppression in core. logging.getLogger('asyncio').setLevel(logging.DEBUG) @@ -55,28 +57,29 @@ class ConnectionHandler(metaclass=abc.ABCMeta): self.log("[sans-io] clientdisconnect") if self.transports: + self.log("[sans-io] closing transports...") await asyncio.wait([ self.close_connection(x) for x in self.transports ]) - # self._debug("transports closed!") + self.log("[sans-io] transports closed!") - async def close_connection(self, connection): - try: - io = self.transports.pop(connection) - except KeyError: - self.log(f"already closed: {connection}", "warn") - return - else: - self.log(f"closing {connection}", "debug") - try: - await io.w.drain() - io.w.write_eof() - except socket.error: - pass + async def close_connection(self, connection: Connection) -> None: + self.log(f"closing {connection}", "debug") + connection.state = ConnectionState.CLOSED + io = self.transports.pop(connection) io.w.close() + await io.w.wait_closed() - async def handle_connection(self, connection): + async def shutdown_connection(self, connection: Connection) -> None: + assert connection.state & ConnectionState.CAN_WRITE + io = self.transports[connection] + self.log(f"shutting down {connection}", "debug") + + io.w.write_eof() + connection.state &= ~ConnectionState.CAN_WRITE + + async def handle_connection(self, connection: Connection) -> None: reader, writer = self.transports[connection] while True: try: @@ -86,15 +89,15 @@ class ConnectionHandler(metaclass=abc.ABCMeta): if data: self.server_event(events.DataReceived(connection, data)) else: - connection.connected = False - if connection in self.transports: + if connection.state is ConnectionState.CAN_READ: await self.close_connection(connection) self.server_event(events.ConnectionClosed(connection)) break - async def open_connection(self, command: commands.OpenConnection): + async def open_connection(self, command: commands.OpenConnection) -> None: if not command.connection.address: raise ValueError("Cannot open connection, no hostname given.") + assert command.connection not in self.transports try: reader, writer = await asyncio.open_connection( *command.connection.address @@ -104,7 +107,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta): else: self.log("serverconnect") self.transports[command.connection] = StreamIO(reader, writer) - command.connection.connected = True + command.connection.state = ConnectionState.OPEN self.server_event(events.OpenConnectionReply(command, None)) await self.handle_connection(command.connection) self.log("serverdisconnect") @@ -117,7 +120,11 @@ class ConnectionHandler(metaclass=abc.ABCMeta): print(message) def server_event(self, event: events.Event) -> None: - layer_commands = self.layer.handle_event(event) + try: + layer_commands = list(self.layer.handle_event(event)) + except Exception: + self.log(f"mitmproxy has crashed!\n{traceback.format_exc()}", level="error") + return for command in layer_commands: if isinstance(command, commands.OpenConnection): asyncio.ensure_future( @@ -126,9 +133,14 @@ class ConnectionHandler(metaclass=abc.ABCMeta): elif isinstance(command, commands.SendData): self.transports[command.connection].w.write(command.data) elif isinstance(command, commands.CloseConnection): - asyncio.ensure_future( - self.close_connection(command.connection) - ) + if command.connection == self.client: + asyncio.ensure_future( + self.close_connection(command.connection) + ) + else: + asyncio.ensure_future( + self.shutdown_connection(command.connection) + ) elif isinstance(command, glue.GlueGetConnectionHandler): self.server_event(glue.GlueGetConnectionHandlerReply(command, self)) elif isinstance(command, commands.Hook):