[sans-io] better represent half-closed connections

This commit is contained in:
Maximilian Hils 2019-11-05 20:17:02 +01:00
parent 9f075fbbbc
commit 16abce77ea
2 changed files with 54 additions and 30 deletions

View File

@ -1,21 +1,32 @@
import copy from enum import Flag, auto
from typing import Optional, List, Union, Sequence, Any from typing import List, Optional, Sequence, Union
from mitmproxy.options import Options from mitmproxy.options import Options
class ConnectionState(Flag):
CLOSED = 0
CAN_READ = auto()
CAN_WRITE = auto()
OPEN = CAN_READ | CAN_WRITE
class Connection: class Connection:
""" """
Connections exposed to the layers only contain metadata, no socket objects. Connections exposed to the layers only contain metadata, no socket objects.
""" """
address: tuple address: tuple
connected: bool = False state: ConnectionState
tls: bool = False tls: bool = False
tls_established: bool = False tls_established: bool = False
alpn: Optional[bytes] = None alpn: Optional[bytes] = None
alpn_offers: Sequence[bytes] = () alpn_offers: Sequence[bytes] = ()
sni: Union[bytes, bool, None] sni: Union[bytes, bool, None]
@property
def connected(self):
return self.state is ConnectionState.OPEN
def __repr__(self): def __repr__(self):
return f"{type(self).__name__}({repr(self.__dict__)})" return f"{type(self).__name__}({repr(self.__dict__)})"
@ -25,7 +36,7 @@ class Client(Connection):
def __init__(self, address): def __init__(self, address):
self.address = address self.address = address
self.connected = True self.state = ConnectionState.OPEN
class Server(Connection): class Server(Connection):
@ -35,6 +46,7 @@ class Server(Connection):
def __init__(self, address: Optional[tuple]): def __init__(self, address: Optional[tuple]):
self.address = address self.address = address
self.state = ConnectionState.CLOSED
class Context: class Context:

View File

@ -10,15 +10,17 @@ import abc
import asyncio import asyncio
import logging import logging
import socket import socket
import traceback
import typing import typing
from mitmproxy import http, options as moptions from mitmproxy import http, options as moptions
from mitmproxy.proxy.protocol.http import HTTPMode from mitmproxy.proxy.protocol.http import HTTPMode
from mitmproxy.proxy2 import commands, events, layer, layers from mitmproxy.proxy2 import commands, events, layer, layers
from mitmproxy.proxy2.context import Client, Connection, Context from mitmproxy.proxy2.context import Client, Connection, ConnectionState, Context
from mitmproxy.proxy2.layers import glue from mitmproxy.proxy2.layers import glue
from mitmproxy.utils import human from mitmproxy.utils import human
class StreamIO(typing.NamedTuple): class StreamIO(typing.NamedTuple):
r: asyncio.StreamReader r: asyncio.StreamReader
w: asyncio.StreamWriter w: asyncio.StreamWriter
@ -27,7 +29,7 @@ class StreamIO(typing.NamedTuple):
class ConnectionHandler(metaclass=abc.ABCMeta): class ConnectionHandler(metaclass=abc.ABCMeta):
transports: typing.MutableMapping[Connection, StreamIO] transports: typing.MutableMapping[Connection, StreamIO]
def __init__(self, reader, writer, options): def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, options: moptions.Options) -> None:
addr = writer.get_extra_info('peername') addr = writer.get_extra_info('peername')
self.client = Client(addr) self.client = Client(addr)
@ -43,7 +45,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
self.client: StreamIO(reader, writer) self.client: StreamIO(reader, writer)
} }
async def handle_client(self): async def handle_client(self) -> None:
# FIXME: Work around log suppression in core. # FIXME: Work around log suppression in core.
logging.getLogger('asyncio').setLevel(logging.DEBUG) logging.getLogger('asyncio').setLevel(logging.DEBUG)
@ -55,28 +57,29 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
self.log("[sans-io] clientdisconnect") self.log("[sans-io] clientdisconnect")
if self.transports: if self.transports:
self.log("[sans-io] closing transports...")
await asyncio.wait([ await asyncio.wait([
self.close_connection(x) self.close_connection(x)
for x in self.transports for x in self.transports
]) ])
# self._debug("transports closed!") self.log("[sans-io] transports closed!")
async def close_connection(self, connection): async def close_connection(self, connection: Connection) -> None:
try:
io = self.transports.pop(connection)
except KeyError:
self.log(f"already closed: {connection}", "warn")
return
else:
self.log(f"closing {connection}", "debug") self.log(f"closing {connection}", "debug")
try: connection.state = ConnectionState.CLOSED
await io.w.drain() io = self.transports.pop(connection)
io.w.write_eof()
except socket.error:
pass
io.w.close() io.w.close()
await io.w.wait_closed()
async def handle_connection(self, connection): async def shutdown_connection(self, connection: Connection) -> None:
assert connection.state & ConnectionState.CAN_WRITE
io = self.transports[connection]
self.log(f"shutting down {connection}", "debug")
io.w.write_eof()
connection.state &= ~ConnectionState.CAN_WRITE
async def handle_connection(self, connection: Connection) -> None:
reader, writer = self.transports[connection] reader, writer = self.transports[connection]
while True: while True:
try: try:
@ -86,15 +89,15 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
if data: if data:
self.server_event(events.DataReceived(connection, data)) self.server_event(events.DataReceived(connection, data))
else: else:
connection.connected = False if connection.state is ConnectionState.CAN_READ:
if connection in self.transports:
await self.close_connection(connection) await self.close_connection(connection)
self.server_event(events.ConnectionClosed(connection)) self.server_event(events.ConnectionClosed(connection))
break break
async def open_connection(self, command: commands.OpenConnection): async def open_connection(self, command: commands.OpenConnection) -> None:
if not command.connection.address: if not command.connection.address:
raise ValueError("Cannot open connection, no hostname given.") raise ValueError("Cannot open connection, no hostname given.")
assert command.connection not in self.transports
try: try:
reader, writer = await asyncio.open_connection( reader, writer = await asyncio.open_connection(
*command.connection.address *command.connection.address
@ -104,7 +107,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
else: else:
self.log("serverconnect") self.log("serverconnect")
self.transports[command.connection] = StreamIO(reader, writer) self.transports[command.connection] = StreamIO(reader, writer)
command.connection.connected = True command.connection.state = ConnectionState.OPEN
self.server_event(events.OpenConnectionReply(command, None)) self.server_event(events.OpenConnectionReply(command, None))
await self.handle_connection(command.connection) await self.handle_connection(command.connection)
self.log("serverdisconnect") self.log("serverdisconnect")
@ -117,7 +120,11 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
print(message) print(message)
def server_event(self, event: events.Event) -> None: def server_event(self, event: events.Event) -> None:
layer_commands = self.layer.handle_event(event) try:
layer_commands = list(self.layer.handle_event(event))
except Exception:
self.log(f"mitmproxy has crashed!\n{traceback.format_exc()}", level="error")
return
for command in layer_commands: for command in layer_commands:
if isinstance(command, commands.OpenConnection): if isinstance(command, commands.OpenConnection):
asyncio.ensure_future( asyncio.ensure_future(
@ -126,9 +133,14 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
elif isinstance(command, commands.SendData): elif isinstance(command, commands.SendData):
self.transports[command.connection].w.write(command.data) self.transports[command.connection].w.write(command.data)
elif isinstance(command, commands.CloseConnection): elif isinstance(command, commands.CloseConnection):
if command.connection == self.client:
asyncio.ensure_future( asyncio.ensure_future(
self.close_connection(command.connection) self.close_connection(command.connection)
) )
else:
asyncio.ensure_future(
self.shutdown_connection(command.connection)
)
elif isinstance(command, glue.GlueGetConnectionHandler): elif isinstance(command, glue.GlueGetConnectionHandler):
self.server_event(glue.GlueGetConnectionHandlerReply(command, self)) self.server_event(glue.GlueGetConnectionHandlerReply(command, self))
elif isinstance(command, commands.Hook): elif isinstance(command, commands.Hook):