diff --git a/mitmproxy/proxy/protocol2/server/server_async.py b/mitmproxy/proxy/protocol2/server/server_async.py index bd5d45631..a5eb845b8 100644 --- a/mitmproxy/proxy/protocol2/server/server_async.py +++ b/mitmproxy/proxy/protocol2/server/server_async.py @@ -26,8 +26,8 @@ class ConnectionHandler: self.client = Client(addr) self.context = Context(self.client) - # self.layer = ReverseProxy(self.context, ("example.com", 443)) - self.layer = ReverseProxy(self.context, ("example.com", 80)) + # self.layer = ReverseProxy(self.context, ("localhost", 443)) + self.layer = ReverseProxy(self.context, ("localhost", 80)) self.transports: MutableMapping[Connection, StreamIO] = { self.client: StreamIO(reader, writer) @@ -41,15 +41,19 @@ class ConnectionHandler: print("client connection done, closing transports!") - for connection in list(self.transports): - await self.close(connection) + if self.transports: + await asyncio.wait([ + self.close_connection(x) + for x in self.transports + ]) - # TODO: teardown all other conns. print("transports closed!") - async def close(self, connection): - print("Closing", connection) - io = self.transports.pop(connection) + async def close_connection(self, connection): + io = self.transports.pop(connection, None) + if not io: + print(f"Already closed: {connection}") + print(f"Closing {connection}") try: await io.w.drain() io.w.write_eof() @@ -69,7 +73,7 @@ class ConnectionHandler: else: connection.connected = False if connection in self.transports: - await self.close(connection) + await self.close_connection(connection) await self.server_event(events.ConnectionClosed(connection)) break @@ -97,10 +101,14 @@ class ConnectionHandler: # TODO: pass to master here. print(f"~ {command.name}: {command.data}") asyncio.ensure_future( - self.server_event(events.HookReply(command, "hook reply")) + self.server_event(events.HookReply(command, None)) + ) + elif isinstance(command, commands.CloseConnection): + asyncio.ensure_future( + self.close_connection(command.connection) ) else: - raise NotImplementedError("Unexpected event: {}".format(command)) + raise NotImplementedError(f"Unexpected event: {command}") print("#>")