mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-22 15:37:45 +00:00
[sans-io] better represent half-closed connections
This commit is contained in:
parent
9f075fbbbc
commit
16abce77ea
@ -1,21 +1,32 @@
|
||||
import copy
|
||||
from typing import Optional, List, Union, Sequence, Any
|
||||
from enum import Flag, auto
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
from mitmproxy.options import Options
|
||||
|
||||
|
||||
class ConnectionState(Flag):
|
||||
CLOSED = 0
|
||||
CAN_READ = auto()
|
||||
CAN_WRITE = auto()
|
||||
OPEN = CAN_READ | CAN_WRITE
|
||||
|
||||
|
||||
class Connection:
|
||||
"""
|
||||
Connections exposed to the layers only contain metadata, no socket objects.
|
||||
"""
|
||||
address: tuple
|
||||
connected: bool = False
|
||||
state: ConnectionState
|
||||
tls: bool = False
|
||||
tls_established: bool = False
|
||||
alpn: Optional[bytes] = None
|
||||
alpn_offers: Sequence[bytes] = ()
|
||||
sni: Union[bytes, bool, None]
|
||||
|
||||
@property
|
||||
def connected(self):
|
||||
return self.state is ConnectionState.OPEN
|
||||
|
||||
def __repr__(self):
|
||||
return f"{type(self).__name__}({repr(self.__dict__)})"
|
||||
|
||||
@ -25,7 +36,7 @@ class Client(Connection):
|
||||
|
||||
def __init__(self, address):
|
||||
self.address = address
|
||||
self.connected = True
|
||||
self.state = ConnectionState.OPEN
|
||||
|
||||
|
||||
class Server(Connection):
|
||||
@ -35,6 +46,7 @@ class Server(Connection):
|
||||
|
||||
def __init__(self, address: Optional[tuple]):
|
||||
self.address = address
|
||||
self.state = ConnectionState.CLOSED
|
||||
|
||||
|
||||
class Context:
|
||||
|
@ -10,15 +10,17 @@ import abc
|
||||
import asyncio
|
||||
import logging
|
||||
import socket
|
||||
import traceback
|
||||
import typing
|
||||
|
||||
from mitmproxy import http, options as moptions
|
||||
from mitmproxy.proxy.protocol.http import HTTPMode
|
||||
from mitmproxy.proxy2 import commands, events, layer, layers
|
||||
from mitmproxy.proxy2.context import Client, Connection, Context
|
||||
from mitmproxy.proxy2.context import Client, Connection, ConnectionState, Context
|
||||
from mitmproxy.proxy2.layers import glue
|
||||
from mitmproxy.utils import human
|
||||
|
||||
|
||||
class StreamIO(typing.NamedTuple):
|
||||
r: asyncio.StreamReader
|
||||
w: asyncio.StreamWriter
|
||||
@ -27,7 +29,7 @@ class StreamIO(typing.NamedTuple):
|
||||
class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
transports: typing.MutableMapping[Connection, StreamIO]
|
||||
|
||||
def __init__(self, reader, writer, options):
|
||||
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, options: moptions.Options) -> None:
|
||||
addr = writer.get_extra_info('peername')
|
||||
|
||||
self.client = Client(addr)
|
||||
@ -43,7 +45,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
self.client: StreamIO(reader, writer)
|
||||
}
|
||||
|
||||
async def handle_client(self):
|
||||
async def handle_client(self) -> None:
|
||||
# FIXME: Work around log suppression in core.
|
||||
logging.getLogger('asyncio').setLevel(logging.DEBUG)
|
||||
|
||||
@ -55,28 +57,29 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
self.log("[sans-io] clientdisconnect")
|
||||
|
||||
if self.transports:
|
||||
self.log("[sans-io] closing transports...")
|
||||
await asyncio.wait([
|
||||
self.close_connection(x)
|
||||
for x in self.transports
|
||||
])
|
||||
# self._debug("transports closed!")
|
||||
self.log("[sans-io] transports closed!")
|
||||
|
||||
async def close_connection(self, connection):
|
||||
try:
|
||||
io = self.transports.pop(connection)
|
||||
except KeyError:
|
||||
self.log(f"already closed: {connection}", "warn")
|
||||
return
|
||||
else:
|
||||
self.log(f"closing {connection}", "debug")
|
||||
try:
|
||||
await io.w.drain()
|
||||
io.w.write_eof()
|
||||
except socket.error:
|
||||
pass
|
||||
async def close_connection(self, connection: Connection) -> None:
|
||||
self.log(f"closing {connection}", "debug")
|
||||
connection.state = ConnectionState.CLOSED
|
||||
io = self.transports.pop(connection)
|
||||
io.w.close()
|
||||
await io.w.wait_closed()
|
||||
|
||||
async def handle_connection(self, connection):
|
||||
async def shutdown_connection(self, connection: Connection) -> None:
|
||||
assert connection.state & ConnectionState.CAN_WRITE
|
||||
io = self.transports[connection]
|
||||
self.log(f"shutting down {connection}", "debug")
|
||||
|
||||
io.w.write_eof()
|
||||
connection.state &= ~ConnectionState.CAN_WRITE
|
||||
|
||||
async def handle_connection(self, connection: Connection) -> None:
|
||||
reader, writer = self.transports[connection]
|
||||
while True:
|
||||
try:
|
||||
@ -86,15 +89,15 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
if data:
|
||||
self.server_event(events.DataReceived(connection, data))
|
||||
else:
|
||||
connection.connected = False
|
||||
if connection in self.transports:
|
||||
if connection.state is ConnectionState.CAN_READ:
|
||||
await self.close_connection(connection)
|
||||
self.server_event(events.ConnectionClosed(connection))
|
||||
break
|
||||
|
||||
async def open_connection(self, command: commands.OpenConnection):
|
||||
async def open_connection(self, command: commands.OpenConnection) -> None:
|
||||
if not command.connection.address:
|
||||
raise ValueError("Cannot open connection, no hostname given.")
|
||||
assert command.connection not in self.transports
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection(
|
||||
*command.connection.address
|
||||
@ -104,7 +107,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
else:
|
||||
self.log("serverconnect")
|
||||
self.transports[command.connection] = StreamIO(reader, writer)
|
||||
command.connection.connected = True
|
||||
command.connection.state = ConnectionState.OPEN
|
||||
self.server_event(events.OpenConnectionReply(command, None))
|
||||
await self.handle_connection(command.connection)
|
||||
self.log("serverdisconnect")
|
||||
@ -117,7 +120,11 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
print(message)
|
||||
|
||||
def server_event(self, event: events.Event) -> None:
|
||||
layer_commands = self.layer.handle_event(event)
|
||||
try:
|
||||
layer_commands = list(self.layer.handle_event(event))
|
||||
except Exception:
|
||||
self.log(f"mitmproxy has crashed!\n{traceback.format_exc()}", level="error")
|
||||
return
|
||||
for command in layer_commands:
|
||||
if isinstance(command, commands.OpenConnection):
|
||||
asyncio.ensure_future(
|
||||
@ -126,9 +133,14 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
elif isinstance(command, commands.SendData):
|
||||
self.transports[command.connection].w.write(command.data)
|
||||
elif isinstance(command, commands.CloseConnection):
|
||||
asyncio.ensure_future(
|
||||
self.close_connection(command.connection)
|
||||
)
|
||||
if command.connection == self.client:
|
||||
asyncio.ensure_future(
|
||||
self.close_connection(command.connection)
|
||||
)
|
||||
else:
|
||||
asyncio.ensure_future(
|
||||
self.shutdown_connection(command.connection)
|
||||
)
|
||||
elif isinstance(command, glue.GlueGetConnectionHandler):
|
||||
self.server_event(glue.GlueGetConnectionHandlerReply(command, self))
|
||||
elif isinstance(command, commands.Hook):
|
||||
|
Loading…
Reference in New Issue
Block a user