[sans-io] http tests++

This commit is contained in:
Maximilian Hils 2019-11-13 03:52:38 +01:00
parent 1c80dfe17f
commit 6ee7802bf1
11 changed files with 301 additions and 85 deletions

View File

@ -53,7 +53,7 @@ class SendData(ConnectionCommand):
self.data = data
def __repr__(self):
target = type(self.connection).__name__.lower()
target = str(self.connection).split("(", 1)[0].lower()
return f"SendData({target}, {self.data})"

View File

@ -66,7 +66,7 @@ class CommandReply(Event):
command: commands.Command
reply: typing.Any
def __init__(self, command: typing.Union[commands.Command, int], reply: typing.Any):
def __init__(self, command: commands.Command, reply: typing.Any):
self.command = command
self.reply = reply
@ -91,7 +91,7 @@ class OpenConnectionReply(CommandReply):
def __init__(
self,
command: typing.Union[commands.OpenConnection, int],
command: commands.OpenConnection,
err: typing.Optional[str]
):
super().__init__(command, err)
@ -113,7 +113,7 @@ class GetSocketReply(CommandReply):
def __init__(
self,
command: typing.Union[commands.GetSocket, int],
command: commands.GetSocket,
socket: socket.socket
):
super().__init__(command, socket)

View File

@ -497,6 +497,8 @@ class HttpStream(Layer):
self.flow.error = flow.Error(err)
yield commands.Hook("error", self.flow)
return
else:
self.flow.server_conn = connection
yield SendHttp(RequestHeaders(self.flow.request, self.stream_id), connection)
@ -660,6 +662,7 @@ class HTTPLayer(Layer):
can_reuse_context_connection = (
self.context.server not in self.connections and
self.context.server.connected and
self.context.server.address == event.address and
self.context.server.tls == event.tls
)
if can_reuse_context_connection:

View File

@ -1,12 +1,15 @@
import pytest
from mitmproxy import options
from mitmproxy.addons.proxyserver import Proxyserver
from mitmproxy.proxy2 import context
@pytest.fixture
def tctx():
opts = options.Options()
Proxyserver().load(opts)
return context.Context(
context.Client(("client", 1234)),
options.Options()
opts
)

View File

@ -13,7 +13,7 @@ from .. import tutils
@pytest.fixture
def ws_playbook(tctx):
tctx.server.connected = True
playbook = tutils.playbook(
playbook = tutils.Playbook(
websocket.WebsocketLayer(
tctx,
tflow.twebsocketflow().handshake_flow

View File

@ -1,29 +1,200 @@
def test_http_proxy():
import pytest
from mitmproxy.http import HTTPResponse
from mitmproxy.proxy.protocol.http import HTTPMode
from mitmproxy.proxy2.commands import Hook, OpenConnection, SendData
from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
from mitmproxy.proxy2.layers import tls
from mitmproxy.proxy2.layers.http import http
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply, reply_establish_server_tls, reply_next_layer
def test_http_proxy(tctx):
"""Test a simple HTTP GET / request"""
server = Placeholder()
flow = Placeholder()
assert (
Playbook(http.HTTPLayer(tctx, HTTPMode.regular))
>> DataReceived(tctx.client, b"GET http://example.com/foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< Hook("requestheaders", flow)
>> reply()
<< Hook("request", flow)
>> reply()
<< OpenConnection(server)
>> reply(None)
<< SendData(server, b"GET /foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n")
>> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World")
<< Hook("responseheaders", flow)
>> reply()
>> DataReceived(server, b"!")
<< Hook("response", flow)
>> reply()
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
)
assert server().address == ("example.com", 80)
def test_https_proxy_eager():
"""Test a CONNECT request, followed by TLS, followed by a HTTP GET /"""
@pytest.mark.parametrize("strategy", ["lazy", "eager"])
def test_https_proxy(strategy, tctx):
"""Test a CONNECT request, followed by a HTTP GET /"""
server = Placeholder()
flow = Placeholder()
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular))
tctx.options.connection_strategy = strategy
(playbook
>> DataReceived(tctx.client, b"CONNECT example.proxy:80 HTTP/1.1\r\n\r\n")
<< Hook("http_connect", Placeholder())
>> reply())
if strategy == "eager":
(playbook
<< OpenConnection(server)
>> reply(None))
(playbook
<< SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n')
>> DataReceived(tctx.client, b"GET /foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< Hook("next_layer", Placeholder())
>> reply_next_layer(lambda ctx: http.HTTPLayer(ctx, HTTPMode.transparent))
<< Hook("requestheaders", flow)
>> reply()
<< Hook("request", flow)
>> reply())
if strategy == "lazy":
(playbook
<< OpenConnection(server)
>> reply(None))
(playbook
<< SendData(server, b"GET /foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n")
>> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
<< Hook("responseheaders", flow)
>> reply()
<< Hook("response", flow)
>> reply()
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!"))
assert playbook
def test_https_proxy_lazy():
"""Test a CONNECT request, followed by TLS, followed by a HTTP GET /"""
@pytest.mark.parametrize("https_client", [False, True])
@pytest.mark.parametrize("https_server", [False, True])
@pytest.mark.parametrize("strategy", ["lazy", "eager"])
def test_redirect(strategy, https_server, https_client, tctx):
"""Test redirects between http:// and https:// in regular proxy mode."""
server = Placeholder()
flow = Placeholder()
tctx.options.connection_strategy = strategy
p = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
def redirect(hook: Hook):
if https_server:
hook.data.request.url = "https://redirected.site/"
else:
hook.data.request.url = "http://redirected.site/"
if https_client:
p >> DataReceived(tctx.client, b"CONNECT example.com:80 HTTP/1.1\r\n\r\n")
if strategy == "eager":
p << OpenConnection(Placeholder())
p >> reply(None)
p << SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n')
p >> DataReceived(tctx.client, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
p << Hook("next_layer", Placeholder())
p >> reply_next_layer(lambda ctx: http.HTTPLayer(ctx, HTTPMode.transparent))
else:
p >> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
p << Hook("request", flow)
p >> reply(side_effect=redirect)
p << OpenConnection(server)
p >> reply(None)
if https_server:
p << tls.EstablishServerTLS(server)
p >> reply_establish_server_tls()
p << SendData(server, b"GET / HTTP/1.1\r\nHost: redirected.site\r\n\r\n")
p >> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
p << SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
assert p
if https_server:
assert server().address == ("redirected.site", 443)
else:
assert server().address == ("redirected.site", 80)
def test_http_to_https():
"""Test a simple HTTP GET request that is being rewritten to HTTPS by an addon."""
def test_http_redirect():
"""Test a simple HTTP GET request that redirected to another host"""
def test_multiple_server_connections():
def test_multiple_server_connections(tctx):
"""Test multiple requests being rewritten to different targets."""
server1 = Placeholder()
server2 = Placeholder()
def redirect(to: str):
def side_effect(hook: Hook):
hook.data.request.url = to
return side_effect
assert (
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< Hook("request", Placeholder())
>> reply(side_effect=redirect("http://one.redirect/"))
<< OpenConnection(server1)
>> reply(None)
<< SendData(server1, b"GET / HTTP/1.1\r\nHost: one.redirect\r\n\r\n")
>> DataReceived(server1, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< Hook("request", Placeholder())
>> reply(side_effect=redirect("http://two.redirect/"))
<< OpenConnection(server2)
>> reply(None)
<< SendData(server2, b"GET / HTTP/1.1\r\nHost: two.redirect\r\n\r\n")
>> DataReceived(server2, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
)
assert server1().address == ("one.redirect", 80)
assert server2().address == ("two.redirect", 80)
def test_http_reply_from_proxy():
def test_http_reply_from_proxy(tctx):
"""Test a response served by mitmproxy itself."""
def test_disconnect_while_intercept():
def reply_from_proxy(hook: Hook):
hook.data.response = HTTPResponse.make(418)
assert (
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< Hook("request", Placeholder())
>> reply(side_effect=reply_from_proxy)
<< SendData(tctx.client, b"HTTP/1.1 418 I'm a teapot\r\ncontent-length: 0\r\n\r\n")
)
def test_disconnect_while_intercept(tctx):
"""Test a server disconnect while a request is intercepted."""
tctx.options.connection_strategy = "eager"
server1 = Placeholder()
server2 = Placeholder()
flow = Placeholder()
assert (
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
>> DataReceived(tctx.client, b"CONNECT example.com:80 HTTP/1.1\r\n\r\n")
<< Hook("http_connect", Placeholder())
>> reply()
<< OpenConnection(server1)
>> reply(None)
<< SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n')
>> DataReceived(tctx.client, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< Hook("next_layer", Placeholder())
>> reply_next_layer(lambda ctx: http.HTTPLayer(ctx, HTTPMode.transparent))
<< Hook("request", flow)
>> ConnectionClosed(server1)
>> reply(to=-2)
<< OpenConnection(server2)
>> reply(None)
<< SendData(server2, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
>> DataReceived(server2, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
)
assert server1() != server2()
assert flow().server_conn == server2()

View File

@ -1,7 +1,7 @@
from mitmproxy.proxy2.commands import CloseConnection, Hook, OpenConnection, SendData
from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
from mitmproxy.proxy2.layers import TCPLayer
from ..tutils import Placeholder, playbook, reply
from ..tutils import Placeholder, Playbook, reply
def test_open_connection(tctx):
@ -10,13 +10,13 @@ def test_open_connection(tctx):
because the server may send data first.
"""
assert (
playbook(TCPLayer(tctx, True))
Playbook(TCPLayer(tctx, True))
<< OpenConnection(tctx.server)
)
tctx.server.connected = True
assert (
playbook(TCPLayer(tctx, True))
Playbook(TCPLayer(tctx, True))
<< None
)
@ -24,7 +24,7 @@ def test_open_connection(tctx):
def test_open_connection_err(tctx):
f = Placeholder()
assert (
playbook(TCPLayer(tctx))
Playbook(TCPLayer(tctx))
<< Hook("tcp_start", f)
>> reply()
<< OpenConnection(tctx.server)
@ -40,7 +40,7 @@ def test_simple(tctx):
f = Placeholder()
assert (
playbook(TCPLayer(tctx))
Playbook(TCPLayer(tctx))
<< Hook("tcp_start", f)
>> reply()
<< OpenConnection(tctx.server)
@ -71,7 +71,7 @@ def test_receive_data_before_server_connected(tctx):
will still be forwarded.
"""
assert (
playbook(TCPLayer(tctx), hooks=False)
Playbook(TCPLayer(tctx), hooks=False)
<< OpenConnection(tctx.server)
>> DataReceived(tctx.client, b"hello!")
>> reply(None, to=-2)
@ -84,7 +84,7 @@ def test_receive_data_after_half_close(tctx):
data received after the other connection has been half-closed should still be forwarded.
"""
assert (
playbook(TCPLayer(tctx), hooks=False)
Playbook(TCPLayer(tctx), hooks=False)
<< OpenConnection(tctx.server)
>> reply(None)
>> ConnectionClosed(tctx.server)

View File

@ -87,7 +87,7 @@ class SSLTest:
)
def _test_echo(playbook: tutils.playbook, tssl: SSLTest, conn: context.Connection) -> None:
def _test_echo(playbook: tutils.Playbook, tssl: SSLTest, conn: context.Connection) -> None:
tssl.obj.write(b"Hello World")
data = tutils.Placeholder()
assert (
@ -110,7 +110,7 @@ class TlsEchoLayer(tutils.EchoLayer):
yield from super()._handle_event(event)
def interact(playbook: tutils.playbook, conn: context.Connection, tssl: SSLTest):
def interact(playbook: tutils.Playbook, conn: context.Connection, tssl: SSLTest):
data = tutils.Placeholder()
assert (
playbook
@ -149,7 +149,7 @@ class TestServerTLS:
# Handshake
assert (
tutils.playbook(layer)
tutils.Playbook(layer)
>> events.DataReceived(tctx.client, b"Hello World")
<< commands.SendData(tctx.client, b"hello world")
>> events.DataReceived(tctx.server, b"Foo")
@ -158,7 +158,7 @@ class TestServerTLS:
def test_simple(self, tctx):
layer = tls.ServerTLSLayer(tctx)
playbook = tutils.playbook(layer)
playbook = tutils.Playbook(layer)
tctx.server.connected = True
tctx.server.address = ("example.com", 443)
@ -170,7 +170,7 @@ class TestServerTLS:
playbook
>> events.DataReceived(tctx.client, b"establish-server-tls")
<< commands.Hook("next_layer", tutils.Placeholder())
>> tutils.next_layer(TlsEchoLayer)
>> tutils.reply_next_layer(TlsEchoLayer)
<< commands.Hook("tls_start", tutils.Placeholder())
>> reply_tls_start()
<< commands.SendData(tctx.server, data)
@ -196,21 +196,21 @@ class TestServerTLS:
_test_echo(playbook, tssl, tctx.server)
def _make_client_tls_layer(tctx: context.Context) -> typing.Tuple[tutils.playbook, tls.ClientTLSLayer]:
def _make_client_tls_layer(tctx: context.Context) -> typing.Tuple[tutils.Playbook, tls.ClientTLSLayer]:
# This is a bit contrived as the client layer expects a server layer as parent.
# We also set child layers manually to avoid NextLayer noise.
server_layer = tls.ServerTLSLayer(tctx)
client_layer = tls.ClientTLSLayer(tctx)
server_layer.child_layer = client_layer
client_layer.child_layer = TlsEchoLayer(tctx)
playbook = tutils.playbook(server_layer)
playbook = tutils.Playbook(server_layer)
return playbook, client_layer
def _test_tls_client_server(
tctx: context.Context,
sni: typing.Optional[bytes]
) -> typing.Tuple[tutils.playbook, tls.ClientTLSLayer, SSLTest]:
) -> typing.Tuple[tutils.Playbook, tls.ClientTLSLayer, SSLTest]:
playbook, client_layer = _make_client_tls_layer(tctx)
tctx.server.tls = True
tctx.server.address = ("example.com", 443)

View File

@ -5,7 +5,7 @@ from test.mitmproxy.proxy2 import tutils
class TestNextLayer:
def test_simple(self, tctx):
nl = layer.NextLayer(tctx)
playbook = tutils.playbook(nl, hooks=True)
playbook = tutils.Playbook(nl, hooks=True)
assert (
playbook
@ -32,7 +32,7 @@ class TestNextLayer:
a reply from the proxy core.
"""
nl = layer.NextLayer(tctx)
playbook = tutils.playbook(nl)
playbook = tutils.Playbook(nl)
assert (
playbook
@ -52,7 +52,7 @@ class TestNextLayer:
def test_func_references(self, tctx):
nl = layer.NextLayer(tctx)
playbook = tutils.playbook(nl)
playbook = tutils.Playbook(nl)
assert (
playbook

View File

@ -38,7 +38,7 @@ class TLayer(Layer):
@pytest.fixture
def tplaybook(tctx):
return tutils.playbook(TLayer(tctx), expected=[])
return tutils.Playbook(TLayer(tctx), expected=[])
def test_simple(tplaybook):
@ -164,7 +164,7 @@ def test_command_reply(tplaybook):
def test_default_playbook(tctx):
p = tutils.playbook(TLayer(tctx))
p = tutils.Playbook(TLayer(tctx))
assert p
assert len(p.actual) == 1
assert isinstance(p.actual[0], events.Start)

View File

@ -2,7 +2,8 @@ import collections.abc
import copy
import difflib
import itertools
import sys
import re
import traceback
import typing
from mitmproxy.proxy2 import commands, context
@ -10,14 +11,15 @@ from mitmproxy.proxy2 import events
from mitmproxy.proxy2.context import ConnectionState
from mitmproxy.proxy2.events import command_reply_subclasses
from mitmproxy.proxy2.layer import Layer, NextLayer
from mitmproxy.proxy2.layers import tls
TPlaybookEntry = typing.Union[commands.Command, events.Event]
TPlaybook = typing.List[TPlaybookEntry]
PlaybookEntry = typing.Union[commands.Command, events.Event]
PlaybookEntryList = typing.List[PlaybookEntry]
def _eq(
a: TPlaybookEntry,
b: TPlaybookEntry
a: PlaybookEntry,
b: PlaybookEntry
) -> bool:
"""Compare two commands/events, and possibly update placeholders."""
if type(a) != type(b):
@ -45,8 +47,8 @@ def _eq(
def eq(
a: typing.Union[TPlaybookEntry, typing.Iterable[TPlaybookEntry]],
b: typing.Union[TPlaybookEntry, typing.Iterable[TPlaybookEntry]]
a: typing.Union[PlaybookEntry, typing.Iterable[PlaybookEntry]],
b: typing.Union[PlaybookEntry, typing.Iterable[PlaybookEntry]]
):
"""
Compare an indiviual event/command or a list of events/commands.
@ -58,15 +60,39 @@ def eq(
return _eq(a, b)
def _fmt_entry(x: TPlaybookEntry):
def _fmt_entry(x: PlaybookEntry):
arrow = ">>" if isinstance(x, events.Event) else "<<"
x = str(x) \
.replace('Placeholder:None', '<unset placeholder>') \
.replace('Placeholder:', '')
x = str(x)
x = re.sub('Placeholder:None', '<unset placeholder>', x, flags=re.IGNORECASE)
x = re.sub('Placeholder:', '', x, flags=re.IGNORECASE)
return f"{arrow} {x}"
class playbook:
def _merge_sends(lst: PlaybookEntryList) -> PlaybookEntryList:
merged = lst[:1]
for x in lst[1:]:
prev = merged[-1]
two_subsequent_sends_to_the_same_remote = (
isinstance(x, commands.SendData) and
isinstance(prev, commands.SendData) and
x.connection is prev.connection
)
if two_subsequent_sends_to_the_same_remote:
prev.data += x.data
else:
merged.append(x)
return merged
class _TracebackInPlaybook(commands.Command):
def __init__(self, exc):
self.e = exc
def __repr__(self):
return self.e
class Playbook:
"""
Assert that a layer emits the expected commands in reaction to a given sequence of events.
For example, the following code asserts that the TCP layer emits an OpenConnection command
@ -88,9 +114,9 @@ class playbook:
"""
layer: Layer
"""The base layer"""
expected: TPlaybook
expected: PlaybookEntryList
"""expected command/event sequence"""
actual: TPlaybook
actual: PlaybookEntryList
"""actual command/event sequence"""
_errored: bool
"""used to check if playbook as been fully asserted"""
@ -98,13 +124,15 @@ class playbook:
"""If False, the playbook specification doesn't contain log commands."""
hooks: bool
"""If False, the playbook specification doesn't include hooks or hook replies. They are automatically replied to."""
merge_sends: bool
"""If True, subsequent SendData commands to the same remote will be merged in both expected and actual playbook."""
def __init__(
self,
layer: Layer,
hooks: bool = True,
logs: bool = False,
expected: typing.Optional[TPlaybook] = None,
expected: typing.Optional[PlaybookEntryList] = None,
):
if expected is None:
expected = [
@ -121,8 +149,6 @@ class playbook:
def __rshift__(self, e):
"""Add an event to send"""
assert isinstance(e, events.Event)
if not self.hooks and isinstance(e, events.HookReply):
raise ValueError(f"Playbook must not contain hook replies if hooks=False: {e}")
self.expected.append(e)
return self
@ -131,10 +157,6 @@ class playbook:
if c is None:
return self
assert isinstance(c, commands.Command)
if not self.logs and isinstance(c, commands.Log):
raise ValueError(f"Playbook must not contain log commands if logs=False: {c}")
if not self.hooks and isinstance(c, commands.Hook):
raise ValueError(f"Playbook must not contain hook commands if hooks=False: {c}")
self.expected.append(c)
return self
@ -149,25 +171,41 @@ class playbook:
else:
if hasattr(x, "playbook_eval"):
x = self.expected[i] = x.playbook_eval(self)
for name, value in vars(x).items():
if isinstance(value, _Placeholder):
setattr(x, name, value())
if isinstance(x, events.OpenConnectionReply) and not x.reply:
x.command.connection.state = ConnectionState.OPEN
elif isinstance(x, events.ConnectionClosed):
x.connection.state &= ~ConnectionState.CAN_READ
self.actual.append(x)
try:
cmds = list(self.layer.handle_event(x))
except Exception:
self.actual.append(_TracebackInPlaybook(traceback.format_exc()))
break
self.actual.extend(cmds)
if not self.logs:
for offset, cmd in enumerate(cmds):
if isinstance(cmd, commands.Log):
self.expected.insert(i + 1 + offset, cmd)
pos = i + 1 + offset
if isinstance(cmd, commands.Log) and not isinstance(self.expected[pos], commands.Log):
self.expected.insert(pos, cmd)
if not self.hooks:
last_cmd = self.actual[-1]
if isinstance(last_cmd, commands.Hook):
self.expected.insert(i + len(cmds), last_cmd)
self.expected.insert(i + len(cmds) + 1, events.HookReply(last_cmd))
pos = i + len(cmds)
need_to_emulate_hook = (
isinstance(last_cmd, commands.Hook) and
not (isinstance(self.expected[pos], commands.Hook) and self.expected[pos].name == last_cmd.name)
)
if need_to_emulate_hook:
self.expected.insert(pos, last_cmd)
self.expected.insert(pos + 1, events.HookReply(last_cmd))
i += 1
self.actual = _merge_sends(self.actual)
self.expected = _merge_sends(self.expected)
if not eq(self.expected, self.actual):
self._errored = True
diff = "\n".join(difflib.ndiff(
@ -210,7 +248,7 @@ class reply(events.Event):
self.to = to
self.side_effect = side_effect
def playbook_eval(self, playbook: playbook) -> events.CommandReply:
def playbook_eval(self, playbook: Playbook) -> events.CommandReply:
if isinstance(self.to, int):
expected = playbook.expected[:playbook.expected.index(self)]
assert abs(self.to) < len(expected)
@ -225,7 +263,7 @@ class reply(events.Event):
break
else:
actual_str = "\n".join(_fmt_entry(x) for x in playbook.actual)
raise AssertionError(f"Expected command ({self.to}) did not occur:\n{actual_str}")
raise AssertionError(f"Expected command {self.to} did not occur:\n{actual_str}")
assert isinstance(self.to, commands.Command)
self.side_effect(self.to)
@ -279,23 +317,12 @@ class EchoLayer(Layer):
yield commands.SendData(event.connection, event.data.lower())
def next_layer(
def reply_next_layer(
layer: typing.Union[typing.Type[Layer], typing.Callable[[context.Context], Layer]],
*args,
**kwargs
) -> reply:
"""
Helper function to simplify the syntax for next_layer events from this:
<< commands.Hook("next_layer", next_layer)
)
next_layer().layer = tutils.EchoLayer(next_layer().context)
assert (
playbook
>> events.HookReply(-1)
to this:
"""Helper function to simplify the syntax for next_layer events to this:
<< commands.Hook("next_layer", next_layer)
>> tutils.next_layer(next_layer, tutils.EchoLayer)
"""
@ -305,3 +332,15 @@ def next_layer(
hook.data.layer = layer(hook.data.context)
return reply(*args, side_effect=set_layer, **kwargs)
def reply_establish_server_tls(**kwargs) -> reply:
"""Helper function to simplify the syntax for EstablishServerTls events to this:
<< tls.EstablishServerTLS(server)
>> tutils.reply_establish_server_tls()
"""
def fake_tls(cmd: tls.EstablishServerTLS) -> None:
cmd.connection.tls_established = True
return reply(None, side_effect=fake_tls, **kwargs)