mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 00:01:36 +00:00
simplify async server
This commit is contained in:
parent
d8621553b1
commit
a309cdb56c
@ -6,46 +6,46 @@ The very high level overview is as follows:
|
||||
- Process any commands from layer (such as opening a server connection)
|
||||
- Wait for any IO and send it as events to top layer.
|
||||
"""
|
||||
import abc
|
||||
import asyncio
|
||||
import collections
|
||||
import socket
|
||||
from typing import MutableMapping
|
||||
import typing
|
||||
|
||||
from mitmproxy import controller
|
||||
from mitmproxy.proxy.protocol2 import events, commands
|
||||
from mitmproxy.proxy.protocol2.context import Client, Context
|
||||
from mitmproxy.proxy.protocol2.context import Connection
|
||||
from mitmproxy.proxy.protocol2.reverse_proxy import ReverseProxy
|
||||
|
||||
StreamIO = collections.namedtuple('StreamIO', ['r', 'w'])
|
||||
|
||||
class StreamIO(typing.NamedTuple):
|
||||
r: asyncio.StreamReader
|
||||
w: asyncio.StreamWriter
|
||||
|
||||
|
||||
class ConnectionHandler:
|
||||
def __init__(self, event_queue, reader, writer):
|
||||
class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
transports: typing.MutableMapping[Connection, StreamIO]
|
||||
|
||||
def __init__(self, reader, writer):
|
||||
addr = writer.get_extra_info('peername')
|
||||
|
||||
self.client = Client(addr)
|
||||
self.context = Context(self.client)
|
||||
self.event_queue = event_queue
|
||||
|
||||
# self.layer = ReverseProxy(self.context, ("localhost", 443))
|
||||
self.layer = ReverseProxy(self.context, ("localhost", 80))
|
||||
|
||||
self.transports: MutableMapping[Connection, StreamIO] = {
|
||||
self.transports = {
|
||||
self.client: StreamIO(reader, writer)
|
||||
}
|
||||
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
async def handle(cls, reader, writer):
|
||||
await cls(reader, writer).handle_client()
|
||||
def _debug(self, *args):
|
||||
print(*args)
|
||||
|
||||
async def handle_client(self):
|
||||
await self.server_event(events.Start())
|
||||
self.server_event(events.Start())
|
||||
await self.handle_connection(self.client)
|
||||
|
||||
print("client connection done, closing transports!")
|
||||
self._debug("client connection done, closing transports!")
|
||||
|
||||
if self.transports:
|
||||
await asyncio.wait([
|
||||
@ -53,13 +53,13 @@ class ConnectionHandler:
|
||||
for x in self.transports
|
||||
])
|
||||
|
||||
print("transports closed!")
|
||||
self._debug("transports closed!")
|
||||
|
||||
async def close_connection(self, connection):
|
||||
io = self.transports.pop(connection, None)
|
||||
if not io:
|
||||
print(f"Already closed: {connection}")
|
||||
print(f"Closing {connection}")
|
||||
self._debug(f"Already closed: {connection}")
|
||||
self._debug(f"Closing {connection}")
|
||||
try:
|
||||
await io.w.drain()
|
||||
io.w.write_eof()
|
||||
@ -75,60 +75,70 @@ class ConnectionHandler:
|
||||
except socket.error:
|
||||
data = b""
|
||||
if data:
|
||||
await self.server_event(events.DataReceived(connection, data))
|
||||
self.server_event(events.DataReceived(connection, data))
|
||||
else:
|
||||
connection.connected = False
|
||||
if connection in self.transports:
|
||||
await self.close_connection(connection)
|
||||
await self.server_event(events.ConnectionClosed(connection))
|
||||
self.server_event(events.ConnectionClosed(connection))
|
||||
break
|
||||
|
||||
async def open_connection(self, command: commands.OpenConnection):
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection(
|
||||
*command.connection.address
|
||||
)
|
||||
except IOError as e:
|
||||
self.server_event(events.OpenConnectionReply(command, str(e)))
|
||||
else:
|
||||
self.transports[command.connection] = StreamIO(reader, writer)
|
||||
command.connection.connected = True
|
||||
await self.server_event(events.OpenConnectionReply(command, None))
|
||||
self.server_event(events.OpenConnectionReply(command, None))
|
||||
await self.handle_connection(command.connection)
|
||||
|
||||
@abc.abstractmethod
|
||||
async def handle_hook(self, hook: commands.Hook) -> None:
|
||||
# TODO: temporary glue code - let's see how many years it survives.
|
||||
hook.data.reply = controller.Reply(hook.data)
|
||||
q = asyncio.Queue()
|
||||
hook.data.reply.q = q
|
||||
self.event_queue.put((hook.name, hook.data))
|
||||
reply = await q.get()
|
||||
await self.server_event(events.HookReply(hook, reply))
|
||||
pass
|
||||
|
||||
async def server_event(self, event: events.Event) -> None:
|
||||
print("*", type(event).__name__)
|
||||
async with self.lock:
|
||||
print("<#", event)
|
||||
def server_event(self, event: events.Event) -> None:
|
||||
self._debug(">>", event)
|
||||
layer_commands = self.layer.handle_event(event)
|
||||
for command in layer_commands:
|
||||
print("<<", command)
|
||||
self._debug("<<", command)
|
||||
if isinstance(command, commands.OpenConnection):
|
||||
asyncio.ensure_future(self.open_connection(command))
|
||||
asyncio.ensure_future(
|
||||
self.open_connection(command)
|
||||
)
|
||||
elif isinstance(command, commands.SendData):
|
||||
self.transports[command.connection].w.write(command.data)
|
||||
elif isinstance(command, commands.Hook):
|
||||
print(f"~ {command.name}: {command.data}")
|
||||
asyncio.ensure_future(
|
||||
self.handle_hook(command)
|
||||
)
|
||||
elif isinstance(command, commands.CloseConnection):
|
||||
asyncio.ensure_future(
|
||||
self.close_connection(command.connection)
|
||||
)
|
||||
elif isinstance(command, commands.Hook):
|
||||
asyncio.ensure_future(
|
||||
self.handle_hook(command)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unexpected event: {command}")
|
||||
print("#>")
|
||||
raise RuntimeError(f"Unexpected event: {command}")
|
||||
|
||||
|
||||
class SimpleConnectionHandler(ConnectionHandler):
|
||||
"""Simple handler that does not process any hooks."""
|
||||
|
||||
async def handle_hook(self, hook: commands.Hook) -> None:
|
||||
self.server_event(events.HookReply(hook, None))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
coro = asyncio.start_server(ConnectionHandler.handle, '127.0.0.1', 8080, loop=loop)
|
||||
|
||||
async def handle(reader, writer):
|
||||
await SimpleConnectionHandler(reader, writer).handle_client()
|
||||
|
||||
|
||||
coro = asyncio.start_server(handle, '127.0.0.1', 8080, loop=loop)
|
||||
server = loop.run_until_complete(coro)
|
||||
|
||||
# Serve requests until Ctrl+C is pressed
|
||||
|
Loading…
Reference in New Issue
Block a user