replace .send() with async events

This commit is contained in:
Maximilian Hils 2016-11-19 00:15:27 +01:00
parent 063c52065f
commit 4caa1dffbd
2 changed files with 95 additions and 20 deletions

View File

@ -18,7 +18,7 @@ class ConnectionHandler:
self.client = Client(addr)
self.context = Context(self.client)
self.layer = ReverseProxy(self.context, ("towel.blinkenlights.nl", 23))
self.layer = ReverseProxy(self.context, ("example.com", 80))
self.transports = {} # type: MutableMapping[Connection, StreamIO]
self.transports[self.client] = StreamIO(reader, writer)
@ -26,8 +26,6 @@ class ConnectionHandler:
self.lock = asyncio.Lock()
async def handle_client(self):
await self.server_event(events.Start())
await self.handle_connection(self.client)
for connection in self.transports:
@ -48,16 +46,28 @@ class ConnectionHandler:
async def handle_connection(self, connection):
connection.connected = True
await self.server_event(events.OpenConnection(connection))
reader, writer = self.transports[connection]
while True:
try:
data = await reader.read(4096)
except socket.error:
data = b""
if data:
await self.server_event(events.ReceiveData(connection, data))
else:
connection.connected = False
await self.close(connection)
await self.server_event(events.CloseConnection(connection))
break
async def open_connection(self, event: events.OpenConnection):
reader, writer = await asyncio.open_connection(
*event.connection.address
)
self.transports[event.connection] = StreamIO(reader, writer)
await self.handle_connection(event.connection)
async def server_event(self, event: events.Event):
print("*", event)
async with self.lock:
@ -66,13 +76,7 @@ class ConnectionHandler:
for event in layer_events:
print("<<", event)
if isinstance(event, events.OpenConnection):
reader, writer = await asyncio.open_connection(
*event.connection.address
)
self.transports[event.connection] = StreamIO(reader, writer)
asyncio.ensure_future(self.handle_connection(event.connection))
layer_events.send(42)
asyncio.ensure_future(self.open_connection(event))
elif isinstance(event, events.SendData):
self.transports[event.connection].w.write(event.data)
else:

View File

@ -1,25 +1,96 @@
import functools
from mitmproxy.proxy.protocol2 import events
from mitmproxy.proxy.protocol2.context import ClientServerContext
from mitmproxy.proxy.protocol2.events import TEventGenerator
from mitmproxy.proxy.protocol2.layer import Layer
"""
Utility decorators that help build state machines
"""
def defer(event_type):
"""
Queue up the events matching the specified event type and emit them immediately
after the state has changed.
"""
def decorator(f):
deferred = []
@functools.wraps(f)
def wrapper(self, event: events.Event):
if isinstance(event, event_type):
deferred.append(event)
else:
yield from f(self, event)
if self.state != f:
for event in deferred:
yield from self.state(event)
deferred.clear()
return wrapper
return decorator
def exit_on_close(f):
"""
Stop all further interaction once a single close event has been observed.
"""
closed = False
@functools.wraps(f)
def wrapper(self, event: events.Event):
nonlocal closed
if isinstance(event, events.CloseConnection):
closed = True
if not closed:
yield from f(self, event)
return wrapper
class TCPLayer(Layer):
context = None # type: ClientServerContext
def handle_event(self, event: events.Event) -> TEventGenerator:
if isinstance(event, events.Start):
if not self.context.server.connected:
try:
t = yield events.OpenConnection(self.context.server)
yield
print("opening took {}s".format(t)) # example on how we can implement .ask()
except Exception as e:
print("Could not connect to server: {}".format(e))
def __init__(self, context: ClientServerContext):
super().__init__(context)
self.state = self.start
def handle_event(self, event: events.Event) -> TEventGenerator:
yield from self.state(event)
def start(self, event: events.Event) -> TEventGenerator:
if isinstance(event, events.OpenConnection):
if not self.context.server.connected:
yield events.OpenConnection(self.context.server)
self.state = self.wait_for_open
else:
self.state = self.relay_messages
else:
raise TypeError("Unexpected event: {}".format(event))
@defer(events.ReceiveData)
@exit_on_close
def wait_for_open(self, event: events.Event) -> TEventGenerator:
if isinstance(event, events.OpenConnection):
# connection is now open
self.state = self.relay_messages
else:
raise TypeError("Unexpected event: {}".format(event))
# noinspection PyUnreachableCode
yield
def relay_messages(self, event: events.Event) -> TEventGenerator:
if isinstance(event, events.ReceiveData):
if event.connection == self.context.client:
dst = self.context.server
else:
dst = self.context.client
yield events.SendData(dst, event.data)
if isinstance(event, events.CloseConnection):
pass # TODO: close other connection here.
else:
raise TypeError("Unexpected event: {}".format(event))