mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-27 02:24:18 +00:00
replace .send() with async events
This commit is contained in:
parent
063c52065f
commit
4caa1dffbd
@ -18,7 +18,7 @@ class ConnectionHandler:
|
||||
self.client = Client(addr)
|
||||
self.context = Context(self.client)
|
||||
|
||||
self.layer = ReverseProxy(self.context, ("towel.blinkenlights.nl", 23))
|
||||
self.layer = ReverseProxy(self.context, ("example.com", 80))
|
||||
|
||||
self.transports = {} # type: MutableMapping[Connection, StreamIO]
|
||||
self.transports[self.client] = StreamIO(reader, writer)
|
||||
@ -26,8 +26,6 @@ class ConnectionHandler:
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def handle_client(self):
|
||||
await self.server_event(events.Start())
|
||||
|
||||
await self.handle_connection(self.client)
|
||||
|
||||
for connection in self.transports:
|
||||
@ -48,16 +46,28 @@ class ConnectionHandler:
|
||||
|
||||
async def handle_connection(self, connection):
|
||||
connection.connected = True
|
||||
await self.server_event(events.OpenConnection(connection))
|
||||
reader, writer = self.transports[connection]
|
||||
while True:
|
||||
try:
|
||||
data = await reader.read(4096)
|
||||
except socket.error:
|
||||
data = b""
|
||||
if data:
|
||||
await self.server_event(events.ReceiveData(connection, data))
|
||||
else:
|
||||
connection.connected = False
|
||||
await self.close(connection)
|
||||
await self.server_event(events.CloseConnection(connection))
|
||||
break
|
||||
|
||||
async def open_connection(self, event: events.OpenConnection):
|
||||
reader, writer = await asyncio.open_connection(
|
||||
*event.connection.address
|
||||
)
|
||||
self.transports[event.connection] = StreamIO(reader, writer)
|
||||
await self.handle_connection(event.connection)
|
||||
|
||||
async def server_event(self, event: events.Event):
|
||||
print("*", event)
|
||||
async with self.lock:
|
||||
@ -66,13 +76,7 @@ class ConnectionHandler:
|
||||
for event in layer_events:
|
||||
print("<<", event)
|
||||
if isinstance(event, events.OpenConnection):
|
||||
reader, writer = await asyncio.open_connection(
|
||||
*event.connection.address
|
||||
)
|
||||
self.transports[event.connection] = StreamIO(reader, writer)
|
||||
asyncio.ensure_future(self.handle_connection(event.connection))
|
||||
layer_events.send(42)
|
||||
|
||||
asyncio.ensure_future(self.open_connection(event))
|
||||
elif isinstance(event, events.SendData):
|
||||
self.transports[event.connection].w.write(event.data)
|
||||
else:
|
||||
|
@ -1,25 +1,96 @@
|
||||
import functools
|
||||
|
||||
from mitmproxy.proxy.protocol2 import events
|
||||
from mitmproxy.proxy.protocol2.context import ClientServerContext
|
||||
from mitmproxy.proxy.protocol2.events import TEventGenerator
|
||||
from mitmproxy.proxy.protocol2.layer import Layer
|
||||
|
||||
"""
|
||||
Utility decorators that help build state machines
|
||||
"""
|
||||
|
||||
|
||||
def defer(event_type):
|
||||
"""
|
||||
Queue up the events matching the specified event type and emit them immediately
|
||||
after the state has changed.
|
||||
"""
|
||||
|
||||
def decorator(f):
|
||||
deferred = []
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapper(self, event: events.Event):
|
||||
if isinstance(event, event_type):
|
||||
deferred.append(event)
|
||||
else:
|
||||
yield from f(self, event)
|
||||
if self.state != f:
|
||||
for event in deferred:
|
||||
yield from self.state(event)
|
||||
deferred.clear()
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def exit_on_close(f):
|
||||
"""
|
||||
Stop all further interaction once a single close event has been observed.
|
||||
"""
|
||||
closed = False
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapper(self, event: events.Event):
|
||||
nonlocal closed
|
||||
if isinstance(event, events.CloseConnection):
|
||||
closed = True
|
||||
if not closed:
|
||||
yield from f(self, event)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class TCPLayer(Layer):
|
||||
context = None # type: ClientServerContext
|
||||
|
||||
def handle_event(self, event: events.Event) -> TEventGenerator:
|
||||
if isinstance(event, events.Start):
|
||||
if not self.context.server.connected:
|
||||
try:
|
||||
t = yield events.OpenConnection(self.context.server)
|
||||
yield
|
||||
print("opening took {}s".format(t)) # example on how we can implement .ask()
|
||||
except Exception as e:
|
||||
print("Could not connect to server: {}".format(e))
|
||||
def __init__(self, context: ClientServerContext):
|
||||
super().__init__(context)
|
||||
self.state = self.start
|
||||
|
||||
def handle_event(self, event: events.Event) -> TEventGenerator:
|
||||
yield from self.state(event)
|
||||
|
||||
def start(self, event: events.Event) -> TEventGenerator:
|
||||
if isinstance(event, events.OpenConnection):
|
||||
if not self.context.server.connected:
|
||||
yield events.OpenConnection(self.context.server)
|
||||
self.state = self.wait_for_open
|
||||
else:
|
||||
self.state = self.relay_messages
|
||||
else:
|
||||
raise TypeError("Unexpected event: {}".format(event))
|
||||
|
||||
@defer(events.ReceiveData)
|
||||
@exit_on_close
|
||||
def wait_for_open(self, event: events.Event) -> TEventGenerator:
|
||||
if isinstance(event, events.OpenConnection):
|
||||
# connection is now open
|
||||
self.state = self.relay_messages
|
||||
else:
|
||||
raise TypeError("Unexpected event: {}".format(event))
|
||||
# noinspection PyUnreachableCode
|
||||
yield
|
||||
|
||||
def relay_messages(self, event: events.Event) -> TEventGenerator:
|
||||
if isinstance(event, events.ReceiveData):
|
||||
if event.connection == self.context.client:
|
||||
dst = self.context.server
|
||||
else:
|
||||
dst = self.context.client
|
||||
yield events.SendData(dst, event.data)
|
||||
if isinstance(event, events.CloseConnection):
|
||||
pass # TODO: close other connection here.
|
||||
else:
|
||||
raise TypeError("Unexpected event: {}".format(event))
|
||||
|
Loading…
Reference in New Issue
Block a user