implement blocking yields

This commit is contained in:
Maximilian Hils 2017-06-14 21:05:51 +02:00
parent d64b3e9491
commit c1ddd478cc
11 changed files with 310 additions and 131 deletions

View File

@ -0,0 +1,15 @@
"""
Experimental sans-io implementation of mitmproxy's protocol stack.
Most important primitives:
- layers: represent protocol layers, e.g. one for tcp, tls, and so on. Layers are stacked, so
a typical configuration might be ReverseProxy/TLS/TCP.
- server: the proxy server does all IO and communication with the mitmproxy master.
It creates the top layer for each incoming client connection.
- events: When IO actions occur at the proxy server, they are passed down to the top layer as events.
- commands: In the other direction, layers can emit commands to higher layers or the proxy server.
This is used to e.g. send data, request for new connections to be opened, or to use mitmproxy's
script hooks.
- context: The context is the connection context each layer is provided with. This is still very
much WIP, but this should expose stuff like Server Name Indication to lower layers.
"""

View File

@ -0,0 +1,73 @@
"""
Commands make it possible for layers to communicate with the "outer world",
e.g. to perform IO or to ask the master.
A command is issued by a proxy layer and is then passed upwards to the proxy server, and from there
possibly to the master and addons.
The counterpart to commands are events.
"""
import typing
from mitmproxy.proxy.protocol2.context import Connection
class Command:
"""
Base class for all commands
"""
blocking: bool = False
"""
Determines if the command blocks until it has been completed.
Example:
reply = yield Hook("requestheaders", flow)
"""
def __repr__(self):
return f"{type(self).__name__}({repr(self.__dict__)})"
class ConnectionCommand(Command):
"""
Commands involving a specific connection
"""
connection: Connection
def __init__(self, connection: Connection) -> None:
self.connection = connection
class SendData(ConnectionCommand):
"""
Send data to a remote peer
"""
data: bytes
def __init__(self, connection: Connection, data: bytes) -> None:
super().__init__(connection)
self.data = data
class OpenConnection(ConnectionCommand):
"""
Open a new connection
"""
blocking = True
class Hook(Command):
"""
Callback to the master (like ".ask()")
"""
blocking = True
name: str
data: typing.Any
def __init__(self, name: str, data: typing.Any) -> None:
self.name = name
self.data = data
TCommandGenerator = typing.Generator[Command, typing.Any, None]

View File

@ -2,11 +2,11 @@ class Connection:
"""
Connections exposed to the layers only contain metadata, no socket objects.
"""
address = None # type: tuple
connected = None # type: bool
address: tuple
connected: bool
def __repr__(self):
return "{}({})".format(type(self).__name__, repr(self.__dict__))
return f"{type(self).__name__}({repr(self.__dict__)})"
class Client(Connection):
@ -32,7 +32,7 @@ class Context:
lookup did in the previous implementation.
"""
client = None # type: Client
client: Client
def __init__(self, client: Client) -> None:
self.client = client
@ -43,7 +43,7 @@ class ClientServerContext(Context):
In most cases, there's also only exactly one server.
"""
server = None # type: Server
server: Server
def __init__(self, client: Client, server: Server) -> None:
super().__init__(client)

View File

@ -1,64 +1,90 @@
"""
The only way for layers to do IO is to emit events indicating what should be done.
For example, a layer may first emit a OpenConnection event and then a SendData event.
Likewise, layers only receive IO via events.
When IO actions occur at the proxy server, they are passed down to layers as events.
Events represent the only way for layers to receive new data from sockets.
The counterpart to events are commands.
"""
from typing import Iterable
import typing
from mitmproxy.proxy.protocol2 import commands
from mitmproxy.proxy.protocol2.context import Connection
class Event:
"""
Base class for all events.
"""
def __repr__(self):
return "{}({})".format(type(self).__name__, repr(self.__dict__))
TEventGenerator = Iterable[Event]
return f"{type(self).__name__}({repr(self.__dict__)})"
class Start(Event):
"""
Every layer initially receives a start event.
This is useful to emit events on startup, which otherwise would not be possible.
This is useful to emit events on startup.
"""
pass
class ConnectionEvent(Event):
"""
All events involving IO connections.
All events involving connection IO.
"""
connection: Connection
def __init__(self, connection: Connection):
self.connection = connection
class OpenConnection(ConnectionEvent):
pass
class CloseConnection(ConnectionEvent):
"""
(this would be send by proxy and by layers. keep it that way?)
Remote has closed a connection.
"""
pass
class SendData(ConnectionEvent):
def __init__(self, connection: Connection, data: bytes) -> None:
super().__init__(connection)
self.data = data
class ReceiveData(ConnectionEvent):
"""
Remote has sent some data.
"""
def __init__(self, connection: Connection, data: bytes) -> None:
super().__init__(connection)
self.data = data
class ReceiveClientData(ReceiveData):
"""
Client has sent data.
These subclasses simplify code for simple layers with one server and one client.
"""
pass
class ReceiveServerData(ReceiveData):
pass
class CommandReply(Event):
"""
Emitted when a command has been finished, e.g.
when the master has replied or when we have established a server connection.
"""
command: commands.Command
reply: typing.Any
def __init__(self, command: commands.Command, reply: typing.Any):
self.command = command
self.reply = reply
def __new__(cls, *args, **kwargs):
if cls is CommandReply:
raise TypeError("CommandReply may not be instantiated directly.")
return super().__new__(cls)
class OpenConnectionReply(CommandReply):
reply: bool
def __init__(self, command: commands.Command, ok: bool):
super().__init__(command, ok)

View File

@ -1,14 +1,89 @@
"""
Base class for protocol layers.
"""
import collections
import typing
from abc import ABCMeta, abstractmethod
from mitmproxy.proxy.protocol2 import commands, events
from mitmproxy.proxy.protocol2.context import Context
from mitmproxy.proxy.protocol2.events import Event, TEventGenerator
from mitmproxy.proxy.protocol2.events import Event
class Paused(typing.NamedTuple):
"""
State of a layer that's paused because it is waiting for a command reply.
"""
command: commands.Command
generator: commands.TCommandGenerator
class Layer(metaclass=ABCMeta):
context: Context
_paused: typing.Optional[Paused]
def __init__(self, context: Context) -> None:
self.context = context
self._paused = None
self._paused_event_queue: typing.Deque[events.Event] = collections.deque()
@abstractmethod
def handle_event(self, event: Event) -> TEventGenerator:
def handle(self, event: Event) -> commands.TCommandGenerator:
"""Handle a proxy server event"""
if False:
yield None
def handle_event(self, event: Event) -> commands.TCommandGenerator:
if self._paused:
# did we just receive the reply we were waiting for?
pause_finished = (
isinstance(event, events.CommandReply) and
event.command == self._paused.command
)
if pause_finished:
yield from self.__continue(event)
else:
self._paused_event_queue.append(event)
print("Paused Event Queue: " + repr(self._paused_event_queue))
else:
command_generator = self.handle(event)
yield from self.__process(command_generator)
def __process(self, command_generator: commands.TCommandGenerator, send=None):
"""
yield all commands from a generator.
if a command is blocking, the layer is paused and this function returns before
processing any other commands.
"""
try:
command = command_generator.send(send)
except StopIteration:
return
while command:
if command.blocking is True:
print("start pausing")
command.blocking = self # assign to our layer so that higher layers don't block.
self._paused = Paused(
command,
command_generator,
)
yield command
return
else:
yield command
command = next(command_generator, None)
def __continue(self, event: events.CommandReply):
"""continue processing events after being paused"""
print("continue")
command_generator = self._paused.generator
self._paused = None
yield from self.__process(command_generator, event.reply)
while not self._paused and self._paused_event_queue:
event = self._paused_event_queue.popleft()
print(f"<# Paused event: {event}")
command_generator = self.handle(event)
yield from self.__process(command_generator)
print("#>")

View File

@ -1,5 +1,6 @@
from mitmproxy.proxy.protocol2.commands import TCommandGenerator
from mitmproxy.proxy.protocol2.context import ClientServerContext, Context, Server
from mitmproxy.proxy.protocol2.events import Event, TEventGenerator
from mitmproxy.proxy.protocol2.events import Event
from mitmproxy.proxy.protocol2.layer import Layer
from mitmproxy.proxy.protocol2.tls import TLSLayer
@ -10,14 +11,7 @@ class ReverseProxy(Layer):
server = Server(server_addr)
self.child_context = ClientServerContext(context.client, server)
self.child_layer = TLSLayer(self.child_context, True, True)
# self.child_layer = TCPLayer(self.child_context)
def handle_event(self, event: Event) -> TEventGenerator:
def handle(self, event: Event) -> TCommandGenerator:
yield from self.child_layer.handle_event(event)
# If we cannot use yield from, we have to use something like this:
# x = None
# evts = self.child_layer.handle_event(event)
# while True:
# x = yield evts.send(x)
# https://www.python.org/dev/peps/pep-0380/#formal-semantics
# This is obviously ugly - but do we have any cases where we need to intercept messages like this?

View File

@ -1,3 +1,4 @@
# This is outdated, only the async version is kept up to date.
"""
Minimal server implementation based on https://docs.python.org/3/library/selectors.html#examples.
May be worth to replace this with something asyncio-based to overcome the issues outlined by the
@ -8,7 +9,7 @@ import selectors
import socket
from typing import MutableMapping
from mitmproxy.proxy.protocol2 import events
from mitmproxy.proxy.protocol2 import events, commands
from mitmproxy.proxy.protocol2.context import Connection
from mitmproxy.proxy.protocol2.context import Context, Client
from mitmproxy.proxy.protocol2.events import Event
@ -30,7 +31,7 @@ class ConnectionHandler:
layer = ReverseProxy(context, ("example.com", 80))
self.server_event(layer, events.OpenConnection(client))
# self.server_event(layer, commands.OpenConnection(client))
callback = functools.partial(self.read, layer, client)
@ -57,7 +58,7 @@ class ConnectionHandler:
layer_events = layer.handle_event(event)
for event in layer_events:
print("<<", event)
if isinstance(event, events.OpenConnection):
if isinstance(event, commands.OpenConnection):
# FIXME: This is blocking!
sock = socket.create_connection(event.connection.address)
sock.setblocking(False)
@ -70,7 +71,7 @@ class ConnectionHandler:
callback
)
self.connections[event.connection] = sock
elif isinstance(event, events.SendData):
elif isinstance(event, commands.SendData):
# FIXME: This may fail.
self.connections[event.connection].sendall(event.data)
else:
@ -106,11 +107,3 @@ class TCPServer:
if __name__ == '__main__':
s = TCPServer(('', 8080))
s.run()
"""
1) full async
uh?
2) notifier
3) thread per connection
"""

View File

@ -1,9 +1,17 @@
"""
Proxy Server Implementation using asyncio.
The very high level overview is as follows:
- Spawn one coroutine per client connection and create a reverse proxy layer to example.com
- Process any commands from layer (such as opening a server connection)
- Wait for any IO and send it as events to top layer.
"""
import asyncio
import collections
import socket
from typing import MutableMapping
from mitmproxy.proxy.protocol2 import events
from mitmproxy.proxy.protocol2 import events, commands
from mitmproxy.proxy.protocol2.context import Client, Context
from mitmproxy.proxy.protocol2.context import Connection
from mitmproxy.proxy.protocol2.reverse_proxy import ReverseProxy
@ -19,9 +27,11 @@ class ConnectionHandler:
self.context = Context(self.client)
self.layer = ReverseProxy(self.context, ("example.com", 443))
# self.layer = ReverseProxy(self.context, ("example.com", 80))
self.transports = {} # type: MutableMapping[Connection, StreamIO]
self.transports[self.client] = StreamIO(reader, writer)
self.transports: MutableMapping[Connection, StreamIO] = {
self.client: StreamIO(reader, writer)
}
self.lock = asyncio.Lock()
@ -29,12 +39,13 @@ class ConnectionHandler:
await self.server_event(events.Start())
await self.handle_connection(self.client)
for connection in self.transports:
# FIXME: dictionary is changing size during iteration
print("client connection done, closing transports!")
for connection in list(self.transports):
await self.close(connection)
# TODO: teardown all other conns.
print("client connection done!")
print("transports closed!")
async def close(self, connection):
print("Closing", connection)
@ -47,9 +58,6 @@ class ConnectionHandler:
io.w.close()
async def handle_connection(self, connection):
connection.connected = True
if connection != self.client:
await self.server_event(events.OpenConnection(connection))
reader, writer = self.transports[connection]
while True:
try:
@ -63,27 +71,30 @@ class ConnectionHandler:
await self.server_event(events.ReceiveServerData(connection, data))
else:
connection.connected = False
await self.close(connection)
if connection in self.transports:
await self.close(connection)
await self.server_event(events.CloseConnection(connection))
break
async def open_connection(self, event: events.OpenConnection):
async def open_connection(self, command: commands.OpenConnection):
reader, writer = await asyncio.open_connection(
*event.connection.address
*command.connection.address
)
self.transports[event.connection] = StreamIO(reader, writer)
await self.handle_connection(event.connection)
self.transports[command.connection] = StreamIO(reader, writer)
command.connection.connected = True
await self.server_event(events.OpenConnectionReply(command, "success"))
await self.handle_connection(command.connection)
async def server_event(self, event: events.Event):
print("*", event)
print("*", type(event).__name__)
async with self.lock:
print("<#", event)
layer_events = self.layer.handle_event(event)
for event in layer_events:
print("<<", event)
if isinstance(event, events.OpenConnection):
if isinstance(event, commands.OpenConnection):
asyncio.ensure_future(self.open_connection(event))
elif isinstance(event, events.SendData):
elif isinstance(event, commands.SendData):
self.transports[event.connection].w.write(event.data)
else:
raise NotImplementedError("Unexpected event: {}".format(event))

View File

@ -1,45 +1,44 @@
import functools
import typing
from warnings import warn
from mitmproxy.proxy.protocol2 import events
from mitmproxy.proxy.protocol2 import events, commands
from mitmproxy.proxy.protocol2.context import ClientServerContext
from mitmproxy.proxy.protocol2.events import TEventGenerator
from mitmproxy.proxy.protocol2.layer import Layer
from mitmproxy.proxy.protocol2.utils import defer, only
from mitmproxy.proxy.protocol2.utils import exit_on_close
from mitmproxy.proxy.protocol2.utils import expect
class TCPLayer(Layer):
context = None # type: ClientServerContext
"""
Simple TCP layer that just relays messages right now.
"""
context: ClientServerContext = None
# this is like a mini state machine.
state: typing.Callable[[events.Event], commands.TCommandGenerator]
def __init__(self, context: ClientServerContext):
super().__init__(context)
self.state = self.start
def handle_event(self, event: events.Event) -> TEventGenerator:
def handle(self, event: events.Event) -> commands.TCommandGenerator:
yield from self.state(event)
@only(events.Start)
def start(self, _) -> TEventGenerator:
@expect(events.Start)
def start(self, _) -> commands.TCommandGenerator:
if not self.context.server.connected:
yield events.OpenConnection(self.context.server)
self.state = self.wait_for_open
else:
self.state = self.relay_messages
@defer(events.ReceiveData)
@exit_on_close
@only(events.OpenConnection)
def wait_for_open(self, _) -> TEventGenerator:
print(r"open connection...")
ok = yield commands.OpenConnection(self.context.server)
print(r"connection opened! \o/", ok)
self.state = self.relay_messages
yield from []
@only(events.ReceiveData, events.CloseConnection)
def relay_messages(self, event: events.Event) -> TEventGenerator:
@expect(events.ReceiveData, events.CloseConnection)
def relay_messages(self, event: events.Event) -> commands.TCommandGenerator:
if isinstance(event, events.ReceiveClientData):
yield events.SendData(self.context.server, event.data)
yield commands.SendData(self.context.server, event.data)
elif isinstance(event, events.ReceiveServerData):
yield events.SendData(self.context.client, event.data)
yield commands.SendData(self.context.client, event.data)
elif isinstance(event, events.CloseConnection):
warn("unimplemented: tcp.relay_message:close")
# TODO: close other connection here.

View File

@ -1,45 +1,45 @@
"""
TLS man-in-the-middle layer.
"""
# We may want to split this up into client (only once) and server (for every server) layer.
import os
from typing import MutableMapping
from warnings import warn
from OpenSSL import SSL
from mitmproxy.certs import CertStore
from mitmproxy.options import DEFAULT_CLIENT_CIPHERS
from mitmproxy.proxy.protocol2 import events
from mitmproxy.proxy.protocol.tls import DEFAULT_CLIENT_CIPHERS
from mitmproxy.proxy.protocol2 import events, commands
from mitmproxy.proxy.protocol2.context import ClientServerContext, Connection
from mitmproxy.proxy.protocol2.events import TEventGenerator
from mitmproxy.proxy.protocol2.layer import Layer
from mitmproxy.proxy.protocol2.tcp import TCPLayer
from mitmproxy.proxy.protocol2.utils import only, defer
from mitmproxy.proxy.protocol2.utils import expect
class TLSLayer(Layer):
context = None # type: ClientServerContext
client_tls = None # type: bool
server_tls = None # type: bool
# client_data = None # type: Buffer
# server_data = None # type: Buffer
child_layer = None # type: Layer
context: ClientServerContext = None
client_tls: bool = None # FIXME: not yet used.
server_tls: bool = None
child_layer: Layer = None
def __init__(self, context: ClientServerContext, client_tls: bool, server_tls: bool):
super().__init__(context)
self.state = self.start
self.client_tls = client_tls
self.server_tls = server_tls
# self.client_data = Buffer()
# self.server_data = Buffer()
self.tls = {} # type: MutableMapping[Connection, SSL.Connection]
self.tls: MutableMapping[Connection, SSL.Connection] = {}
def handle_event(self, event: events.Event) -> TEventGenerator:
def handle(self, event: events.Event) -> commands.TCommandGenerator:
yield from self.state(event)
@only(events.Start)
def start(self, _) -> TEventGenerator:
@expect(events.Start)
def start(self, _) -> commands.TCommandGenerator:
yield from self.start_client_tls()
if not self.context.server.connected:
yield events.OpenConnection(self.context.server)
else:
yield from self.start_server_tls()
# TODO: This should be lazy.
yield commands.OpenConnection(self.context.server)
yield from self.start_server_tls()
self.state = self.establish_tls
def start_client_tls(self):
@ -77,13 +77,11 @@ class TLSLayer(Layer):
# Okay, nothing more waiting to be sent.
return
else:
yield events.SendData(conn, data)
yield commands.SendData(conn, data)
@only(events.OpenConnection, events.CloseConnection, events.ReceiveData)
def establish_tls(self, event: events.Event) -> TEventGenerator:
if isinstance(event, events.OpenConnection):
yield from self.start_server_tls()
elif isinstance(event, events.ReceiveData):
@expect(events.CloseConnection, events.ReceiveData)
def establish_tls(self, event: events.Event) -> commands.TCommandGenerator:
if isinstance(event, events.ReceiveData):
self.tls[event.connection].bio_write(event.data)
try:
self.tls[event.connection].do_handshake()
@ -93,32 +91,23 @@ class TLSLayer(Layer):
both_handshakes_done = (
self.tls[self.context.client].get_peer_finished() and
self.context.server in self.tls and self.tls[self.context.server].get_peer_finished()
self.context.server in self.tls and self.tls[
self.context.server].get_peer_finished()
)
if both_handshakes_done:
print("both handshakes done")
# FIXME: This'd be accomplised by asking the master.
self.child_layer = TCPLayer(self.context)
yield from self.child_layer.handle_event(events.Start())
self.state = self.relay_messages
yield from self.state(events.ReceiveData(self.context.server, b""))
yield from self.state(events.ReceiveData(self.context.client, b""))
elif isinstance(event, events.CloseConnection):
warn("unimplemented: tls.establish_tls:close")
@defer(events.ReceiveData, events.CloseConnection)
def set_next_layer(self, layer):
# FIXME: That'd be a proper event, not just the layer.
self.child_layer = layer # type: Layer
self.state = self.relay_messages
@only(events.CloseConnection, events.ReceiveData)
def relay_messages(self, event: events.Event) -> TEventGenerator:
@expect(events.CloseConnection, events.ReceiveData)
def relay_messages(self, event: events.Event) -> commands.TCommandGenerator:
if isinstance(event, events.ReceiveData):
if event.data:
self.tls[event.connection].bio_write(event.data)
@ -135,7 +124,7 @@ class TLSLayer(Layer):
event_for_child = events.ReceiveServerData(self.context.server, plaintext)
for event_from_child in self.child_layer.handle_event(event_for_child):
if isinstance(event_from_child, events.SendData):
if isinstance(event_from_child, commands.SendData):
self.tls[event_from_child.connection].sendall(event_from_child.data)
yield from self.tls_interact(event_from_child.connection)
else:

View File

@ -7,6 +7,7 @@ from typing import Optional
from mitmproxy.proxy.protocol2 import events
# This is not used at the moment.
class Buffer:
def __init__(self):
self._buffer = bytearray()
@ -43,7 +44,7 @@ class Buffer:
return bytes(chunk)
def only(*event_types):
def expect(*event_types):
"""
Only allow the given event type.
If another event is passed, a TypeError is raised.
@ -55,13 +56,15 @@ def only(*event_types):
if isinstance(event, event_types):
yield from f(self, event)
else:
raise TypeError("Invalid event type: Expected {}, got {}".format(event_types, event))
raise TypeError(
"Invalid event type: Expected {}, got {}".format(event_types, event))
return wrapper
return decorator
# not used at the moment. We may not need this at all if the blocking yield continues to work as expected.
def defer(*event_types):
"""
Queue up the events matching the specified event type and emit them immediately
@ -87,6 +90,7 @@ def defer(*event_types):
return decorator
# not used at the moment.
def exit_on_close(f):
"""
Stop all further interaction once a single close event has been observed.