diff --git a/mitmproxy/proxy2/commands.py b/mitmproxy/proxy2/commands.py index b84a6471c..b115e8220 100644 --- a/mitmproxy/proxy2/commands.py +++ b/mitmproxy/proxy2/commands.py @@ -53,7 +53,7 @@ class SendData(ConnectionCommand): self.data = data def __repr__(self): - target = type(self.connection).__name__.lower() + target = str(self.connection).split("(", 1)[0].lower() return f"SendData({target}, {self.data})" diff --git a/mitmproxy/proxy2/events.py b/mitmproxy/proxy2/events.py index 1ec5cebe2..9e1003713 100644 --- a/mitmproxy/proxy2/events.py +++ b/mitmproxy/proxy2/events.py @@ -66,7 +66,7 @@ class CommandReply(Event): command: commands.Command reply: typing.Any - def __init__(self, command: typing.Union[commands.Command, int], reply: typing.Any): + def __init__(self, command: commands.Command, reply: typing.Any): self.command = command self.reply = reply @@ -91,7 +91,7 @@ class OpenConnectionReply(CommandReply): def __init__( self, - command: typing.Union[commands.OpenConnection, int], + command: commands.OpenConnection, err: typing.Optional[str] ): super().__init__(command, err) @@ -113,7 +113,7 @@ class GetSocketReply(CommandReply): def __init__( self, - command: typing.Union[commands.GetSocket, int], + command: commands.GetSocket, socket: socket.socket ): super().__init__(command, socket) diff --git a/mitmproxy/proxy2/layers/http/http.py b/mitmproxy/proxy2/layers/http/http.py index 9bb5e79c0..df3a5c1e2 100644 --- a/mitmproxy/proxy2/layers/http/http.py +++ b/mitmproxy/proxy2/layers/http/http.py @@ -497,6 +497,8 @@ class HttpStream(Layer): self.flow.error = flow.Error(err) yield commands.Hook("error", self.flow) return + else: + self.flow.server_conn = connection yield SendHttp(RequestHeaders(self.flow.request, self.stream_id), connection) @@ -660,6 +662,7 @@ class HTTPLayer(Layer): can_reuse_context_connection = ( self.context.server not in self.connections and self.context.server.connected and + self.context.server.address == event.address and self.context.server.tls == event.tls ) if can_reuse_context_connection: diff --git a/test/mitmproxy/proxy2/conftest.py b/test/mitmproxy/proxy2/conftest.py index 03ce5b7f7..be5db56e1 100644 --- a/test/mitmproxy/proxy2/conftest.py +++ b/test/mitmproxy/proxy2/conftest.py @@ -1,12 +1,15 @@ import pytest from mitmproxy import options +from mitmproxy.addons.proxyserver import Proxyserver from mitmproxy.proxy2 import context @pytest.fixture def tctx(): + opts = options.Options() + Proxyserver().load(opts) return context.Context( context.Client(("client", 1234)), - options.Options() + opts ) diff --git a/test/mitmproxy/proxy2/layers/test_websocket.py b/test/mitmproxy/proxy2/layers/_test_websocket.py similarity index 99% rename from test/mitmproxy/proxy2/layers/test_websocket.py rename to test/mitmproxy/proxy2/layers/_test_websocket.py index 512efc3ba..f329d2966 100644 --- a/test/mitmproxy/proxy2/layers/test_websocket.py +++ b/test/mitmproxy/proxy2/layers/_test_websocket.py @@ -13,7 +13,7 @@ from .. import tutils @pytest.fixture def ws_playbook(tctx): tctx.server.connected = True - playbook = tutils.playbook( + playbook = tutils.Playbook( websocket.WebsocketLayer( tctx, tflow.twebsocketflow().handshake_flow diff --git a/test/mitmproxy/proxy2/layers/test_http.py b/test/mitmproxy/proxy2/layers/test_http.py index 2749c6b45..0e5897270 100644 --- a/test/mitmproxy/proxy2/layers/test_http.py +++ b/test/mitmproxy/proxy2/layers/test_http.py @@ -1,29 +1,200 @@ -def test_http_proxy(): +import pytest + +from mitmproxy.http import HTTPResponse +from mitmproxy.proxy.protocol.http import HTTPMode +from mitmproxy.proxy2.commands import Hook, OpenConnection, SendData +from mitmproxy.proxy2.events import ConnectionClosed, DataReceived +from mitmproxy.proxy2.layers import tls +from mitmproxy.proxy2.layers.http import http +from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply, reply_establish_server_tls, reply_next_layer + + +def test_http_proxy(tctx): """Test a simple HTTP GET / request""" + server = Placeholder() + flow = Placeholder() + assert ( + Playbook(http.HTTPLayer(tctx, HTTPMode.regular)) + >> DataReceived(tctx.client, b"GET http://example.com/foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n") + << Hook("requestheaders", flow) + >> reply() + << Hook("request", flow) + >> reply() + << OpenConnection(server) + >> reply(None) + << SendData(server, b"GET /foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n") + >> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World") + << Hook("responseheaders", flow) + >> reply() + >> DataReceived(server, b"!") + << Hook("response", flow) + >> reply() + << SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!") + ) + assert server().address == ("example.com", 80) -def test_https_proxy_eager(): - """Test a CONNECT request, followed by TLS, followed by a HTTP GET /""" +@pytest.mark.parametrize("strategy", ["lazy", "eager"]) +def test_https_proxy(strategy, tctx): + """Test a CONNECT request, followed by a HTTP GET /""" + server = Placeholder() + flow = Placeholder() + playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular)) + tctx.options.connection_strategy = strategy + + (playbook + >> DataReceived(tctx.client, b"CONNECT example.proxy:80 HTTP/1.1\r\n\r\n") + << Hook("http_connect", Placeholder()) + >> reply()) + if strategy == "eager": + (playbook + << OpenConnection(server) + >> reply(None)) + (playbook + << SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n') + >> DataReceived(tctx.client, b"GET /foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n") + << Hook("next_layer", Placeholder()) + >> reply_next_layer(lambda ctx: http.HTTPLayer(ctx, HTTPMode.transparent)) + << Hook("requestheaders", flow) + >> reply() + << Hook("request", flow) + >> reply()) + if strategy == "lazy": + (playbook + << OpenConnection(server) + >> reply(None)) + (playbook + << SendData(server, b"GET /foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n") + >> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!") + << Hook("responseheaders", flow) + >> reply() + << Hook("response", flow) + >> reply() + << SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")) + assert playbook -def test_https_proxy_lazy(): - """Test a CONNECT request, followed by TLS, followed by a HTTP GET /""" +@pytest.mark.parametrize("https_client", [False, True]) +@pytest.mark.parametrize("https_server", [False, True]) +@pytest.mark.parametrize("strategy", ["lazy", "eager"]) +def test_redirect(strategy, https_server, https_client, tctx): + """Test redirects between http:// and https:// in regular proxy mode.""" + server = Placeholder() + flow = Placeholder() + tctx.options.connection_strategy = strategy + p = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False) + + def redirect(hook: Hook): + if https_server: + hook.data.request.url = "https://redirected.site/" + else: + hook.data.request.url = "http://redirected.site/" + + if https_client: + p >> DataReceived(tctx.client, b"CONNECT example.com:80 HTTP/1.1\r\n\r\n") + if strategy == "eager": + p << OpenConnection(Placeholder()) + p >> reply(None) + p << SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n') + p >> DataReceived(tctx.client, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + p << Hook("next_layer", Placeholder()) + p >> reply_next_layer(lambda ctx: http.HTTPLayer(ctx, HTTPMode.transparent)) + else: + p >> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n") + p << Hook("request", flow) + p >> reply(side_effect=redirect) + p << OpenConnection(server) + p >> reply(None) + if https_server: + p << tls.EstablishServerTLS(server) + p >> reply_establish_server_tls() + p << SendData(server, b"GET / HTTP/1.1\r\nHost: redirected.site\r\n\r\n") + p >> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!") + p << SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!") + + assert p + if https_server: + assert server().address == ("redirected.site", 443) + else: + assert server().address == ("redirected.site", 80) -def test_http_to_https(): - """Test a simple HTTP GET request that is being rewritten to HTTPS by an addon.""" - - -def test_http_redirect(): - """Test a simple HTTP GET request that redirected to another host""" - - -def test_multiple_server_connections(): +def test_multiple_server_connections(tctx): """Test multiple requests being rewritten to different targets.""" + server1 = Placeholder() + server2 = Placeholder() + + def redirect(to: str): + def side_effect(hook: Hook): + hook.data.request.url = to + + return side_effect + + assert ( + Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False) + >> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n") + << Hook("request", Placeholder()) + >> reply(side_effect=redirect("http://one.redirect/")) + << OpenConnection(server1) + >> reply(None) + << SendData(server1, b"GET / HTTP/1.1\r\nHost: one.redirect\r\n\r\n") + >> DataReceived(server1, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n") + << SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n") + >> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n") + << Hook("request", Placeholder()) + >> reply(side_effect=redirect("http://two.redirect/")) + << OpenConnection(server2) + >> reply(None) + << SendData(server2, b"GET / HTTP/1.1\r\nHost: two.redirect\r\n\r\n") + >> DataReceived(server2, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n") + << SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n") + ) + assert server1().address == ("one.redirect", 80) + assert server2().address == ("two.redirect", 80) -def test_http_reply_from_proxy(): +def test_http_reply_from_proxy(tctx): """Test a response served by mitmproxy itself.""" -def test_disconnect_while_intercept(): - """Test a server disconnect while a request is intercepted.""" \ No newline at end of file + def reply_from_proxy(hook: Hook): + hook.data.response = HTTPResponse.make(418) + + assert ( + Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False) + >> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n") + << Hook("request", Placeholder()) + >> reply(side_effect=reply_from_proxy) + << SendData(tctx.client, b"HTTP/1.1 418 I'm a teapot\r\ncontent-length: 0\r\n\r\n") + ) + + +def test_disconnect_while_intercept(tctx): + """Test a server disconnect while a request is intercepted.""" + tctx.options.connection_strategy = "eager" + + server1 = Placeholder() + server2 = Placeholder() + flow = Placeholder() + + assert ( + Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False) + >> DataReceived(tctx.client, b"CONNECT example.com:80 HTTP/1.1\r\n\r\n") + << Hook("http_connect", Placeholder()) + >> reply() + << OpenConnection(server1) + >> reply(None) + << SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n') + >> DataReceived(tctx.client, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + << Hook("next_layer", Placeholder()) + >> reply_next_layer(lambda ctx: http.HTTPLayer(ctx, HTTPMode.transparent)) + << Hook("request", flow) + >> ConnectionClosed(server1) + >> reply(to=-2) + << OpenConnection(server2) + >> reply(None) + << SendData(server2, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + >> DataReceived(server2, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n") + << SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n") + ) + assert server1() != server2() + assert flow().server_conn == server2() diff --git a/test/mitmproxy/proxy2/layers/test_tcp.py b/test/mitmproxy/proxy2/layers/test_tcp.py index 5217ab547..ca2670994 100644 --- a/test/mitmproxy/proxy2/layers/test_tcp.py +++ b/test/mitmproxy/proxy2/layers/test_tcp.py @@ -1,7 +1,7 @@ from mitmproxy.proxy2.commands import CloseConnection, Hook, OpenConnection, SendData from mitmproxy.proxy2.events import ConnectionClosed, DataReceived from mitmproxy.proxy2.layers import TCPLayer -from ..tutils import Placeholder, playbook, reply +from ..tutils import Placeholder, Playbook, reply def test_open_connection(tctx): @@ -10,13 +10,13 @@ def test_open_connection(tctx): because the server may send data first. """ assert ( - playbook(TCPLayer(tctx, True)) + Playbook(TCPLayer(tctx, True)) << OpenConnection(tctx.server) ) tctx.server.connected = True assert ( - playbook(TCPLayer(tctx, True)) + Playbook(TCPLayer(tctx, True)) << None ) @@ -24,7 +24,7 @@ def test_open_connection(tctx): def test_open_connection_err(tctx): f = Placeholder() assert ( - playbook(TCPLayer(tctx)) + Playbook(TCPLayer(tctx)) << Hook("tcp_start", f) >> reply() << OpenConnection(tctx.server) @@ -40,7 +40,7 @@ def test_simple(tctx): f = Placeholder() assert ( - playbook(TCPLayer(tctx)) + Playbook(TCPLayer(tctx)) << Hook("tcp_start", f) >> reply() << OpenConnection(tctx.server) @@ -71,7 +71,7 @@ def test_receive_data_before_server_connected(tctx): will still be forwarded. """ assert ( - playbook(TCPLayer(tctx), hooks=False) + Playbook(TCPLayer(tctx), hooks=False) << OpenConnection(tctx.server) >> DataReceived(tctx.client, b"hello!") >> reply(None, to=-2) @@ -84,7 +84,7 @@ def test_receive_data_after_half_close(tctx): data received after the other connection has been half-closed should still be forwarded. """ assert ( - playbook(TCPLayer(tctx), hooks=False) + Playbook(TCPLayer(tctx), hooks=False) << OpenConnection(tctx.server) >> reply(None) >> ConnectionClosed(tctx.server) diff --git a/test/mitmproxy/proxy2/layers/test_tls.py b/test/mitmproxy/proxy2/layers/test_tls.py index 8f0da853f..aa23a16df 100644 --- a/test/mitmproxy/proxy2/layers/test_tls.py +++ b/test/mitmproxy/proxy2/layers/test_tls.py @@ -87,7 +87,7 @@ class SSLTest: ) -def _test_echo(playbook: tutils.playbook, tssl: SSLTest, conn: context.Connection) -> None: +def _test_echo(playbook: tutils.Playbook, tssl: SSLTest, conn: context.Connection) -> None: tssl.obj.write(b"Hello World") data = tutils.Placeholder() assert ( @@ -110,7 +110,7 @@ class TlsEchoLayer(tutils.EchoLayer): yield from super()._handle_event(event) -def interact(playbook: tutils.playbook, conn: context.Connection, tssl: SSLTest): +def interact(playbook: tutils.Playbook, conn: context.Connection, tssl: SSLTest): data = tutils.Placeholder() assert ( playbook @@ -149,7 +149,7 @@ class TestServerTLS: # Handshake assert ( - tutils.playbook(layer) + tutils.Playbook(layer) >> events.DataReceived(tctx.client, b"Hello World") << commands.SendData(tctx.client, b"hello world") >> events.DataReceived(tctx.server, b"Foo") @@ -158,7 +158,7 @@ class TestServerTLS: def test_simple(self, tctx): layer = tls.ServerTLSLayer(tctx) - playbook = tutils.playbook(layer) + playbook = tutils.Playbook(layer) tctx.server.connected = True tctx.server.address = ("example.com", 443) @@ -170,7 +170,7 @@ class TestServerTLS: playbook >> events.DataReceived(tctx.client, b"establish-server-tls") << commands.Hook("next_layer", tutils.Placeholder()) - >> tutils.next_layer(TlsEchoLayer) + >> tutils.reply_next_layer(TlsEchoLayer) << commands.Hook("tls_start", tutils.Placeholder()) >> reply_tls_start() << commands.SendData(tctx.server, data) @@ -196,21 +196,21 @@ class TestServerTLS: _test_echo(playbook, tssl, tctx.server) -def _make_client_tls_layer(tctx: context.Context) -> typing.Tuple[tutils.playbook, tls.ClientTLSLayer]: +def _make_client_tls_layer(tctx: context.Context) -> typing.Tuple[tutils.Playbook, tls.ClientTLSLayer]: # This is a bit contrived as the client layer expects a server layer as parent. # We also set child layers manually to avoid NextLayer noise. server_layer = tls.ServerTLSLayer(tctx) client_layer = tls.ClientTLSLayer(tctx) server_layer.child_layer = client_layer client_layer.child_layer = TlsEchoLayer(tctx) - playbook = tutils.playbook(server_layer) + playbook = tutils.Playbook(server_layer) return playbook, client_layer def _test_tls_client_server( tctx: context.Context, sni: typing.Optional[bytes] -) -> typing.Tuple[tutils.playbook, tls.ClientTLSLayer, SSLTest]: +) -> typing.Tuple[tutils.Playbook, tls.ClientTLSLayer, SSLTest]: playbook, client_layer = _make_client_tls_layer(tctx) tctx.server.tls = True tctx.server.address = ("example.com", 443) diff --git a/test/mitmproxy/proxy2/test_layer.py b/test/mitmproxy/proxy2/test_layer.py index 3a8efc95f..af5e262fb 100644 --- a/test/mitmproxy/proxy2/test_layer.py +++ b/test/mitmproxy/proxy2/test_layer.py @@ -5,7 +5,7 @@ from test.mitmproxy.proxy2 import tutils class TestNextLayer: def test_simple(self, tctx): nl = layer.NextLayer(tctx) - playbook = tutils.playbook(nl, hooks=True) + playbook = tutils.Playbook(nl, hooks=True) assert ( playbook @@ -32,7 +32,7 @@ class TestNextLayer: a reply from the proxy core. """ nl = layer.NextLayer(tctx) - playbook = tutils.playbook(nl) + playbook = tutils.Playbook(nl) assert ( playbook @@ -52,7 +52,7 @@ class TestNextLayer: def test_func_references(self, tctx): nl = layer.NextLayer(tctx) - playbook = tutils.playbook(nl) + playbook = tutils.Playbook(nl) assert ( playbook diff --git a/test/mitmproxy/proxy2/test_tutils.py b/test/mitmproxy/proxy2/test_tutils.py index 29b65171d..c79cd2efc 100644 --- a/test/mitmproxy/proxy2/test_tutils.py +++ b/test/mitmproxy/proxy2/test_tutils.py @@ -38,7 +38,7 @@ class TLayer(Layer): @pytest.fixture def tplaybook(tctx): - return tutils.playbook(TLayer(tctx), expected=[]) + return tutils.Playbook(TLayer(tctx), expected=[]) def test_simple(tplaybook): @@ -164,7 +164,7 @@ def test_command_reply(tplaybook): def test_default_playbook(tctx): - p = tutils.playbook(TLayer(tctx)) + p = tutils.Playbook(TLayer(tctx)) assert p assert len(p.actual) == 1 assert isinstance(p.actual[0], events.Start) diff --git a/test/mitmproxy/proxy2/tutils.py b/test/mitmproxy/proxy2/tutils.py index c45440033..4f8a3baf5 100644 --- a/test/mitmproxy/proxy2/tutils.py +++ b/test/mitmproxy/proxy2/tutils.py @@ -2,7 +2,8 @@ import collections.abc import copy import difflib import itertools -import sys +import re +import traceback import typing from mitmproxy.proxy2 import commands, context @@ -10,14 +11,15 @@ from mitmproxy.proxy2 import events from mitmproxy.proxy2.context import ConnectionState from mitmproxy.proxy2.events import command_reply_subclasses from mitmproxy.proxy2.layer import Layer, NextLayer +from mitmproxy.proxy2.layers import tls -TPlaybookEntry = typing.Union[commands.Command, events.Event] -TPlaybook = typing.List[TPlaybookEntry] +PlaybookEntry = typing.Union[commands.Command, events.Event] +PlaybookEntryList = typing.List[PlaybookEntry] def _eq( - a: TPlaybookEntry, - b: TPlaybookEntry + a: PlaybookEntry, + b: PlaybookEntry ) -> bool: """Compare two commands/events, and possibly update placeholders.""" if type(a) != type(b): @@ -45,8 +47,8 @@ def _eq( def eq( - a: typing.Union[TPlaybookEntry, typing.Iterable[TPlaybookEntry]], - b: typing.Union[TPlaybookEntry, typing.Iterable[TPlaybookEntry]] + a: typing.Union[PlaybookEntry, typing.Iterable[PlaybookEntry]], + b: typing.Union[PlaybookEntry, typing.Iterable[PlaybookEntry]] ): """ Compare an indiviual event/command or a list of events/commands. @@ -58,15 +60,39 @@ def eq( return _eq(a, b) -def _fmt_entry(x: TPlaybookEntry): +def _fmt_entry(x: PlaybookEntry): arrow = ">>" if isinstance(x, events.Event) else "<<" - x = str(x) \ - .replace('Placeholder:None', '') \ - .replace('Placeholder:', '') + x = str(x) + x = re.sub('Placeholder:None', '', x, flags=re.IGNORECASE) + x = re.sub('Placeholder:', '', x, flags=re.IGNORECASE) return f"{arrow} {x}" -class playbook: +def _merge_sends(lst: PlaybookEntryList) -> PlaybookEntryList: + merged = lst[:1] + for x in lst[1:]: + prev = merged[-1] + two_subsequent_sends_to_the_same_remote = ( + isinstance(x, commands.SendData) and + isinstance(prev, commands.SendData) and + x.connection is prev.connection + ) + if two_subsequent_sends_to_the_same_remote: + prev.data += x.data + else: + merged.append(x) + return merged + + +class _TracebackInPlaybook(commands.Command): + def __init__(self, exc): + self.e = exc + + def __repr__(self): + return self.e + + +class Playbook: """ Assert that a layer emits the expected commands in reaction to a given sequence of events. For example, the following code asserts that the TCP layer emits an OpenConnection command @@ -88,9 +114,9 @@ class playbook: """ layer: Layer """The base layer""" - expected: TPlaybook + expected: PlaybookEntryList """expected command/event sequence""" - actual: TPlaybook + actual: PlaybookEntryList """actual command/event sequence""" _errored: bool """used to check if playbook as been fully asserted""" @@ -98,13 +124,15 @@ class playbook: """If False, the playbook specification doesn't contain log commands.""" hooks: bool """If False, the playbook specification doesn't include hooks or hook replies. They are automatically replied to.""" + merge_sends: bool + """If True, subsequent SendData commands to the same remote will be merged in both expected and actual playbook.""" def __init__( self, layer: Layer, hooks: bool = True, logs: bool = False, - expected: typing.Optional[TPlaybook] = None, + expected: typing.Optional[PlaybookEntryList] = None, ): if expected is None: expected = [ @@ -121,8 +149,6 @@ class playbook: def __rshift__(self, e): """Add an event to send""" assert isinstance(e, events.Event) - if not self.hooks and isinstance(e, events.HookReply): - raise ValueError(f"Playbook must not contain hook replies if hooks=False: {e}") self.expected.append(e) return self @@ -131,10 +157,6 @@ class playbook: if c is None: return self assert isinstance(c, commands.Command) - if not self.logs and isinstance(c, commands.Log): - raise ValueError(f"Playbook must not contain log commands if logs=False: {c}") - if not self.hooks and isinstance(c, commands.Hook): - raise ValueError(f"Playbook must not contain hook commands if hooks=False: {c}") self.expected.append(c) return self @@ -149,25 +171,41 @@ class playbook: else: if hasattr(x, "playbook_eval"): x = self.expected[i] = x.playbook_eval(self) + for name, value in vars(x).items(): + if isinstance(value, _Placeholder): + setattr(x, name, value()) if isinstance(x, events.OpenConnectionReply) and not x.reply: x.command.connection.state = ConnectionState.OPEN elif isinstance(x, events.ConnectionClosed): x.connection.state &= ~ConnectionState.CAN_READ self.actual.append(x) - cmds = list(self.layer.handle_event(x)) + try: + cmds = list(self.layer.handle_event(x)) + except Exception: + self.actual.append(_TracebackInPlaybook(traceback.format_exc())) + break self.actual.extend(cmds) if not self.logs: for offset, cmd in enumerate(cmds): - if isinstance(cmd, commands.Log): - self.expected.insert(i + 1 + offset, cmd) + pos = i + 1 + offset + if isinstance(cmd, commands.Log) and not isinstance(self.expected[pos], commands.Log): + self.expected.insert(pos, cmd) if not self.hooks: last_cmd = self.actual[-1] - if isinstance(last_cmd, commands.Hook): - self.expected.insert(i + len(cmds), last_cmd) - self.expected.insert(i + len(cmds) + 1, events.HookReply(last_cmd)) + pos = i + len(cmds) + need_to_emulate_hook = ( + isinstance(last_cmd, commands.Hook) and + not (isinstance(self.expected[pos], commands.Hook) and self.expected[pos].name == last_cmd.name) + ) + if need_to_emulate_hook: + self.expected.insert(pos, last_cmd) + self.expected.insert(pos + 1, events.HookReply(last_cmd)) i += 1 + self.actual = _merge_sends(self.actual) + self.expected = _merge_sends(self.expected) + if not eq(self.expected, self.actual): self._errored = True diff = "\n".join(difflib.ndiff( @@ -210,7 +248,7 @@ class reply(events.Event): self.to = to self.side_effect = side_effect - def playbook_eval(self, playbook: playbook) -> events.CommandReply: + def playbook_eval(self, playbook: Playbook) -> events.CommandReply: if isinstance(self.to, int): expected = playbook.expected[:playbook.expected.index(self)] assert abs(self.to) < len(expected) @@ -225,7 +263,7 @@ class reply(events.Event): break else: actual_str = "\n".join(_fmt_entry(x) for x in playbook.actual) - raise AssertionError(f"Expected command ({self.to}) did not occur:\n{actual_str}") + raise AssertionError(f"Expected command {self.to} did not occur:\n{actual_str}") assert isinstance(self.to, commands.Command) self.side_effect(self.to) @@ -279,23 +317,12 @@ class EchoLayer(Layer): yield commands.SendData(event.connection, event.data.lower()) -def next_layer( +def reply_next_layer( layer: typing.Union[typing.Type[Layer], typing.Callable[[context.Context], Layer]], *args, **kwargs ) -> reply: - """ - Helper function to simplify the syntax for next_layer events from this: - - << commands.Hook("next_layer", next_layer) - ) - next_layer().layer = tutils.EchoLayer(next_layer().context) - assert ( - playbook - >> events.HookReply(-1) - - to this: - + """Helper function to simplify the syntax for next_layer events to this: << commands.Hook("next_layer", next_layer) >> tutils.next_layer(next_layer, tutils.EchoLayer) """ @@ -305,3 +332,15 @@ def next_layer( hook.data.layer = layer(hook.data.context) return reply(*args, side_effect=set_layer, **kwargs) + + +def reply_establish_server_tls(**kwargs) -> reply: + """Helper function to simplify the syntax for EstablishServerTls events to this: + << tls.EstablishServerTLS(server) + >> tutils.reply_establish_server_tls() + """ + + def fake_tls(cmd: tls.EstablishServerTLS) -> None: + cmd.connection.tls_established = True + + return reply(None, side_effect=fake_tls, **kwargs)