[sans-io] better represent half-closed connections

This commit is contained in:
Maximilian Hils 2019-11-05 20:17:02 +01:00
parent 9f075fbbbc
commit 16abce77ea
2 changed files with 54 additions and 30 deletions

View File

@ -1,21 +1,32 @@
import copy
from typing import Optional, List, Union, Sequence, Any
from enum import Flag, auto
from typing import List, Optional, Sequence, Union
from mitmproxy.options import Options
class ConnectionState(Flag):
CLOSED = 0
CAN_READ = auto()
CAN_WRITE = auto()
OPEN = CAN_READ | CAN_WRITE
class Connection:
"""
Connections exposed to the layers only contain metadata, no socket objects.
"""
address: tuple
connected: bool = False
state: ConnectionState
tls: bool = False
tls_established: bool = False
alpn: Optional[bytes] = None
alpn_offers: Sequence[bytes] = ()
sni: Union[bytes, bool, None]
@property
def connected(self):
return self.state is ConnectionState.OPEN
def __repr__(self):
return f"{type(self).__name__}({repr(self.__dict__)})"
@ -25,7 +36,7 @@ class Client(Connection):
def __init__(self, address):
self.address = address
self.connected = True
self.state = ConnectionState.OPEN
class Server(Connection):
@ -35,6 +46,7 @@ class Server(Connection):
def __init__(self, address: Optional[tuple]):
self.address = address
self.state = ConnectionState.CLOSED
class Context:

View File

@ -10,15 +10,17 @@ import abc
import asyncio
import logging
import socket
import traceback
import typing
from mitmproxy import http, options as moptions
from mitmproxy.proxy.protocol.http import HTTPMode
from mitmproxy.proxy2 import commands, events, layer, layers
from mitmproxy.proxy2.context import Client, Connection, Context
from mitmproxy.proxy2.context import Client, Connection, ConnectionState, Context
from mitmproxy.proxy2.layers import glue
from mitmproxy.utils import human
class StreamIO(typing.NamedTuple):
r: asyncio.StreamReader
w: asyncio.StreamWriter
@ -27,7 +29,7 @@ class StreamIO(typing.NamedTuple):
class ConnectionHandler(metaclass=abc.ABCMeta):
transports: typing.MutableMapping[Connection, StreamIO]
def __init__(self, reader, writer, options):
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, options: moptions.Options) -> None:
addr = writer.get_extra_info('peername')
self.client = Client(addr)
@ -43,7 +45,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
self.client: StreamIO(reader, writer)
}
async def handle_client(self):
async def handle_client(self) -> None:
# FIXME: Work around log suppression in core.
logging.getLogger('asyncio').setLevel(logging.DEBUG)
@ -55,28 +57,29 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
self.log("[sans-io] clientdisconnect")
if self.transports:
self.log("[sans-io] closing transports...")
await asyncio.wait([
self.close_connection(x)
for x in self.transports
])
# self._debug("transports closed!")
self.log("[sans-io] transports closed!")
async def close_connection(self, connection):
try:
io = self.transports.pop(connection)
except KeyError:
self.log(f"already closed: {connection}", "warn")
return
else:
async def close_connection(self, connection: Connection) -> None:
self.log(f"closing {connection}", "debug")
try:
await io.w.drain()
io.w.write_eof()
except socket.error:
pass
connection.state = ConnectionState.CLOSED
io = self.transports.pop(connection)
io.w.close()
await io.w.wait_closed()
async def handle_connection(self, connection):
async def shutdown_connection(self, connection: Connection) -> None:
assert connection.state & ConnectionState.CAN_WRITE
io = self.transports[connection]
self.log(f"shutting down {connection}", "debug")
io.w.write_eof()
connection.state &= ~ConnectionState.CAN_WRITE
async def handle_connection(self, connection: Connection) -> None:
reader, writer = self.transports[connection]
while True:
try:
@ -86,15 +89,15 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
if data:
self.server_event(events.DataReceived(connection, data))
else:
connection.connected = False
if connection in self.transports:
if connection.state is ConnectionState.CAN_READ:
await self.close_connection(connection)
self.server_event(events.ConnectionClosed(connection))
break
async def open_connection(self, command: commands.OpenConnection):
async def open_connection(self, command: commands.OpenConnection) -> None:
if not command.connection.address:
raise ValueError("Cannot open connection, no hostname given.")
assert command.connection not in self.transports
try:
reader, writer = await asyncio.open_connection(
*command.connection.address
@ -104,7 +107,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
else:
self.log("serverconnect")
self.transports[command.connection] = StreamIO(reader, writer)
command.connection.connected = True
command.connection.state = ConnectionState.OPEN
self.server_event(events.OpenConnectionReply(command, None))
await self.handle_connection(command.connection)
self.log("serverdisconnect")
@ -117,7 +120,11 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
print(message)
def server_event(self, event: events.Event) -> None:
layer_commands = self.layer.handle_event(event)
try:
layer_commands = list(self.layer.handle_event(event))
except Exception:
self.log(f"mitmproxy has crashed!\n{traceback.format_exc()}", level="error")
return
for command in layer_commands:
if isinstance(command, commands.OpenConnection):
asyncio.ensure_future(
@ -126,9 +133,14 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
elif isinstance(command, commands.SendData):
self.transports[command.connection].w.write(command.data)
elif isinstance(command, commands.CloseConnection):
if command.connection == self.client:
asyncio.ensure_future(
self.close_connection(command.connection)
)
else:
asyncio.ensure_future(
self.shutdown_connection(command.connection)
)
elif isinstance(command, glue.GlueGetConnectionHandler):
self.server_event(glue.GlueGetConnectionHandlerReply(command, self))
elif isinstance(command, commands.Hook):