mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 00:01:36 +00:00
[sans-io] better represent half-closed connections
This commit is contained in:
parent
9f075fbbbc
commit
16abce77ea
@ -1,21 +1,32 @@
|
|||||||
import copy
|
from enum import Flag, auto
|
||||||
from typing import Optional, List, Union, Sequence, Any
|
from typing import List, Optional, Sequence, Union
|
||||||
|
|
||||||
from mitmproxy.options import Options
|
from mitmproxy.options import Options
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionState(Flag):
|
||||||
|
CLOSED = 0
|
||||||
|
CAN_READ = auto()
|
||||||
|
CAN_WRITE = auto()
|
||||||
|
OPEN = CAN_READ | CAN_WRITE
|
||||||
|
|
||||||
|
|
||||||
class Connection:
|
class Connection:
|
||||||
"""
|
"""
|
||||||
Connections exposed to the layers only contain metadata, no socket objects.
|
Connections exposed to the layers only contain metadata, no socket objects.
|
||||||
"""
|
"""
|
||||||
address: tuple
|
address: tuple
|
||||||
connected: bool = False
|
state: ConnectionState
|
||||||
tls: bool = False
|
tls: bool = False
|
||||||
tls_established: bool = False
|
tls_established: bool = False
|
||||||
alpn: Optional[bytes] = None
|
alpn: Optional[bytes] = None
|
||||||
alpn_offers: Sequence[bytes] = ()
|
alpn_offers: Sequence[bytes] = ()
|
||||||
sni: Union[bytes, bool, None]
|
sni: Union[bytes, bool, None]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def connected(self):
|
||||||
|
return self.state is ConnectionState.OPEN
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"{type(self).__name__}({repr(self.__dict__)})"
|
return f"{type(self).__name__}({repr(self.__dict__)})"
|
||||||
|
|
||||||
@ -25,7 +36,7 @@ class Client(Connection):
|
|||||||
|
|
||||||
def __init__(self, address):
|
def __init__(self, address):
|
||||||
self.address = address
|
self.address = address
|
||||||
self.connected = True
|
self.state = ConnectionState.OPEN
|
||||||
|
|
||||||
|
|
||||||
class Server(Connection):
|
class Server(Connection):
|
||||||
@ -35,6 +46,7 @@ class Server(Connection):
|
|||||||
|
|
||||||
def __init__(self, address: Optional[tuple]):
|
def __init__(self, address: Optional[tuple]):
|
||||||
self.address = address
|
self.address = address
|
||||||
|
self.state = ConnectionState.CLOSED
|
||||||
|
|
||||||
|
|
||||||
class Context:
|
class Context:
|
||||||
|
@ -10,15 +10,17 @@ import abc
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
|
import traceback
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from mitmproxy import http, options as moptions
|
from mitmproxy import http, options as moptions
|
||||||
from mitmproxy.proxy.protocol.http import HTTPMode
|
from mitmproxy.proxy.protocol.http import HTTPMode
|
||||||
from mitmproxy.proxy2 import commands, events, layer, layers
|
from mitmproxy.proxy2 import commands, events, layer, layers
|
||||||
from mitmproxy.proxy2.context import Client, Connection, Context
|
from mitmproxy.proxy2.context import Client, Connection, ConnectionState, Context
|
||||||
from mitmproxy.proxy2.layers import glue
|
from mitmproxy.proxy2.layers import glue
|
||||||
from mitmproxy.utils import human
|
from mitmproxy.utils import human
|
||||||
|
|
||||||
|
|
||||||
class StreamIO(typing.NamedTuple):
|
class StreamIO(typing.NamedTuple):
|
||||||
r: asyncio.StreamReader
|
r: asyncio.StreamReader
|
||||||
w: asyncio.StreamWriter
|
w: asyncio.StreamWriter
|
||||||
@ -27,7 +29,7 @@ class StreamIO(typing.NamedTuple):
|
|||||||
class ConnectionHandler(metaclass=abc.ABCMeta):
|
class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||||
transports: typing.MutableMapping[Connection, StreamIO]
|
transports: typing.MutableMapping[Connection, StreamIO]
|
||||||
|
|
||||||
def __init__(self, reader, writer, options):
|
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, options: moptions.Options) -> None:
|
||||||
addr = writer.get_extra_info('peername')
|
addr = writer.get_extra_info('peername')
|
||||||
|
|
||||||
self.client = Client(addr)
|
self.client = Client(addr)
|
||||||
@ -43,7 +45,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||||||
self.client: StreamIO(reader, writer)
|
self.client: StreamIO(reader, writer)
|
||||||
}
|
}
|
||||||
|
|
||||||
async def handle_client(self):
|
async def handle_client(self) -> None:
|
||||||
# FIXME: Work around log suppression in core.
|
# FIXME: Work around log suppression in core.
|
||||||
logging.getLogger('asyncio').setLevel(logging.DEBUG)
|
logging.getLogger('asyncio').setLevel(logging.DEBUG)
|
||||||
|
|
||||||
@ -55,28 +57,29 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||||||
self.log("[sans-io] clientdisconnect")
|
self.log("[sans-io] clientdisconnect")
|
||||||
|
|
||||||
if self.transports:
|
if self.transports:
|
||||||
|
self.log("[sans-io] closing transports...")
|
||||||
await asyncio.wait([
|
await asyncio.wait([
|
||||||
self.close_connection(x)
|
self.close_connection(x)
|
||||||
for x in self.transports
|
for x in self.transports
|
||||||
])
|
])
|
||||||
# self._debug("transports closed!")
|
self.log("[sans-io] transports closed!")
|
||||||
|
|
||||||
async def close_connection(self, connection):
|
async def close_connection(self, connection: Connection) -> None:
|
||||||
try:
|
self.log(f"closing {connection}", "debug")
|
||||||
io = self.transports.pop(connection)
|
connection.state = ConnectionState.CLOSED
|
||||||
except KeyError:
|
io = self.transports.pop(connection)
|
||||||
self.log(f"already closed: {connection}", "warn")
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
self.log(f"closing {connection}", "debug")
|
|
||||||
try:
|
|
||||||
await io.w.drain()
|
|
||||||
io.w.write_eof()
|
|
||||||
except socket.error:
|
|
||||||
pass
|
|
||||||
io.w.close()
|
io.w.close()
|
||||||
|
await io.w.wait_closed()
|
||||||
|
|
||||||
async def handle_connection(self, connection):
|
async def shutdown_connection(self, connection: Connection) -> None:
|
||||||
|
assert connection.state & ConnectionState.CAN_WRITE
|
||||||
|
io = self.transports[connection]
|
||||||
|
self.log(f"shutting down {connection}", "debug")
|
||||||
|
|
||||||
|
io.w.write_eof()
|
||||||
|
connection.state &= ~ConnectionState.CAN_WRITE
|
||||||
|
|
||||||
|
async def handle_connection(self, connection: Connection) -> None:
|
||||||
reader, writer = self.transports[connection]
|
reader, writer = self.transports[connection]
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@ -86,15 +89,15 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||||||
if data:
|
if data:
|
||||||
self.server_event(events.DataReceived(connection, data))
|
self.server_event(events.DataReceived(connection, data))
|
||||||
else:
|
else:
|
||||||
connection.connected = False
|
if connection.state is ConnectionState.CAN_READ:
|
||||||
if connection in self.transports:
|
|
||||||
await self.close_connection(connection)
|
await self.close_connection(connection)
|
||||||
self.server_event(events.ConnectionClosed(connection))
|
self.server_event(events.ConnectionClosed(connection))
|
||||||
break
|
break
|
||||||
|
|
||||||
async def open_connection(self, command: commands.OpenConnection):
|
async def open_connection(self, command: commands.OpenConnection) -> None:
|
||||||
if not command.connection.address:
|
if not command.connection.address:
|
||||||
raise ValueError("Cannot open connection, no hostname given.")
|
raise ValueError("Cannot open connection, no hostname given.")
|
||||||
|
assert command.connection not in self.transports
|
||||||
try:
|
try:
|
||||||
reader, writer = await asyncio.open_connection(
|
reader, writer = await asyncio.open_connection(
|
||||||
*command.connection.address
|
*command.connection.address
|
||||||
@ -104,7 +107,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||||||
else:
|
else:
|
||||||
self.log("serverconnect")
|
self.log("serverconnect")
|
||||||
self.transports[command.connection] = StreamIO(reader, writer)
|
self.transports[command.connection] = StreamIO(reader, writer)
|
||||||
command.connection.connected = True
|
command.connection.state = ConnectionState.OPEN
|
||||||
self.server_event(events.OpenConnectionReply(command, None))
|
self.server_event(events.OpenConnectionReply(command, None))
|
||||||
await self.handle_connection(command.connection)
|
await self.handle_connection(command.connection)
|
||||||
self.log("serverdisconnect")
|
self.log("serverdisconnect")
|
||||||
@ -117,7 +120,11 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||||||
print(message)
|
print(message)
|
||||||
|
|
||||||
def server_event(self, event: events.Event) -> None:
|
def server_event(self, event: events.Event) -> None:
|
||||||
layer_commands = self.layer.handle_event(event)
|
try:
|
||||||
|
layer_commands = list(self.layer.handle_event(event))
|
||||||
|
except Exception:
|
||||||
|
self.log(f"mitmproxy has crashed!\n{traceback.format_exc()}", level="error")
|
||||||
|
return
|
||||||
for command in layer_commands:
|
for command in layer_commands:
|
||||||
if isinstance(command, commands.OpenConnection):
|
if isinstance(command, commands.OpenConnection):
|
||||||
asyncio.ensure_future(
|
asyncio.ensure_future(
|
||||||
@ -126,9 +133,14 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||||||
elif isinstance(command, commands.SendData):
|
elif isinstance(command, commands.SendData):
|
||||||
self.transports[command.connection].w.write(command.data)
|
self.transports[command.connection].w.write(command.data)
|
||||||
elif isinstance(command, commands.CloseConnection):
|
elif isinstance(command, commands.CloseConnection):
|
||||||
asyncio.ensure_future(
|
if command.connection == self.client:
|
||||||
self.close_connection(command.connection)
|
asyncio.ensure_future(
|
||||||
)
|
self.close_connection(command.connection)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
asyncio.ensure_future(
|
||||||
|
self.shutdown_connection(command.connection)
|
||||||
|
)
|
||||||
elif isinstance(command, glue.GlueGetConnectionHandler):
|
elif isinstance(command, glue.GlueGetConnectionHandler):
|
||||||
self.server_event(glue.GlueGetConnectionHandlerReply(command, self))
|
self.server_event(glue.GlueGetConnectionHandlerReply(command, self))
|
||||||
elif isinstance(command, commands.Hook):
|
elif isinstance(command, commands.Hook):
|
||||||
|
Loading…
Reference in New Issue
Block a user