simplify async server

This commit is contained in:
Maximilian Hils 2017-06-28 17:27:52 +02:00
parent d8621553b1
commit a309cdb56c

View File

@ -6,46 +6,46 @@ The very high level overview is as follows:
- Process any commands from layer (such as opening a server connection)
- Wait for any IO and send it as events to top layer.
"""
import abc
import asyncio
import collections
import socket
from typing import MutableMapping
import typing
from mitmproxy import controller
from mitmproxy.proxy.protocol2 import events, commands
from mitmproxy.proxy.protocol2.context import Client, Context
from mitmproxy.proxy.protocol2.context import Connection
from mitmproxy.proxy.protocol2.reverse_proxy import ReverseProxy
StreamIO = collections.namedtuple('StreamIO', ['r', 'w'])
class StreamIO(typing.NamedTuple):
r: asyncio.StreamReader
w: asyncio.StreamWriter
class ConnectionHandler:
def __init__(self, event_queue, reader, writer):
class ConnectionHandler(metaclass=abc.ABCMeta):
transports: typing.MutableMapping[Connection, StreamIO]
def __init__(self, reader, writer):
addr = writer.get_extra_info('peername')
self.client = Client(addr)
self.context = Context(self.client)
self.event_queue = event_queue
# self.layer = ReverseProxy(self.context, ("localhost", 443))
self.layer = ReverseProxy(self.context, ("localhost", 80))
self.transports: MutableMapping[Connection, StreamIO] = {
self.transports = {
self.client: StreamIO(reader, writer)
}
self.lock = asyncio.Lock()
@classmethod
async def handle(cls, reader, writer):
await cls(reader, writer).handle_client()
def _debug(self, *args):
print(*args)
async def handle_client(self):
await self.server_event(events.Start())
self.server_event(events.Start())
await self.handle_connection(self.client)
print("client connection done, closing transports!")
self._debug("client connection done, closing transports!")
if self.transports:
await asyncio.wait([
@ -53,13 +53,13 @@ class ConnectionHandler:
for x in self.transports
])
print("transports closed!")
self._debug("transports closed!")
async def close_connection(self, connection):
io = self.transports.pop(connection, None)
if not io:
print(f"Already closed: {connection}")
print(f"Closing {connection}")
self._debug(f"Already closed: {connection}")
self._debug(f"Closing {connection}")
try:
await io.w.drain()
io.w.write_eof()
@ -75,60 +75,70 @@ class ConnectionHandler:
except socket.error:
data = b""
if data:
await self.server_event(events.DataReceived(connection, data))
self.server_event(events.DataReceived(connection, data))
else:
connection.connected = False
if connection in self.transports:
await self.close_connection(connection)
await self.server_event(events.ConnectionClosed(connection))
self.server_event(events.ConnectionClosed(connection))
break
async def open_connection(self, command: commands.OpenConnection):
reader, writer = await asyncio.open_connection(
*command.connection.address
)
self.transports[command.connection] = StreamIO(reader, writer)
command.connection.connected = True
await self.server_event(events.OpenConnectionReply(command, None))
await self.handle_connection(command.connection)
try:
reader, writer = await asyncio.open_connection(
*command.connection.address
)
except IOError as e:
self.server_event(events.OpenConnectionReply(command, str(e)))
else:
self.transports[command.connection] = StreamIO(reader, writer)
command.connection.connected = True
self.server_event(events.OpenConnectionReply(command, None))
await self.handle_connection(command.connection)
@abc.abstractmethod
async def handle_hook(self, hook: commands.Hook) -> None:
pass
def server_event(self, event: events.Event) -> None:
self._debug(">>", event)
layer_commands = self.layer.handle_event(event)
for command in layer_commands:
self._debug("<<", command)
if isinstance(command, commands.OpenConnection):
asyncio.ensure_future(
self.open_connection(command)
)
elif isinstance(command, commands.SendData):
self.transports[command.connection].w.write(command.data)
elif isinstance(command, commands.CloseConnection):
asyncio.ensure_future(
self.close_connection(command.connection)
)
elif isinstance(command, commands.Hook):
asyncio.ensure_future(
self.handle_hook(command)
)
else:
raise RuntimeError(f"Unexpected event: {command}")
class SimpleConnectionHandler(ConnectionHandler):
"""Simple handler that does not process any hooks."""
async def handle_hook(self, hook: commands.Hook) -> None:
# TODO: temporary glue code - let's see how many years it survives.
hook.data.reply = controller.Reply(hook.data)
q = asyncio.Queue()
hook.data.reply.q = q
self.event_queue.put((hook.name, hook.data))
reply = await q.get()
await self.server_event(events.HookReply(hook, reply))
self.server_event(events.HookReply(hook, None))
async def server_event(self, event: events.Event) -> None:
print("*", type(event).__name__)
async with self.lock:
print("<#", event)
layer_commands = self.layer.handle_event(event)
for command in layer_commands:
print("<<", command)
if isinstance(command, commands.OpenConnection):
asyncio.ensure_future(self.open_connection(command))
elif isinstance(command, commands.SendData):
self.transports[command.connection].w.write(command.data)
elif isinstance(command, commands.Hook):
print(f"~ {command.name}: {command.data}")
asyncio.ensure_future(
self.handle_hook(command)
)
elif isinstance(command, commands.CloseConnection):
asyncio.ensure_future(
self.close_connection(command.connection)
)
else:
raise NotImplementedError(f"Unexpected event: {command}")
print("#>")
if __name__ == "__main__":
loop = asyncio.get_event_loop()
coro = asyncio.start_server(ConnectionHandler.handle, '127.0.0.1', 8080, loop=loop)
async def handle(reader, writer):
await SimpleConnectionHandler(reader, writer).handle_client()
coro = asyncio.start_server(handle, '127.0.0.1', 8080, loop=loop)
server = loop.run_until_complete(coro)
# Serve requests until Ctrl+C is pressed