diff --git a/mitmproxy/proxy/protocol2/commands.py b/mitmproxy/proxy/protocol2/commands.py index 2baac64af..aa577cb0e 100644 --- a/mitmproxy/proxy/protocol2/commands.py +++ b/mitmproxy/proxy/protocol2/commands.py @@ -28,6 +28,14 @@ class Command: def __repr__(self): return f"{type(self).__name__}({repr(self.__dict__)})" + def __eq__(self, other): + if isinstance(other, self.__class__): + return all( + self.__dict__[k] == other.__dict__[k] + for k in self.__dict__ if k != "blocking" + ) + return False + class ConnectionCommand(Command): """ diff --git a/mitmproxy/proxy/protocol2/events.py b/mitmproxy/proxy/protocol2/events.py index 4aafe58d4..6f11a6c18 100644 --- a/mitmproxy/proxy/protocol2/events.py +++ b/mitmproxy/proxy/protocol2/events.py @@ -17,6 +17,11 @@ class Event: def __repr__(self): return f"{type(self).__name__}({repr(self.__dict__)})" + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.__dict__ == other.__dict__ + return False + class Start(Event): """ diff --git a/mitmproxy/proxy/protocol2/layer.py b/mitmproxy/proxy/protocol2/layer.py index 55494ce95..dba6fdcd2 100644 --- a/mitmproxy/proxy/protocol2/layer.py +++ b/mitmproxy/proxy/protocol2/layer.py @@ -27,6 +27,9 @@ class Layer(metaclass=ABCMeta): self._paused = None self._paused_event_queue: typing.Deque[events.Event] = collections.deque() + def _debug(self, x): + pass # print(x) + @abstractmethod def handle(self, event: Event) -> commands.TCommandGenerator: """Handle a proxy server event""" @@ -38,13 +41,13 @@ class Layer(metaclass=ABCMeta): # did we just receive the reply we were waiting for? pause_finished = ( isinstance(event, events.CommandReply) and - event.command == self._paused.command + event.command is self._paused.command ) if pause_finished: yield from self.__continue(event) else: self._paused_event_queue.append(event) - print("Paused Event Queue: " + repr(self._paused_event_queue)) + self._debug("Paused Event Queue: " + repr(self._paused_event_queue)) else: command_generator = self.handle(event) yield from self.__process(command_generator) @@ -62,7 +65,7 @@ class Layer(metaclass=ABCMeta): while command: if command.blocking is True: - print("start pausing") + self._debug("start pausing") command.blocking = self # assign to our layer so that higher layers don't block. self._paused = Paused( command, @@ -76,14 +79,14 @@ class Layer(metaclass=ABCMeta): def __continue(self, event: events.CommandReply): """continue processing events after being paused""" - print("continue") + self._debug("continue") command_generator = self._paused.generator self._paused = None yield from self.__process(command_generator, event.reply) while not self._paused and self._paused_event_queue: event = self._paused_event_queue.popleft() - print(f"<# Paused event: {event}") + self._debug(f"<# Paused event: {event}") command_generator = self.handle(event) yield from self.__process(command_generator) - print("#>") + self._debug("#>") diff --git a/mitmproxy/proxy/protocol2/tcp.py b/mitmproxy/proxy/protocol2/tcp.py index b185f55c0..bb6cf7cf5 100644 --- a/mitmproxy/proxy/protocol2/tcp.py +++ b/mitmproxy/proxy/protocol2/tcp.py @@ -26,9 +26,7 @@ class TCPLayer(Layer): @expect(events.Start) def start(self, _) -> commands.TCommandGenerator: if not self.context.server.connected: - print(r"open connection...") ok = yield commands.OpenConnection(self.context.server) - print(r"connection opened! \o/", ok) self.state = self.relay_messages @expect(events.DataReceived, events.ConnectionClosed) diff --git a/mitmproxy/proxy/protocol2/test/__init__.py b/mitmproxy/proxy/protocol2/test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mitmproxy/proxy/protocol2/test/conftest.py b/mitmproxy/proxy/protocol2/test/conftest.py new file mode 100644 index 000000000..110114070 --- /dev/null +++ b/mitmproxy/proxy/protocol2/test/conftest.py @@ -0,0 +1,11 @@ +import pytest + +from mitmproxy.proxy.protocol2 import context + + +@pytest.fixture +def tctx(): + return context.ClientServerContext( + context.Client("client"), + context.Server("server") + ) diff --git a/mitmproxy/proxy/protocol2/test/test_tcp.py b/mitmproxy/proxy/protocol2/test/test_tcp.py new file mode 100644 index 000000000..1111b35a6 --- /dev/null +++ b/mitmproxy/proxy/protocol2/test/test_tcp.py @@ -0,0 +1,46 @@ +from mitmproxy.proxy.protocol2 import tcp +from mitmproxy.proxy.protocol2.test.tutils import playbook +from .. import commands +from .. import events + + +def test_open_connection(tctx): + """ + If there is no server connection yet, establish one, + because the server may send data first. + """ + assert ( + playbook(tcp.TCPLayer(tctx)) + << commands.OpenConnection(tctx.server) + ) + + tctx.server.connected = True + assert ( + playbook(tcp.TCPLayer(tctx)) + << None + ) + + +def test_simple(tctx): + """open connection, receive data, send it to peer""" + assert ( + playbook(tcp.TCPLayer(tctx)) + << commands.OpenConnection(tctx.server) + >> events.OpenConnectionReply(-1, "ok") + >> events.ClientDataReceived(tctx.client, b"hello!") + << commands.SendData(tctx.server, b"hello!") + ) + + +def test_receive_data_before_server_connected(tctx): + """ + assert that data received before a server connection is established + will still be forwarded. + """ + assert ( + playbook(tcp.TCPLayer(tctx)) + << commands.OpenConnection(tctx.server) + >> events.ClientDataReceived(tctx.client, b"hello!") + >> events.OpenConnectionReply(-2, "ok") + << commands.SendData(tctx.server, b"hello!") + ) diff --git a/mitmproxy/proxy/protocol2/test/tutils.py b/mitmproxy/proxy/protocol2/test/tutils.py new file mode 100644 index 000000000..326cc7dfe --- /dev/null +++ b/mitmproxy/proxy/protocol2/test/tutils.py @@ -0,0 +1,84 @@ +import itertools +import typing + +from mitmproxy.proxy.protocol2 import commands +from mitmproxy.proxy.protocol2 import events +from mitmproxy.proxy.protocol2 import layer + + +class playbook: + """ + Assert that a layer emits the expected commands in reaction to a given sequence of events. + For example, the following code asserts that the TCP layer emits an OpenConnection command + immediately after starting and does not yield any further commands as a reaction to successful + connection establishment. + + assert playbook(tcp.TCPLayer(tctx)) \ + << commands.OpenConnection(tctx.server) + >> events.OpenConnectionReply(-1, "ok") # -1 = reply to command in previous line. + << None # this line is optional. + + This is syntactic sugar for the following: + + t = tcp.TCPLayer(tctx) + x1 = list(t.handle_event(events.Start())) + assert x1 == [commands.OpenConnection(tctx.server)] + x2 = list(t.handle_event(events.OpenConnectionReply(x1[-1]))) + assert x2 == [] + """ + layer: layer.Layer + playbook: typing.List[typing.Union[commands.Command, events.Event]] + + def __init__( + self, + layer, + playbook=None, + ): + if playbook is None: + playbook = [ + events.Start() + ] + + self.layer = layer + self.playbook = playbook + + def __rshift__(self, e): + """Add an event to send""" + assert isinstance(e, events.Event) + self.playbook.append(e) + return self + + def __lshift__(self, c): + """Add an expected command""" + if c is None: + return self + assert isinstance(c, commands.Command) + self.playbook.append(c) + return self + + def __bool__(self): + """Determine if playbook is correct.""" + actual = [] + for i, x in enumerate(self.playbook): + if isinstance(x, commands.Command): + pass + else: + if isinstance(x, events.CommandReply): + if isinstance(x.command, int): + x.command = actual[i + x.command] + + actual.append(x) + actual.extend( + self.layer.handle_event(x) + ) + + if actual != self.playbook: + # print debug info + for a, e in itertools.zip_longest(actual, self.playbook): + if a == e: + print(f"= {e}") + else: + print(f"✓ {e}") + print(f"✗ {a}") + return False + return True