From e0eb77a794e79548b4a16720845e1b6375c27f92 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 9 Nov 2019 00:27:25 +0100 Subject: [PATCH] [sans-io] add transparent proxy, improve testing --- mitmproxy/proxy2/commands.py | 8 ++ mitmproxy/proxy2/context.py | 9 ++ mitmproxy/proxy2/events.py | 32 ++++- mitmproxy/proxy2/layers/__init__.py | 5 - mitmproxy/proxy2/layers/modes.py | 36 ++++-- mitmproxy/proxy2/layers/tcp.py | 39 +++--- mitmproxy/proxy2/layers/tls.py | 4 +- mitmproxy/proxy2/server.py | 5 +- mitmproxy/proxy2/utils.py | 3 +- test/mitmproxy/platform/__init__.py | 0 test/mitmproxy/proxy2/layers/test_tcp.py | 157 ++++++++++------------- test/mitmproxy/proxy2/test_tutils.py | 8 +- test/mitmproxy/proxy2/tutils.py | 101 ++++++++++----- 13 files changed, 245 insertions(+), 162 deletions(-) delete mode 100644 test/mitmproxy/platform/__init__.py diff --git a/mitmproxy/proxy2/commands.py b/mitmproxy/proxy2/commands.py index ef9eb1620..b84a6471c 100644 --- a/mitmproxy/proxy2/commands.py +++ b/mitmproxy/proxy2/commands.py @@ -90,6 +90,14 @@ class Hook(Command): # return f"Hook({self.name}: {data})" +class GetSocket(ConnectionCommand): + """ + Get the underlying socket. + This should really never be used, but is required to implement transparent mode. + """ + blocking = True + + class Log(Command): message: str level: str diff --git a/mitmproxy/proxy2/context.py b/mitmproxy/proxy2/context.py index 714cb1521..1574cc5b8 100644 --- a/mitmproxy/proxy2/context.py +++ b/mitmproxy/proxy2/context.py @@ -27,6 +27,15 @@ class Connection: def connected(self): return self.state is ConnectionState.OPEN + @connected.setter + def connected(self, val: bool) -> None: + # We should really set .state, but verdict is still due if we even want to keep .state around. + # We allow setting .connected while we figure that out. + if val: + self.state = ConnectionState.OPEN + else: + self.state = ConnectionState.CLOSED + def __repr__(self): return f"{type(self).__name__}({repr(self.__dict__)})" diff --git a/mitmproxy/proxy2/events.py b/mitmproxy/proxy2/events.py index c1c4c4819..1ec5cebe2 100644 --- a/mitmproxy/proxy2/events.py +++ b/mitmproxy/proxy2/events.py @@ -3,6 +3,7 @@ When IO actions occur at the proxy server, they are passed down to layers as eve Events represent the only way for layers to receive new data from sockets. The counterpart to events are commands. """ +import socket import typing from mitmproxy.proxy2 import commands @@ -62,7 +63,7 @@ class CommandReply(Event): Emitted when a command has been finished, e.g. when the master has replied or when we have established a server connection. """ - command: typing.Union[commands.Command, int] + command: commands.Command reply: typing.Any def __init__(self, command: typing.Union[commands.Command, int], reply: typing.Any): @@ -74,10 +75,19 @@ class CommandReply(Event): raise TypeError("CommandReply may not be instantiated directly.") return super().__new__(cls) + def __init_subclass__(cls, **kwargs): + command_cls = cls.__annotations__["command"] + if not issubclass(command_cls, commands.Command) and command_cls is not commands.Command: + raise RuntimeError(f"{command_cls} needs a properly annotated command attribute.") + command_reply_subclasses[command_cls] = cls + + +command_reply_subclasses: typing.Dict[commands.Command, typing.Type[CommandReply]] = {} + class OpenConnectionReply(CommandReply): - command: typing.Union[commands.OpenConnection, int] - reply: str + command: commands.OpenConnection + reply: typing.Optional[str] def __init__( self, @@ -88,10 +98,22 @@ class OpenConnectionReply(CommandReply): class HookReply(CommandReply): - command: typing.Union[commands.Hook, int] + command: commands.Hook - def __init__(self, command: typing.Union[commands.Hook, int]): + def __init__(self, command: commands.Hook): super().__init__(command, None) def __repr__(self): return f"HookReply({repr(self.command)[5:-1]})" + + +class GetSocketReply(CommandReply): + command: commands.GetSocket + reply: socket.socket + + def __init__( + self, + command: typing.Union[commands.GetSocket, int], + socket: socket.socket + ): + super().__init__(command, socket) diff --git a/mitmproxy/proxy2/layers/__init__.py b/mitmproxy/proxy2/layers/__init__.py index fe3ec0bd3..d075fbf85 100644 --- a/mitmproxy/proxy2/layers/__init__.py +++ b/mitmproxy/proxy2/layers/__init__.py @@ -1,8 +1,6 @@ from . import modes from .glue import GlueLayer -from mitmproxy.proxy2.layers.old.old_http import OldHTTPLayer from .http.http import HTTPLayer -from mitmproxy.proxy2.layers.old.http1 import ClientHTTP1Layer, ServerHTTP1Layer from .tcp import TCPLayer from .tls import ClientTLSLayer, ServerTLSLayer from .websocket import WebsocketLayer @@ -10,10 +8,7 @@ from .websocket import WebsocketLayer __all__ = [ "modes", "GlueLayer", - "OldHTTPLayer", # TODO remove this and replace with ClientHTTP1Layer "HTTPLayer", - "ClientHTTP1Layer", "ServerHTTP1Layer", - "ClientHTTP2Layer", "ServerHTTP2Layer", "TCPLayer", "ClientTLSLayer", "ServerTLSLayer", "WebsocketLayer", diff --git a/mitmproxy/proxy2/layers/modes.py b/mitmproxy/proxy2/layers/modes.py index a2795aeed..1695b3fa7 100644 --- a/mitmproxy/proxy2/layers/modes.py +++ b/mitmproxy/proxy2/layers/modes.py @@ -1,23 +1,41 @@ +from mitmproxy import platform from mitmproxy.net import server_spec -from mitmproxy.proxy2 import layer -from mitmproxy.proxy2.context import Context, Server +from mitmproxy.proxy2 import commands, events, layer +from mitmproxy.proxy2.context import Server +from mitmproxy.proxy2.utils import expect class ReverseProxy(layer.Layer): - def __init__(self, context: Context): - super().__init__(context) - spec = server_spec.parse_with_mode(context.options.mode)[1] + @expect(events.Start) + def _handle_event(self, event: events.Event) -> commands.TCommandGenerator: + spec = server_spec.parse_with_mode(self.context.options.mode)[1] self.context.server = Server(spec.address) - if spec.scheme != "http": + if spec.scheme not in ("http", "tcp"): self.context.server.tls = True - if not context.options.keep_host_header: + if not self.context.options.keep_host_header: self.context.server.sni = spec.address[0] child_layer = layer.NextLayer(self.context) self._handle_event = child_layer.handle_event + yield from child_layer.handle_event(event) class HttpProxy(layer.Layer): - def __init__(self, context: Context): - super().__init__(context) + @expect(events.Start) + def _handle_event(self, event: events.Event) -> commands.TCommandGenerator: child_layer = layer.NextLayer(self.context) self._handle_event = child_layer.handle_event + yield from child_layer.handle_event(event) + + +class TransparentProxy(layer.Layer): + @expect(events.Start) + def _handle_event(self, event: events.Event) -> commands.TCommandGenerator: + socket = yield commands.GetSocket(self.context.client) + try: + self.context.server.address = platform.original_addr(socket) + except Exception as e: + yield commands.Log(f"Transparent mode failure: {e!r}") + + child_layer = layer.NextLayer(self.context) + self._handle_event = child_layer.handle_event + yield from child_layer.handle_event(event) diff --git a/mitmproxy/proxy2/layers/tcp.py b/mitmproxy/proxy2/layers/tcp.py index db2ef9c2e..573fc8aad 100644 --- a/mitmproxy/proxy2/layers/tcp.py +++ b/mitmproxy/proxy2/layers/tcp.py @@ -1,4 +1,6 @@ -from mitmproxy import tcp, flow +from typing import Optional + +from mitmproxy import flow, tcp from mitmproxy.proxy2 import commands, events from mitmproxy.proxy2.context import Context from mitmproxy.proxy2.layer import Layer @@ -10,26 +12,25 @@ class TCPLayer(Layer): Simple TCP layer that just relays messages right now. """ context: Context - ignore: bool - flow: tcp.TCPFlow + flow: Optional[tcp.TCPFlow] def __init__(self, context: Context, ignore: bool = False): super().__init__(context) - self.ignore = ignore - self.flow = None + if ignore: + self.flow = None + else: + self.flow = tcp.TCPFlow(self.context.client, self.context.server, True) @expect(events.Start) def start(self, _) -> commands.TCommandGenerator: - if not self.ignore: - self.flow = tcp.TCPFlow(self.context.client, self.context.server, True) + if self.flow: yield commands.Hook("tcp_start", self.flow) if not self.context.server.connected: - try: - yield commands.OpenConnection(self.context.server) - except IOError as e: - if not self.ignore: - self.flow.error = flow.Error(str(e)) + err = yield commands.OpenConnection(self.context.server) + if err: + if self.flow: + self.flow.error = flow.Error(str(err)) yield commands.Hook("tcp_error", self.flow) yield commands.CloseConnection(self.context.client) self._handle_event = self.done @@ -47,19 +48,21 @@ class TCPLayer(Layer): send_to = self.context.client if isinstance(event, events.DataReceived): - if self.ignore: - yield commands.SendData(send_to, event.data) - else: + if self.flow: tcp_message = tcp.TCPMessage(from_client, event.data) self.flow.messages.append(tcp_message) yield commands.Hook("tcp_message", self.flow) yield commands.SendData(send_to, tcp_message.content) + else: + yield commands.SendData(send_to, event.data) elif isinstance(event, events.ConnectionClosed): yield commands.CloseConnection(send_to) - if not self.ignore: - yield commands.Hook("tcp_end", self.flow) - self._handle_event = self.done + all_done = (not self.context.client.connected and not self.context.server.connected) + if all_done: + self._handle_event = self.done + if self.flow: + yield commands.Hook("tcp_end", self.flow) @expect(events.DataReceived, events.ConnectionClosed) def done(self, _): diff --git a/mitmproxy/proxy2/layers/tls.py b/mitmproxy/proxy2/layers/tls.py index e3ff31e46..132fd6e00 100644 --- a/mitmproxy/proxy2/layers/tls.py +++ b/mitmproxy/proxy2/layers/tls.py @@ -6,7 +6,7 @@ from OpenSSL import SSL from mitmproxy.certs import CertStore from mitmproxy.net.tls import ClientHello -from mitmproxy.proxy.protocol import tls +from mitmproxy.proxy.protocol.tls import DEFAULT_CLIENT_CIPHERS from mitmproxy.proxy2 import commands, events, layer from mitmproxy.proxy2 import context from mitmproxy.proxy2.utils import expect @@ -362,7 +362,7 @@ class ClientTLSLayer(_TLSLayer): ).get_cert(client.sni, (client.sni,)) context.use_privatekey(privkey) context.use_certificate(cert.x509) - context.set_cipher_list(tls.DEFAULT_CLIENT_CIPHERS) + context.set_cipher_list(DEFAULT_CLIENT_CIPHERS) def alpn_select_callback(conn_, options): if server.alpn in options: diff --git a/mitmproxy/proxy2/server.py b/mitmproxy/proxy2/server.py index 8ead05b20..5229d7728 100644 --- a/mitmproxy/proxy2/server.py +++ b/mitmproxy/proxy2/server.py @@ -31,7 +31,7 @@ class StreamIO(typing.NamedTuple): class TimeoutWatchdog: last_activity: float - CONNECTION_TIMEOUT = 120 + CONNECTION_TIMEOUT = 10 * 60 can_timeout: asyncio.Event blocker: int @@ -197,6 +197,9 @@ class ConnectionHandler(metaclass=abc.ABCMeta): asyncio.ensure_future( self.shutdown_connection(command.connection) ) + elif isinstance(command, commands.GetSocket): + socket = self.transports[command.connection].w.get_extra_info("socket") + self.server_event(events.GetSocketReply(command, socket)) elif isinstance(command, glue.GlueGetConnectionHandler): self.server_event(glue.GlueGetConnectionHandlerReply(command, self)) elif isinstance(command, commands.Hook): diff --git a/mitmproxy/proxy2/utils.py b/mitmproxy/proxy2/utils.py index a78721d5d..98b53ced1 100644 --- a/mitmproxy/proxy2/utils.py +++ b/mitmproxy/proxy2/utils.py @@ -18,7 +18,8 @@ def expect(*event_types): if isinstance(event, event_types): yield from f(self, event) else: - raise AssertionError(f"Unexpected event type at {f}: Expected {event_types}, got {event}.") + event_types_str = '|'.join(e.__name__ for e in event_types) + raise AssertionError(f"Unexpected event type at {f.__qualname__}: Expected {event_types_str}, got {event}.") return wrapper diff --git a/test/mitmproxy/platform/__init__.py b/test/mitmproxy/platform/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/mitmproxy/proxy2/layers/test_tcp.py b/test/mitmproxy/proxy2/layers/test_tcp.py index 52222af38..6efdc892c 100644 --- a/test/mitmproxy/proxy2/layers/test_tcp.py +++ b/test/mitmproxy/proxy2/layers/test_tcp.py @@ -1,6 +1,7 @@ -from mitmproxy.proxy2 import commands, events -from mitmproxy.proxy2.layers import tcp -from .. import tutils +from mitmproxy.proxy2.commands import CloseConnection, Hook, OpenConnection, SendData +from mitmproxy.proxy2.events import ConnectionClosed, DataReceived +from mitmproxy.proxy2.layers import TCPLayer +from ..tutils import Placeholder, playbook, reply def test_open_connection(tctx): @@ -9,121 +10,101 @@ def test_open_connection(tctx): because the server may send data first. """ assert ( - tutils.playbook(tcp.TCPLayer(tctx, True)) - << commands.OpenConnection(tctx.server) + playbook(TCPLayer(tctx, True)) + << OpenConnection(tctx.server) ) tctx.server.connected = True assert ( - tutils.playbook(tcp.TCPLayer(tctx, True)) - << None + playbook(TCPLayer(tctx, True)) + << None ) def test_open_connection_err(tctx): - f = tutils.Placeholder() + f = Placeholder() assert ( - tutils.playbook(tcp.TCPLayer(tctx)) - << commands.Hook("tcp_start", f) - >> events.HookReply(-1) - << commands.OpenConnection(tctx.server) - >> events.OpenConnectionReply(-1, "Connect call failed") - << commands.Hook("tcp_error", f) - >> events.HookReply(-1) - << commands.CloseConnection(tctx.client) + playbook(TCPLayer(tctx)) + << Hook("tcp_start", f) + >> reply() + << OpenConnection(tctx.server) + >> reply("Connect call failed") + << Hook("tcp_error", f) + >> reply() + << CloseConnection(tctx.client) ) def test_simple(tctx): """open connection, receive data, send it to peer""" - f = tutils.Placeholder() - playbook = tutils.playbook(tcp.TCPLayer(tctx)) + f = Placeholder() assert ( - playbook - << commands.Hook("tcp_start", f) - >> events.HookReply(-1) - << commands.OpenConnection(tctx.server) - >> events.OpenConnectionReply(-1, None) - >> events.DataReceived(tctx.client, b"hello!") - << commands.Hook("tcp_message", f) - >> events.HookReply(-1) - << commands.SendData(tctx.server, b"hello!") - >> events.DataReceived(tctx.server, b"hi") - << commands.Hook("tcp_message", f) - >> events.HookReply(-1) - << commands.SendData(tctx.client, b"hi") - >> events.ConnectionClosed(tctx.server) - << commands.CloseConnection(tctx.client) - << commands.Hook("tcp_end", f) - >> events.HookReply(-1) - >> events.ConnectionClosed(tctx.client) - << None + playbook(TCPLayer(tctx)) + << Hook("tcp_start", f) + >> reply() + << OpenConnection(tctx.server) + >> reply(None) + >> DataReceived(tctx.client, b"hello!") + << Hook("tcp_message", f) + >> reply() + << SendData(tctx.server, b"hello!") + >> DataReceived(tctx.server, b"hi") + << Hook("tcp_message", f) + >> reply() + << SendData(tctx.client, b"hi") + >> ConnectionClosed(tctx.server) + << CloseConnection(tctx.client) + >> ConnectionClosed(tctx.client) + << CloseConnection(tctx.server) + << Hook("tcp_end", f) + >> reply() + >> ConnectionClosed(tctx.client) + << None ) assert len(f().messages) == 2 -def test_simple_explicit(tctx): - """ - For comparison, test_simple without the playbook() sugar. - This is not substantially more code, but the playbook syntax feels cleaner to me. - """ - layer = tcp.TCPLayer(tctx) - tcp_start, = layer.handle_event(events.Start()) - flow = tcp_start.data - assert tutils._eq(tcp_start, commands.Hook("tcp_start", flow)) - open_conn, = layer.handle_event(events.HookReply(tcp_start)) - assert tutils._eq(open_conn, commands.OpenConnection(tctx.server)) - assert list(layer.handle_event(events.OpenConnectionReply(open_conn, None))) == [] - tcp_msg, = layer.handle_event(events.DataReceived(tctx.client, b"hello!")) - assert tutils._eq(tcp_msg, commands.Hook("tcp_message", flow)) - assert flow.messages[0].content == b"hello!" - - send, = layer.handle_event(events.HookReply(tcp_msg)) - assert tutils._eq(send, commands.SendData(tctx.server, b"hello!")) - close, tcp_end = layer.handle_event(events.ConnectionClosed(tctx.server)) - assert tutils._eq(close, commands.CloseConnection(tctx.client)) - assert tutils._eq(tcp_end, commands.Hook("tcp_end", flow)) - assert list(layer.handle_event(events.HookReply(tcp_end))) == [] - - def test_receive_data_before_server_connected(tctx): """ assert that data received before a server connection is established will still be forwarded. """ - f = tutils.Placeholder() + f = Placeholder() assert ( - tutils.playbook(tcp.TCPLayer(tctx)) - << commands.Hook("tcp_start", f) - >> events.HookReply(-1) - << commands.OpenConnection(tctx.server) - >> events.DataReceived(tctx.client, b"hello!") - >> events.OpenConnectionReply(-2, None) - << commands.Hook("tcp_message", f) - >> events.HookReply(-1) - << commands.SendData(tctx.server, b"hello!") + playbook(TCPLayer(tctx)) + << Hook("tcp_start", f) + >> reply() + << OpenConnection(tctx.server) + >> DataReceived(tctx.client, b"hello!") + >> reply(None, to=-2) + << Hook("tcp_message", f) + >> reply() + << SendData(tctx.server, b"hello!") ) assert f().messages -def test_receive_data_after_server_disconnected(tctx): +def test_receive_data_after_half_close(tctx): """ - data received after a connection has been closed should just be discarded. + data received after the other connection has been half-closed should still be forwarded. """ - f = tutils.Placeholder() + f = Placeholder() assert ( - tutils.playbook(tcp.TCPLayer(tctx)) - << commands.Hook("tcp_start", f) - >> events.HookReply(-1) - << commands.OpenConnection(tctx.server) - >> events.OpenConnectionReply(-1, None) - >> events.ConnectionClosed(tctx.server) - << commands.CloseConnection(tctx.client) - << commands.Hook("tcp_end", f) - >> events.HookReply(-1) - >> events.DataReceived(tctx.client, b"i'm late") - << None - ) - # not included here as it has not been sent to the server. - assert not f().messages + playbook(TCPLayer(tctx)) + << Hook("tcp_start", f) + >> reply() + << OpenConnection(tctx.server) + >> reply(None) + >> ConnectionClosed(tctx.server) + << CloseConnection(tctx.client) + >> DataReceived(tctx.client, b"i'm late") + << Hook("tcp_message", f) + >> reply() + << SendData(tctx.server, b"i'm late") + >> ConnectionClosed(tctx.client) + << CloseConnection(tctx.server) + << Hook("tcp_end", f) + >> reply() + << None + ) \ No newline at end of file diff --git a/test/mitmproxy/proxy2/test_tutils.py b/test/mitmproxy/proxy2/test_tutils.py index 3b24ce5a6..995976449 100644 --- a/test/mitmproxy/proxy2/test_tutils.py +++ b/test/mitmproxy/proxy2/test_tutils.py @@ -22,7 +22,7 @@ class TCommand(commands.Command): class TCommandReply(events.CommandReply): - pass + command: TCommand class TLayer(Layer): @@ -52,7 +52,7 @@ def test_simple(tplaybook): def test_mismatch(tplaybook): - with pytest.raises(AssertionError, message="Playbook mismatch"): + with pytest.raises(AssertionError, match="Playbook mismatch"): assert ( tplaybook >> TEvent([]) @@ -135,7 +135,7 @@ def test_fork_placeholder(tplaybook): assert f2() == p2_flow # re-using the old placeholder does not work. - with pytest.raises(AssertionError, message="Playbook mismatch"): + with pytest.raises(AssertionError, match="Playbook mismatch"): assert ( p2 >> TEvent([p2_flow]) @@ -146,7 +146,7 @@ def test_fork_placeholder(tplaybook): def test_unfinished(tplaybook): """We show a warning when playbooks aren't asserted.""" tplaybook >> TEvent() - with pytest.raises(RuntimeError, message="Unfinished playbook"): + with pytest.raises(RuntimeError, match="Unfinished playbook"): tplaybook.__del__() tplaybook._errored = True tplaybook.__del__() diff --git a/test/mitmproxy/proxy2/tutils.py b/test/mitmproxy/proxy2/tutils.py index 1669c9718..c9e81ecc6 100644 --- a/test/mitmproxy/proxy2/tutils.py +++ b/test/mitmproxy/proxy2/tutils.py @@ -1,12 +1,13 @@ +import collections.abc import copy import difflib import itertools import typing -import collections - from mitmproxy.proxy2 import commands, context from mitmproxy.proxy2 import events +from mitmproxy.proxy2.context import ConnectionState +from mitmproxy.proxy2.events import command_reply_subclasses from mitmproxy.proxy2.layer import Layer, NextLayer TPlaybookEntry = typing.Union[commands.Command, events.Event] @@ -14,8 +15,8 @@ TPlaybook = typing.List[TPlaybookEntry] def _eq( - a: TPlaybookEntry, - b: TPlaybookEntry + a: TPlaybookEntry, + b: TPlaybookEntry ) -> bool: """Compare two commands/events, and possibly update placeholders.""" if type(a) != type(b): @@ -43,24 +44,28 @@ def _eq( def eq( - a: typing.Union[TPlaybookEntry, typing.Iterable[TPlaybookEntry]], - b: typing.Union[TPlaybookEntry, typing.Iterable[TPlaybookEntry]] + a: typing.Union[TPlaybookEntry, typing.Iterable[TPlaybookEntry]], + b: typing.Union[TPlaybookEntry, typing.Iterable[TPlaybookEntry]] ): """ Compare an indiviual event/command or a list of events/commands. """ - if isinstance(a, collections.Iterable) and isinstance(b, collections.Iterable): + if isinstance(a, collections.abc.Iterable) and isinstance(b, collections.abc.Iterable): return all( _eq(x, y) for x, y in itertools.zip_longest(a, b) ) return _eq(a, b) -T = typing.TypeVar('T', bound=Layer) +def _str(x: typing.Union[events.Event, commands.Command]): + arrow = ">>" if isinstance(x, events.Event) else "<<" + x = str(x) \ + .replace('Placeholder:None', '') \ + .replace('Placeholder:', '') + return f"{arrow} {x}" -# noinspection PyPep8Naming -class playbook(typing.Generic[T]): +class playbook: """ Assert that a layer emits the expected commands in reaction to a given sequence of events. For example, the following code asserts that the TCP layer emits an OpenConnection command @@ -80,7 +85,7 @@ class playbook(typing.Generic[T]): x2 = list(t.handle_event(events.OpenConnectionReply(x1[-1]))) assert x2 == [] """ - layer: T + layer: Layer """The base layer""" expected: TPlaybook """expected command/event sequence""" @@ -92,10 +97,10 @@ class playbook(typing.Generic[T]): """If True, log statements are ignored.""" def __init__( - self, - layer: T, - expected: typing.Optional[TPlaybook] = None, - ignore_log: bool = True + self, + layer: Layer, + expected: typing.Optional[TPlaybook] = None, + ignore_log: bool = True ): if expected is None: expected = [ @@ -130,11 +135,12 @@ class playbook(typing.Generic[T]): if isinstance(x, commands.Command): pass else: - if isinstance(x, events.CommandReply): - if isinstance(x.command, int) and abs(x.command) < len(self.actual): - x.command = self.actual[x.command] - if hasattr(x, "_playbook_eval"): - x._playbook_eval(self) + if hasattr(x, "playbook_eval"): + x = self.expected[i] = x.playbook_eval(self) + if isinstance(x, events.OpenConnectionReply): + x.command.connection.state = ConnectionState.OPEN + elif isinstance(x, events.ConnectionClosed): + x.connection.state &= ~ConnectionState.CAN_READ self.actual.append(x) self.actual.extend( @@ -148,14 +154,6 @@ class playbook(typing.Generic[T]): if not eq(self.expected, self.actual): self._errored = True - - def _str(x): - arrow = ">>" if isinstance(x, events.Event) else "<<" - x = str(x) \ - .replace('Placeholder:None', '') \ - .replace('Placeholder:', '') - return f"{arrow} {x}" - diff = "\n".join(difflib.ndiff( [_str(x) for x in self.expected], [_str(x) for x in self.actual] @@ -180,6 +178,48 @@ class playbook(typing.Generic[T]): return copy.deepcopy(self) +class reply(events.Event): + args: typing.Tuple[typing.Any, ...] + to: typing.Union[commands.Command, int] + side_effect: typing.Callable[[commands.Command], typing.Any] + + def __init__( + self, + *args, + to: typing.Union[commands.Command, int] = -1, + side_effect: typing.Callable[[commands.Command], typing.Any] = lambda cmd: None + ): + """Utility method to reply to the latest hook in playbooks.""" + self.args = args + self.to = to + self.side_effect = side_effect + + def playbook_eval(self, playbook: playbook) -> events.CommandReply: + if isinstance(self.to, int): + expected = playbook.expected[:playbook.expected.index(self)] + assert abs(self.to) < len(expected) + to = expected[self.to] + if not isinstance(to, commands.Command): + raise AssertionError(f"There is no command at offset {self.to}: {to}") + else: + self.to = to + for cmd in reversed(playbook.actual): + if eq(self.to, cmd): + self.to = cmd + break + else: + actual_str = "\n".join(_str(x) for x in playbook.actual) + raise AssertionError(f"Expected command ({self.to}) did not occur:\n{actual_str}") + + self.side_effect(self.to) + reply_cls = command_reply_subclasses[type(self.to)] + try: + inst = reply_cls(self.to, *self.args) + except TypeError as e: + raise ValueError(f"Cannot instantiate {reply_cls.__name__}: {e}") + return inst + + class _Placeholder: """ Placeholder value in playbooks, so that objects (flows in particular) can be referenced before @@ -209,6 +249,7 @@ class _Placeholder: return f"Placeholder:{str(self.obj)}" +# noinspection PyPep8Naming def Placeholder() -> typing.Any: return _Placeholder() @@ -222,7 +263,7 @@ class EchoLayer(Layer): def next_layer( - layer: typing.Union[typing.Type[Layer], typing.Callable[[context.Context], Layer]] + layer: typing.Union[typing.Type[Layer], typing.Callable[[context.Context], Layer]] ) -> events.HookReply: """ Helper function to simplify the syntax for next_layer events from this: @@ -238,7 +279,9 @@ def next_layer( << commands.Hook("next_layer", next_layer) >> tutils.next_layer(next_layer, tutils.EchoLayer) + >> tutils.reply(side_effect=lambda cmd: cmd.layer = tutils.EchoLayer(cmd.data.context) """ + raise RuntimeError("Does tutils.reply(side_effect=lambda cmd: cmd.layer = tutils.EchoLayer(cmd.data.context) work?") if isinstance(layer, type): def make_layer(ctx: context.Context) -> Layer: return layer(ctx)