replace .send() with async events

This commit is contained in:
Maximilian Hils 2016-11-19 00:15:27 +01:00
parent 063c52065f
commit 4caa1dffbd
2 changed files with 95 additions and 20 deletions

View File

@ -18,7 +18,7 @@ class ConnectionHandler:
self.client = Client(addr) self.client = Client(addr)
self.context = Context(self.client) self.context = Context(self.client)
self.layer = ReverseProxy(self.context, ("towel.blinkenlights.nl", 23)) self.layer = ReverseProxy(self.context, ("example.com", 80))
self.transports = {} # type: MutableMapping[Connection, StreamIO] self.transports = {} # type: MutableMapping[Connection, StreamIO]
self.transports[self.client] = StreamIO(reader, writer) self.transports[self.client] = StreamIO(reader, writer)
@ -26,8 +26,6 @@ class ConnectionHandler:
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
async def handle_client(self): async def handle_client(self):
await self.server_event(events.Start())
await self.handle_connection(self.client) await self.handle_connection(self.client)
for connection in self.transports: for connection in self.transports:
@ -48,16 +46,28 @@ class ConnectionHandler:
async def handle_connection(self, connection): async def handle_connection(self, connection):
connection.connected = True connection.connected = True
await self.server_event(events.OpenConnection(connection))
reader, writer = self.transports[connection] reader, writer = self.transports[connection]
while True: while True:
try:
data = await reader.read(4096) data = await reader.read(4096)
except socket.error:
data = b""
if data: if data:
await self.server_event(events.ReceiveData(connection, data)) await self.server_event(events.ReceiveData(connection, data))
else: else:
connection.connected = False connection.connected = False
await self.close(connection) await self.close(connection)
await self.server_event(events.CloseConnection(connection))
break break
async def open_connection(self, event: events.OpenConnection):
reader, writer = await asyncio.open_connection(
*event.connection.address
)
self.transports[event.connection] = StreamIO(reader, writer)
await self.handle_connection(event.connection)
async def server_event(self, event: events.Event): async def server_event(self, event: events.Event):
print("*", event) print("*", event)
async with self.lock: async with self.lock:
@ -66,13 +76,7 @@ class ConnectionHandler:
for event in layer_events: for event in layer_events:
print("<<", event) print("<<", event)
if isinstance(event, events.OpenConnection): if isinstance(event, events.OpenConnection):
reader, writer = await asyncio.open_connection( asyncio.ensure_future(self.open_connection(event))
*event.connection.address
)
self.transports[event.connection] = StreamIO(reader, writer)
asyncio.ensure_future(self.handle_connection(event.connection))
layer_events.send(42)
elif isinstance(event, events.SendData): elif isinstance(event, events.SendData):
self.transports[event.connection].w.write(event.data) self.transports[event.connection].w.write(event.data)
else: else:

View File

@ -1,25 +1,96 @@
import functools
from mitmproxy.proxy.protocol2 import events from mitmproxy.proxy.protocol2 import events
from mitmproxy.proxy.protocol2.context import ClientServerContext from mitmproxy.proxy.protocol2.context import ClientServerContext
from mitmproxy.proxy.protocol2.events import TEventGenerator from mitmproxy.proxy.protocol2.events import TEventGenerator
from mitmproxy.proxy.protocol2.layer import Layer from mitmproxy.proxy.protocol2.layer import Layer
"""
Utility decorators that help build state machines
"""
def defer(event_type):
"""
Queue up the events matching the specified event type and emit them immediately
after the state has changed.
"""
def decorator(f):
deferred = []
@functools.wraps(f)
def wrapper(self, event: events.Event):
if isinstance(event, event_type):
deferred.append(event)
else:
yield from f(self, event)
if self.state != f:
for event in deferred:
yield from self.state(event)
deferred.clear()
return wrapper
return decorator
def exit_on_close(f):
"""
Stop all further interaction once a single close event has been observed.
"""
closed = False
@functools.wraps(f)
def wrapper(self, event: events.Event):
nonlocal closed
if isinstance(event, events.CloseConnection):
closed = True
if not closed:
yield from f(self, event)
return wrapper
class TCPLayer(Layer): class TCPLayer(Layer):
context = None # type: ClientServerContext context = None # type: ClientServerContext
def handle_event(self, event: events.Event) -> TEventGenerator: def __init__(self, context: ClientServerContext):
if isinstance(event, events.Start): super().__init__(context)
if not self.context.server.connected: self.state = self.start
try:
t = yield events.OpenConnection(self.context.server)
yield
print("opening took {}s".format(t)) # example on how we can implement .ask()
except Exception as e:
print("Could not connect to server: {}".format(e))
def handle_event(self, event: events.Event) -> TEventGenerator:
yield from self.state(event)
def start(self, event: events.Event) -> TEventGenerator:
if isinstance(event, events.OpenConnection):
if not self.context.server.connected:
yield events.OpenConnection(self.context.server)
self.state = self.wait_for_open
else:
self.state = self.relay_messages
else:
raise TypeError("Unexpected event: {}".format(event))
@defer(events.ReceiveData)
@exit_on_close
def wait_for_open(self, event: events.Event) -> TEventGenerator:
if isinstance(event, events.OpenConnection):
# connection is now open
self.state = self.relay_messages
else:
raise TypeError("Unexpected event: {}".format(event))
# noinspection PyUnreachableCode
yield
def relay_messages(self, event: events.Event) -> TEventGenerator:
if isinstance(event, events.ReceiveData): if isinstance(event, events.ReceiveData):
if event.connection == self.context.client: if event.connection == self.context.client:
dst = self.context.server dst = self.context.server
else: else:
dst = self.context.client dst = self.context.client
yield events.SendData(dst, event.data) yield events.SendData(dst, event.data)
if isinstance(event, events.CloseConnection):
pass # TODO: close other connection here.
else:
raise TypeError("Unexpected event: {}".format(event))