diff --git a/mitmproxy/proxy/protocol2/server_async.py b/mitmproxy/proxy/protocol2/server_async.py index 0f34a9697..64bdf1eb0 100644 --- a/mitmproxy/proxy/protocol2/server_async.py +++ b/mitmproxy/proxy/protocol2/server_async.py @@ -18,7 +18,7 @@ class ConnectionHandler: self.client = Client(addr) self.context = Context(self.client) - self.layer = ReverseProxy(self.context, ("towel.blinkenlights.nl", 23)) + self.layer = ReverseProxy(self.context, ("example.com", 80)) self.transports = {} # type: MutableMapping[Connection, StreamIO] self.transports[self.client] = StreamIO(reader, writer) @@ -26,8 +26,6 @@ class ConnectionHandler: self.lock = asyncio.Lock() async def handle_client(self): - await self.server_event(events.Start()) - await self.handle_connection(self.client) for connection in self.transports: @@ -48,16 +46,28 @@ class ConnectionHandler: async def handle_connection(self, connection): connection.connected = True + await self.server_event(events.OpenConnection(connection)) reader, writer = self.transports[connection] while True: - data = await reader.read(4096) + try: + data = await reader.read(4096) + except socket.error: + data = b"" if data: await self.server_event(events.ReceiveData(connection, data)) else: connection.connected = False await self.close(connection) + await self.server_event(events.CloseConnection(connection)) break + async def open_connection(self, event: events.OpenConnection): + reader, writer = await asyncio.open_connection( + *event.connection.address + ) + self.transports[event.connection] = StreamIO(reader, writer) + await self.handle_connection(event.connection) + async def server_event(self, event: events.Event): print("*", event) async with self.lock: @@ -66,13 +76,7 @@ class ConnectionHandler: for event in layer_events: print("<<", event) if isinstance(event, events.OpenConnection): - reader, writer = await asyncio.open_connection( - *event.connection.address - ) - self.transports[event.connection] = StreamIO(reader, writer) - asyncio.ensure_future(self.handle_connection(event.connection)) - layer_events.send(42) - + asyncio.ensure_future(self.open_connection(event)) elif isinstance(event, events.SendData): self.transports[event.connection].w.write(event.data) else: diff --git a/mitmproxy/proxy/protocol2/tcp.py b/mitmproxy/proxy/protocol2/tcp.py index 0e5df1eef..f937df5eb 100644 --- a/mitmproxy/proxy/protocol2/tcp.py +++ b/mitmproxy/proxy/protocol2/tcp.py @@ -1,25 +1,96 @@ +import functools + from mitmproxy.proxy.protocol2 import events from mitmproxy.proxy.protocol2.context import ClientServerContext from mitmproxy.proxy.protocol2.events import TEventGenerator from mitmproxy.proxy.protocol2.layer import Layer +""" +Utility decorators that help build state machines +""" + + +def defer(event_type): + """ + Queue up the events matching the specified event type and emit them immediately + after the state has changed. + """ + + def decorator(f): + deferred = [] + + @functools.wraps(f) + def wrapper(self, event: events.Event): + if isinstance(event, event_type): + deferred.append(event) + else: + yield from f(self, event) + if self.state != f: + for event in deferred: + yield from self.state(event) + deferred.clear() + + return wrapper + + return decorator + + +def exit_on_close(f): + """ + Stop all further interaction once a single close event has been observed. + """ + closed = False + + @functools.wraps(f) + def wrapper(self, event: events.Event): + nonlocal closed + if isinstance(event, events.CloseConnection): + closed = True + if not closed: + yield from f(self, event) + + return wrapper + class TCPLayer(Layer): context = None # type: ClientServerContext - def handle_event(self, event: events.Event) -> TEventGenerator: - if isinstance(event, events.Start): - if not self.context.server.connected: - try: - t = yield events.OpenConnection(self.context.server) - yield - print("opening took {}s".format(t)) # example on how we can implement .ask() - except Exception as e: - print("Could not connect to server: {}".format(e)) + def __init__(self, context: ClientServerContext): + super().__init__(context) + self.state = self.start + def handle_event(self, event: events.Event) -> TEventGenerator: + yield from self.state(event) + + def start(self, event: events.Event) -> TEventGenerator: + if isinstance(event, events.OpenConnection): + if not self.context.server.connected: + yield events.OpenConnection(self.context.server) + self.state = self.wait_for_open + else: + self.state = self.relay_messages + else: + raise TypeError("Unexpected event: {}".format(event)) + + @defer(events.ReceiveData) + @exit_on_close + def wait_for_open(self, event: events.Event) -> TEventGenerator: + if isinstance(event, events.OpenConnection): + # connection is now open + self.state = self.relay_messages + else: + raise TypeError("Unexpected event: {}".format(event)) + # noinspection PyUnreachableCode + yield + + def relay_messages(self, event: events.Event) -> TEventGenerator: if isinstance(event, events.ReceiveData): if event.connection == self.context.client: dst = self.context.server else: dst = self.context.client yield events.SendData(dst, event.data) + if isinstance(event, events.CloseConnection): + pass # TODO: close other connection here. + else: + raise TypeError("Unexpected event: {}".format(event))