diff --git a/mitmproxy/proxy/protocol2/server/server_async.py b/mitmproxy/proxy/protocol2/server/server_async.py index 0d5be65ff..59002e096 100644 --- a/mitmproxy/proxy/protocol2/server/server_async.py +++ b/mitmproxy/proxy/protocol2/server/server_async.py @@ -6,46 +6,46 @@ The very high level overview is as follows: - Process any commands from layer (such as opening a server connection) - Wait for any IO and send it as events to top layer. """ +import abc import asyncio -import collections import socket -from typing import MutableMapping +import typing -from mitmproxy import controller from mitmproxy.proxy.protocol2 import events, commands from mitmproxy.proxy.protocol2.context import Client, Context from mitmproxy.proxy.protocol2.context import Connection from mitmproxy.proxy.protocol2.reverse_proxy import ReverseProxy -StreamIO = collections.namedtuple('StreamIO', ['r', 'w']) + +class StreamIO(typing.NamedTuple): + r: asyncio.StreamReader + w: asyncio.StreamWriter -class ConnectionHandler: - def __init__(self, event_queue, reader, writer): +class ConnectionHandler(metaclass=abc.ABCMeta): + transports: typing.MutableMapping[Connection, StreamIO] + + def __init__(self, reader, writer): addr = writer.get_extra_info('peername') self.client = Client(addr) self.context = Context(self.client) - self.event_queue = event_queue # self.layer = ReverseProxy(self.context, ("localhost", 443)) self.layer = ReverseProxy(self.context, ("localhost", 80)) - self.transports: MutableMapping[Connection, StreamIO] = { + self.transports = { self.client: StreamIO(reader, writer) } - self.lock = asyncio.Lock() - - @classmethod - async def handle(cls, reader, writer): - await cls(reader, writer).handle_client() + def _debug(self, *args): + print(*args) async def handle_client(self): - await self.server_event(events.Start()) + self.server_event(events.Start()) await self.handle_connection(self.client) - print("client connection done, closing transports!") + self._debug("client connection done, closing transports!") if self.transports: await asyncio.wait([ @@ -53,13 +53,13 @@ class ConnectionHandler: for x in self.transports ]) - print("transports closed!") + self._debug("transports closed!") async def close_connection(self, connection): io = self.transports.pop(connection, None) if not io: - print(f"Already closed: {connection}") - print(f"Closing {connection}") + self._debug(f"Already closed: {connection}") + self._debug(f"Closing {connection}") try: await io.w.drain() io.w.write_eof() @@ -75,60 +75,70 @@ class ConnectionHandler: except socket.error: data = b"" if data: - await self.server_event(events.DataReceived(connection, data)) + self.server_event(events.DataReceived(connection, data)) else: connection.connected = False if connection in self.transports: await self.close_connection(connection) - await self.server_event(events.ConnectionClosed(connection)) + self.server_event(events.ConnectionClosed(connection)) break async def open_connection(self, command: commands.OpenConnection): - reader, writer = await asyncio.open_connection( - *command.connection.address - ) - self.transports[command.connection] = StreamIO(reader, writer) - command.connection.connected = True - await self.server_event(events.OpenConnectionReply(command, None)) - await self.handle_connection(command.connection) + try: + reader, writer = await asyncio.open_connection( + *command.connection.address + ) + except IOError as e: + self.server_event(events.OpenConnectionReply(command, str(e))) + else: + self.transports[command.connection] = StreamIO(reader, writer) + command.connection.connected = True + self.server_event(events.OpenConnectionReply(command, None)) + await self.handle_connection(command.connection) + + @abc.abstractmethod + async def handle_hook(self, hook: commands.Hook) -> None: + pass + + def server_event(self, event: events.Event) -> None: + self._debug(">>", event) + layer_commands = self.layer.handle_event(event) + for command in layer_commands: + self._debug("<<", command) + if isinstance(command, commands.OpenConnection): + asyncio.ensure_future( + self.open_connection(command) + ) + elif isinstance(command, commands.SendData): + self.transports[command.connection].w.write(command.data) + elif isinstance(command, commands.CloseConnection): + asyncio.ensure_future( + self.close_connection(command.connection) + ) + elif isinstance(command, commands.Hook): + asyncio.ensure_future( + self.handle_hook(command) + ) + else: + raise RuntimeError(f"Unexpected event: {command}") + + +class SimpleConnectionHandler(ConnectionHandler): + """Simple handler that does not process any hooks.""" async def handle_hook(self, hook: commands.Hook) -> None: - # TODO: temporary glue code - let's see how many years it survives. - hook.data.reply = controller.Reply(hook.data) - q = asyncio.Queue() - hook.data.reply.q = q - self.event_queue.put((hook.name, hook.data)) - reply = await q.get() - await self.server_event(events.HookReply(hook, reply)) + self.server_event(events.HookReply(hook, None)) - async def server_event(self, event: events.Event) -> None: - print("*", type(event).__name__) - async with self.lock: - print("<#", event) - layer_commands = self.layer.handle_event(event) - for command in layer_commands: - print("<<", command) - if isinstance(command, commands.OpenConnection): - asyncio.ensure_future(self.open_connection(command)) - elif isinstance(command, commands.SendData): - self.transports[command.connection].w.write(command.data) - elif isinstance(command, commands.Hook): - print(f"~ {command.name}: {command.data}") - asyncio.ensure_future( - self.handle_hook(command) - ) - elif isinstance(command, commands.CloseConnection): - asyncio.ensure_future( - self.close_connection(command.connection) - ) - else: - raise NotImplementedError(f"Unexpected event: {command}") - print("#>") if __name__ == "__main__": loop = asyncio.get_event_loop() - coro = asyncio.start_server(ConnectionHandler.handle, '127.0.0.1', 8080, loop=loop) + + async def handle(reader, writer): + await SimpleConnectionHandler(reader, writer).handle_client() + + + coro = asyncio.start_server(handle, '127.0.0.1', 8080, loop=loop) server = loop.run_until_complete(coro) # Serve requests until Ctrl+C is pressed