mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 00:01:36 +00:00
[sans-io] http tests++
This commit is contained in:
parent
1c80dfe17f
commit
6ee7802bf1
@ -53,7 +53,7 @@ class SendData(ConnectionCommand):
|
||||
self.data = data
|
||||
|
||||
def __repr__(self):
|
||||
target = type(self.connection).__name__.lower()
|
||||
target = str(self.connection).split("(", 1)[0].lower()
|
||||
return f"SendData({target}, {self.data})"
|
||||
|
||||
|
||||
|
@ -66,7 +66,7 @@ class CommandReply(Event):
|
||||
command: commands.Command
|
||||
reply: typing.Any
|
||||
|
||||
def __init__(self, command: typing.Union[commands.Command, int], reply: typing.Any):
|
||||
def __init__(self, command: commands.Command, reply: typing.Any):
|
||||
self.command = command
|
||||
self.reply = reply
|
||||
|
||||
@ -91,7 +91,7 @@ class OpenConnectionReply(CommandReply):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
command: typing.Union[commands.OpenConnection, int],
|
||||
command: commands.OpenConnection,
|
||||
err: typing.Optional[str]
|
||||
):
|
||||
super().__init__(command, err)
|
||||
@ -113,7 +113,7 @@ class GetSocketReply(CommandReply):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
command: typing.Union[commands.GetSocket, int],
|
||||
command: commands.GetSocket,
|
||||
socket: socket.socket
|
||||
):
|
||||
super().__init__(command, socket)
|
||||
|
@ -497,6 +497,8 @@ class HttpStream(Layer):
|
||||
self.flow.error = flow.Error(err)
|
||||
yield commands.Hook("error", self.flow)
|
||||
return
|
||||
else:
|
||||
self.flow.server_conn = connection
|
||||
|
||||
yield SendHttp(RequestHeaders(self.flow.request, self.stream_id), connection)
|
||||
|
||||
@ -660,6 +662,7 @@ class HTTPLayer(Layer):
|
||||
can_reuse_context_connection = (
|
||||
self.context.server not in self.connections and
|
||||
self.context.server.connected and
|
||||
self.context.server.address == event.address and
|
||||
self.context.server.tls == event.tls
|
||||
)
|
||||
if can_reuse_context_connection:
|
||||
|
@ -1,12 +1,15 @@
|
||||
import pytest
|
||||
|
||||
from mitmproxy import options
|
||||
from mitmproxy.addons.proxyserver import Proxyserver
|
||||
from mitmproxy.proxy2 import context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tctx():
|
||||
opts = options.Options()
|
||||
Proxyserver().load(opts)
|
||||
return context.Context(
|
||||
context.Client(("client", 1234)),
|
||||
options.Options()
|
||||
opts
|
||||
)
|
||||
|
@ -13,7 +13,7 @@ from .. import tutils
|
||||
@pytest.fixture
|
||||
def ws_playbook(tctx):
|
||||
tctx.server.connected = True
|
||||
playbook = tutils.playbook(
|
||||
playbook = tutils.Playbook(
|
||||
websocket.WebsocketLayer(
|
||||
tctx,
|
||||
tflow.twebsocketflow().handshake_flow
|
@ -1,29 +1,200 @@
|
||||
def test_http_proxy():
|
||||
import pytest
|
||||
|
||||
from mitmproxy.http import HTTPResponse
|
||||
from mitmproxy.proxy.protocol.http import HTTPMode
|
||||
from mitmproxy.proxy2.commands import Hook, OpenConnection, SendData
|
||||
from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
|
||||
from mitmproxy.proxy2.layers import tls
|
||||
from mitmproxy.proxy2.layers.http import http
|
||||
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply, reply_establish_server_tls, reply_next_layer
|
||||
|
||||
|
||||
def test_http_proxy(tctx):
|
||||
"""Test a simple HTTP GET / request"""
|
||||
server = Placeholder()
|
||||
flow = Placeholder()
|
||||
assert (
|
||||
Playbook(http.HTTPLayer(tctx, HTTPMode.regular))
|
||||
>> DataReceived(tctx.client, b"GET http://example.com/foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
<< Hook("requestheaders", flow)
|
||||
>> reply()
|
||||
<< Hook("request", flow)
|
||||
>> reply()
|
||||
<< OpenConnection(server)
|
||||
>> reply(None)
|
||||
<< SendData(server, b"GET /foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
>> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World")
|
||||
<< Hook("responseheaders", flow)
|
||||
>> reply()
|
||||
>> DataReceived(server, b"!")
|
||||
<< Hook("response", flow)
|
||||
>> reply()
|
||||
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
|
||||
)
|
||||
assert server().address == ("example.com", 80)
|
||||
|
||||
|
||||
def test_https_proxy_eager():
|
||||
"""Test a CONNECT request, followed by TLS, followed by a HTTP GET /"""
|
||||
@pytest.mark.parametrize("strategy", ["lazy", "eager"])
|
||||
def test_https_proxy(strategy, tctx):
|
||||
"""Test a CONNECT request, followed by a HTTP GET /"""
|
||||
server = Placeholder()
|
||||
flow = Placeholder()
|
||||
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular))
|
||||
tctx.options.connection_strategy = strategy
|
||||
|
||||
(playbook
|
||||
>> DataReceived(tctx.client, b"CONNECT example.proxy:80 HTTP/1.1\r\n\r\n")
|
||||
<< Hook("http_connect", Placeholder())
|
||||
>> reply())
|
||||
if strategy == "eager":
|
||||
(playbook
|
||||
<< OpenConnection(server)
|
||||
>> reply(None))
|
||||
(playbook
|
||||
<< SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n')
|
||||
>> DataReceived(tctx.client, b"GET /foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
<< Hook("next_layer", Placeholder())
|
||||
>> reply_next_layer(lambda ctx: http.HTTPLayer(ctx, HTTPMode.transparent))
|
||||
<< Hook("requestheaders", flow)
|
||||
>> reply()
|
||||
<< Hook("request", flow)
|
||||
>> reply())
|
||||
if strategy == "lazy":
|
||||
(playbook
|
||||
<< OpenConnection(server)
|
||||
>> reply(None))
|
||||
(playbook
|
||||
<< SendData(server, b"GET /foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
>> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
|
||||
<< Hook("responseheaders", flow)
|
||||
>> reply()
|
||||
<< Hook("response", flow)
|
||||
>> reply()
|
||||
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!"))
|
||||
assert playbook
|
||||
|
||||
|
||||
def test_https_proxy_lazy():
|
||||
"""Test a CONNECT request, followed by TLS, followed by a HTTP GET /"""
|
||||
@pytest.mark.parametrize("https_client", [False, True])
|
||||
@pytest.mark.parametrize("https_server", [False, True])
|
||||
@pytest.mark.parametrize("strategy", ["lazy", "eager"])
|
||||
def test_redirect(strategy, https_server, https_client, tctx):
|
||||
"""Test redirects between http:// and https:// in regular proxy mode."""
|
||||
server = Placeholder()
|
||||
flow = Placeholder()
|
||||
tctx.options.connection_strategy = strategy
|
||||
p = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
|
||||
def redirect(hook: Hook):
|
||||
if https_server:
|
||||
hook.data.request.url = "https://redirected.site/"
|
||||
else:
|
||||
hook.data.request.url = "http://redirected.site/"
|
||||
|
||||
if https_client:
|
||||
p >> DataReceived(tctx.client, b"CONNECT example.com:80 HTTP/1.1\r\n\r\n")
|
||||
if strategy == "eager":
|
||||
p << OpenConnection(Placeholder())
|
||||
p >> reply(None)
|
||||
p << SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n')
|
||||
p >> DataReceived(tctx.client, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
p << Hook("next_layer", Placeholder())
|
||||
p >> reply_next_layer(lambda ctx: http.HTTPLayer(ctx, HTTPMode.transparent))
|
||||
else:
|
||||
p >> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
p << Hook("request", flow)
|
||||
p >> reply(side_effect=redirect)
|
||||
p << OpenConnection(server)
|
||||
p >> reply(None)
|
||||
if https_server:
|
||||
p << tls.EstablishServerTLS(server)
|
||||
p >> reply_establish_server_tls()
|
||||
p << SendData(server, b"GET / HTTP/1.1\r\nHost: redirected.site\r\n\r\n")
|
||||
p >> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
|
||||
p << SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
|
||||
|
||||
assert p
|
||||
if https_server:
|
||||
assert server().address == ("redirected.site", 443)
|
||||
else:
|
||||
assert server().address == ("redirected.site", 80)
|
||||
|
||||
|
||||
def test_http_to_https():
|
||||
"""Test a simple HTTP GET request that is being rewritten to HTTPS by an addon."""
|
||||
|
||||
|
||||
def test_http_redirect():
|
||||
"""Test a simple HTTP GET request that redirected to another host"""
|
||||
|
||||
|
||||
def test_multiple_server_connections():
|
||||
def test_multiple_server_connections(tctx):
|
||||
"""Test multiple requests being rewritten to different targets."""
|
||||
server1 = Placeholder()
|
||||
server2 = Placeholder()
|
||||
|
||||
def redirect(to: str):
|
||||
def side_effect(hook: Hook):
|
||||
hook.data.request.url = to
|
||||
|
||||
return side_effect
|
||||
|
||||
assert (
|
||||
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
<< Hook("request", Placeholder())
|
||||
>> reply(side_effect=redirect("http://one.redirect/"))
|
||||
<< OpenConnection(server1)
|
||||
>> reply(None)
|
||||
<< SendData(server1, b"GET / HTTP/1.1\r\nHost: one.redirect\r\n\r\n")
|
||||
>> DataReceived(server1, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
|
||||
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
|
||||
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
<< Hook("request", Placeholder())
|
||||
>> reply(side_effect=redirect("http://two.redirect/"))
|
||||
<< OpenConnection(server2)
|
||||
>> reply(None)
|
||||
<< SendData(server2, b"GET / HTTP/1.1\r\nHost: two.redirect\r\n\r\n")
|
||||
>> DataReceived(server2, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
|
||||
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
|
||||
)
|
||||
assert server1().address == ("one.redirect", 80)
|
||||
assert server2().address == ("two.redirect", 80)
|
||||
|
||||
|
||||
def test_http_reply_from_proxy():
|
||||
def test_http_reply_from_proxy(tctx):
|
||||
"""Test a response served by mitmproxy itself."""
|
||||
|
||||
def test_disconnect_while_intercept():
|
||||
def reply_from_proxy(hook: Hook):
|
||||
hook.data.response = HTTPResponse.make(418)
|
||||
|
||||
assert (
|
||||
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
<< Hook("request", Placeholder())
|
||||
>> reply(side_effect=reply_from_proxy)
|
||||
<< SendData(tctx.client, b"HTTP/1.1 418 I'm a teapot\r\ncontent-length: 0\r\n\r\n")
|
||||
)
|
||||
|
||||
|
||||
def test_disconnect_while_intercept(tctx):
|
||||
"""Test a server disconnect while a request is intercepted."""
|
||||
tctx.options.connection_strategy = "eager"
|
||||
|
||||
server1 = Placeholder()
|
||||
server2 = Placeholder()
|
||||
flow = Placeholder()
|
||||
|
||||
assert (
|
||||
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
>> DataReceived(tctx.client, b"CONNECT example.com:80 HTTP/1.1\r\n\r\n")
|
||||
<< Hook("http_connect", Placeholder())
|
||||
>> reply()
|
||||
<< OpenConnection(server1)
|
||||
>> reply(None)
|
||||
<< SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n')
|
||||
>> DataReceived(tctx.client, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
<< Hook("next_layer", Placeholder())
|
||||
>> reply_next_layer(lambda ctx: http.HTTPLayer(ctx, HTTPMode.transparent))
|
||||
<< Hook("request", flow)
|
||||
>> ConnectionClosed(server1)
|
||||
>> reply(to=-2)
|
||||
<< OpenConnection(server2)
|
||||
>> reply(None)
|
||||
<< SendData(server2, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
>> DataReceived(server2, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
|
||||
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
|
||||
)
|
||||
assert server1() != server2()
|
||||
assert flow().server_conn == server2()
|
||||
|
@ -1,7 +1,7 @@
|
||||
from mitmproxy.proxy2.commands import CloseConnection, Hook, OpenConnection, SendData
|
||||
from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
|
||||
from mitmproxy.proxy2.layers import TCPLayer
|
||||
from ..tutils import Placeholder, playbook, reply
|
||||
from ..tutils import Placeholder, Playbook, reply
|
||||
|
||||
|
||||
def test_open_connection(tctx):
|
||||
@ -10,13 +10,13 @@ def test_open_connection(tctx):
|
||||
because the server may send data first.
|
||||
"""
|
||||
assert (
|
||||
playbook(TCPLayer(tctx, True))
|
||||
Playbook(TCPLayer(tctx, True))
|
||||
<< OpenConnection(tctx.server)
|
||||
)
|
||||
|
||||
tctx.server.connected = True
|
||||
assert (
|
||||
playbook(TCPLayer(tctx, True))
|
||||
Playbook(TCPLayer(tctx, True))
|
||||
<< None
|
||||
)
|
||||
|
||||
@ -24,7 +24,7 @@ def test_open_connection(tctx):
|
||||
def test_open_connection_err(tctx):
|
||||
f = Placeholder()
|
||||
assert (
|
||||
playbook(TCPLayer(tctx))
|
||||
Playbook(TCPLayer(tctx))
|
||||
<< Hook("tcp_start", f)
|
||||
>> reply()
|
||||
<< OpenConnection(tctx.server)
|
||||
@ -40,7 +40,7 @@ def test_simple(tctx):
|
||||
f = Placeholder()
|
||||
|
||||
assert (
|
||||
playbook(TCPLayer(tctx))
|
||||
Playbook(TCPLayer(tctx))
|
||||
<< Hook("tcp_start", f)
|
||||
>> reply()
|
||||
<< OpenConnection(tctx.server)
|
||||
@ -71,7 +71,7 @@ def test_receive_data_before_server_connected(tctx):
|
||||
will still be forwarded.
|
||||
"""
|
||||
assert (
|
||||
playbook(TCPLayer(tctx), hooks=False)
|
||||
Playbook(TCPLayer(tctx), hooks=False)
|
||||
<< OpenConnection(tctx.server)
|
||||
>> DataReceived(tctx.client, b"hello!")
|
||||
>> reply(None, to=-2)
|
||||
@ -84,7 +84,7 @@ def test_receive_data_after_half_close(tctx):
|
||||
data received after the other connection has been half-closed should still be forwarded.
|
||||
"""
|
||||
assert (
|
||||
playbook(TCPLayer(tctx), hooks=False)
|
||||
Playbook(TCPLayer(tctx), hooks=False)
|
||||
<< OpenConnection(tctx.server)
|
||||
>> reply(None)
|
||||
>> ConnectionClosed(tctx.server)
|
||||
|
@ -87,7 +87,7 @@ class SSLTest:
|
||||
)
|
||||
|
||||
|
||||
def _test_echo(playbook: tutils.playbook, tssl: SSLTest, conn: context.Connection) -> None:
|
||||
def _test_echo(playbook: tutils.Playbook, tssl: SSLTest, conn: context.Connection) -> None:
|
||||
tssl.obj.write(b"Hello World")
|
||||
data = tutils.Placeholder()
|
||||
assert (
|
||||
@ -110,7 +110,7 @@ class TlsEchoLayer(tutils.EchoLayer):
|
||||
yield from super()._handle_event(event)
|
||||
|
||||
|
||||
def interact(playbook: tutils.playbook, conn: context.Connection, tssl: SSLTest):
|
||||
def interact(playbook: tutils.Playbook, conn: context.Connection, tssl: SSLTest):
|
||||
data = tutils.Placeholder()
|
||||
assert (
|
||||
playbook
|
||||
@ -149,7 +149,7 @@ class TestServerTLS:
|
||||
|
||||
# Handshake
|
||||
assert (
|
||||
tutils.playbook(layer)
|
||||
tutils.Playbook(layer)
|
||||
>> events.DataReceived(tctx.client, b"Hello World")
|
||||
<< commands.SendData(tctx.client, b"hello world")
|
||||
>> events.DataReceived(tctx.server, b"Foo")
|
||||
@ -158,7 +158,7 @@ class TestServerTLS:
|
||||
|
||||
def test_simple(self, tctx):
|
||||
layer = tls.ServerTLSLayer(tctx)
|
||||
playbook = tutils.playbook(layer)
|
||||
playbook = tutils.Playbook(layer)
|
||||
tctx.server.connected = True
|
||||
tctx.server.address = ("example.com", 443)
|
||||
|
||||
@ -170,7 +170,7 @@ class TestServerTLS:
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, b"establish-server-tls")
|
||||
<< commands.Hook("next_layer", tutils.Placeholder())
|
||||
>> tutils.next_layer(TlsEchoLayer)
|
||||
>> tutils.reply_next_layer(TlsEchoLayer)
|
||||
<< commands.Hook("tls_start", tutils.Placeholder())
|
||||
>> reply_tls_start()
|
||||
<< commands.SendData(tctx.server, data)
|
||||
@ -196,21 +196,21 @@ class TestServerTLS:
|
||||
_test_echo(playbook, tssl, tctx.server)
|
||||
|
||||
|
||||
def _make_client_tls_layer(tctx: context.Context) -> typing.Tuple[tutils.playbook, tls.ClientTLSLayer]:
|
||||
def _make_client_tls_layer(tctx: context.Context) -> typing.Tuple[tutils.Playbook, tls.ClientTLSLayer]:
|
||||
# This is a bit contrived as the client layer expects a server layer as parent.
|
||||
# We also set child layers manually to avoid NextLayer noise.
|
||||
server_layer = tls.ServerTLSLayer(tctx)
|
||||
client_layer = tls.ClientTLSLayer(tctx)
|
||||
server_layer.child_layer = client_layer
|
||||
client_layer.child_layer = TlsEchoLayer(tctx)
|
||||
playbook = tutils.playbook(server_layer)
|
||||
playbook = tutils.Playbook(server_layer)
|
||||
return playbook, client_layer
|
||||
|
||||
|
||||
def _test_tls_client_server(
|
||||
tctx: context.Context,
|
||||
sni: typing.Optional[bytes]
|
||||
) -> typing.Tuple[tutils.playbook, tls.ClientTLSLayer, SSLTest]:
|
||||
) -> typing.Tuple[tutils.Playbook, tls.ClientTLSLayer, SSLTest]:
|
||||
playbook, client_layer = _make_client_tls_layer(tctx)
|
||||
tctx.server.tls = True
|
||||
tctx.server.address = ("example.com", 443)
|
||||
|
@ -5,7 +5,7 @@ from test.mitmproxy.proxy2 import tutils
|
||||
class TestNextLayer:
|
||||
def test_simple(self, tctx):
|
||||
nl = layer.NextLayer(tctx)
|
||||
playbook = tutils.playbook(nl, hooks=True)
|
||||
playbook = tutils.Playbook(nl, hooks=True)
|
||||
|
||||
assert (
|
||||
playbook
|
||||
@ -32,7 +32,7 @@ class TestNextLayer:
|
||||
a reply from the proxy core.
|
||||
"""
|
||||
nl = layer.NextLayer(tctx)
|
||||
playbook = tutils.playbook(nl)
|
||||
playbook = tutils.Playbook(nl)
|
||||
|
||||
assert (
|
||||
playbook
|
||||
@ -52,7 +52,7 @@ class TestNextLayer:
|
||||
|
||||
def test_func_references(self, tctx):
|
||||
nl = layer.NextLayer(tctx)
|
||||
playbook = tutils.playbook(nl)
|
||||
playbook = tutils.Playbook(nl)
|
||||
|
||||
assert (
|
||||
playbook
|
||||
|
@ -38,7 +38,7 @@ class TLayer(Layer):
|
||||
|
||||
@pytest.fixture
|
||||
def tplaybook(tctx):
|
||||
return tutils.playbook(TLayer(tctx), expected=[])
|
||||
return tutils.Playbook(TLayer(tctx), expected=[])
|
||||
|
||||
|
||||
def test_simple(tplaybook):
|
||||
@ -164,7 +164,7 @@ def test_command_reply(tplaybook):
|
||||
|
||||
|
||||
def test_default_playbook(tctx):
|
||||
p = tutils.playbook(TLayer(tctx))
|
||||
p = tutils.Playbook(TLayer(tctx))
|
||||
assert p
|
||||
assert len(p.actual) == 1
|
||||
assert isinstance(p.actual[0], events.Start)
|
||||
|
@ -2,7 +2,8 @@ import collections.abc
|
||||
import copy
|
||||
import difflib
|
||||
import itertools
|
||||
import sys
|
||||
import re
|
||||
import traceback
|
||||
import typing
|
||||
|
||||
from mitmproxy.proxy2 import commands, context
|
||||
@ -10,14 +11,15 @@ from mitmproxy.proxy2 import events
|
||||
from mitmproxy.proxy2.context import ConnectionState
|
||||
from mitmproxy.proxy2.events import command_reply_subclasses
|
||||
from mitmproxy.proxy2.layer import Layer, NextLayer
|
||||
from mitmproxy.proxy2.layers import tls
|
||||
|
||||
TPlaybookEntry = typing.Union[commands.Command, events.Event]
|
||||
TPlaybook = typing.List[TPlaybookEntry]
|
||||
PlaybookEntry = typing.Union[commands.Command, events.Event]
|
||||
PlaybookEntryList = typing.List[PlaybookEntry]
|
||||
|
||||
|
||||
def _eq(
|
||||
a: TPlaybookEntry,
|
||||
b: TPlaybookEntry
|
||||
a: PlaybookEntry,
|
||||
b: PlaybookEntry
|
||||
) -> bool:
|
||||
"""Compare two commands/events, and possibly update placeholders."""
|
||||
if type(a) != type(b):
|
||||
@ -45,8 +47,8 @@ def _eq(
|
||||
|
||||
|
||||
def eq(
|
||||
a: typing.Union[TPlaybookEntry, typing.Iterable[TPlaybookEntry]],
|
||||
b: typing.Union[TPlaybookEntry, typing.Iterable[TPlaybookEntry]]
|
||||
a: typing.Union[PlaybookEntry, typing.Iterable[PlaybookEntry]],
|
||||
b: typing.Union[PlaybookEntry, typing.Iterable[PlaybookEntry]]
|
||||
):
|
||||
"""
|
||||
Compare an indiviual event/command or a list of events/commands.
|
||||
@ -58,15 +60,39 @@ def eq(
|
||||
return _eq(a, b)
|
||||
|
||||
|
||||
def _fmt_entry(x: TPlaybookEntry):
|
||||
def _fmt_entry(x: PlaybookEntry):
|
||||
arrow = ">>" if isinstance(x, events.Event) else "<<"
|
||||
x = str(x) \
|
||||
.replace('Placeholder:None', '<unset placeholder>') \
|
||||
.replace('Placeholder:', '')
|
||||
x = str(x)
|
||||
x = re.sub('Placeholder:None', '<unset placeholder>', x, flags=re.IGNORECASE)
|
||||
x = re.sub('Placeholder:', '', x, flags=re.IGNORECASE)
|
||||
return f"{arrow} {x}"
|
||||
|
||||
|
||||
class playbook:
|
||||
def _merge_sends(lst: PlaybookEntryList) -> PlaybookEntryList:
|
||||
merged = lst[:1]
|
||||
for x in lst[1:]:
|
||||
prev = merged[-1]
|
||||
two_subsequent_sends_to_the_same_remote = (
|
||||
isinstance(x, commands.SendData) and
|
||||
isinstance(prev, commands.SendData) and
|
||||
x.connection is prev.connection
|
||||
)
|
||||
if two_subsequent_sends_to_the_same_remote:
|
||||
prev.data += x.data
|
||||
else:
|
||||
merged.append(x)
|
||||
return merged
|
||||
|
||||
|
||||
class _TracebackInPlaybook(commands.Command):
|
||||
def __init__(self, exc):
|
||||
self.e = exc
|
||||
|
||||
def __repr__(self):
|
||||
return self.e
|
||||
|
||||
|
||||
class Playbook:
|
||||
"""
|
||||
Assert that a layer emits the expected commands in reaction to a given sequence of events.
|
||||
For example, the following code asserts that the TCP layer emits an OpenConnection command
|
||||
@ -88,9 +114,9 @@ class playbook:
|
||||
"""
|
||||
layer: Layer
|
||||
"""The base layer"""
|
||||
expected: TPlaybook
|
||||
expected: PlaybookEntryList
|
||||
"""expected command/event sequence"""
|
||||
actual: TPlaybook
|
||||
actual: PlaybookEntryList
|
||||
"""actual command/event sequence"""
|
||||
_errored: bool
|
||||
"""used to check if playbook as been fully asserted"""
|
||||
@ -98,13 +124,15 @@ class playbook:
|
||||
"""If False, the playbook specification doesn't contain log commands."""
|
||||
hooks: bool
|
||||
"""If False, the playbook specification doesn't include hooks or hook replies. They are automatically replied to."""
|
||||
merge_sends: bool
|
||||
"""If True, subsequent SendData commands to the same remote will be merged in both expected and actual playbook."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: Layer,
|
||||
hooks: bool = True,
|
||||
logs: bool = False,
|
||||
expected: typing.Optional[TPlaybook] = None,
|
||||
expected: typing.Optional[PlaybookEntryList] = None,
|
||||
):
|
||||
if expected is None:
|
||||
expected = [
|
||||
@ -121,8 +149,6 @@ class playbook:
|
||||
def __rshift__(self, e):
|
||||
"""Add an event to send"""
|
||||
assert isinstance(e, events.Event)
|
||||
if not self.hooks and isinstance(e, events.HookReply):
|
||||
raise ValueError(f"Playbook must not contain hook replies if hooks=False: {e}")
|
||||
self.expected.append(e)
|
||||
return self
|
||||
|
||||
@ -131,10 +157,6 @@ class playbook:
|
||||
if c is None:
|
||||
return self
|
||||
assert isinstance(c, commands.Command)
|
||||
if not self.logs and isinstance(c, commands.Log):
|
||||
raise ValueError(f"Playbook must not contain log commands if logs=False: {c}")
|
||||
if not self.hooks and isinstance(c, commands.Hook):
|
||||
raise ValueError(f"Playbook must not contain hook commands if hooks=False: {c}")
|
||||
self.expected.append(c)
|
||||
return self
|
||||
|
||||
@ -149,25 +171,41 @@ class playbook:
|
||||
else:
|
||||
if hasattr(x, "playbook_eval"):
|
||||
x = self.expected[i] = x.playbook_eval(self)
|
||||
for name, value in vars(x).items():
|
||||
if isinstance(value, _Placeholder):
|
||||
setattr(x, name, value())
|
||||
if isinstance(x, events.OpenConnectionReply) and not x.reply:
|
||||
x.command.connection.state = ConnectionState.OPEN
|
||||
elif isinstance(x, events.ConnectionClosed):
|
||||
x.connection.state &= ~ConnectionState.CAN_READ
|
||||
|
||||
self.actual.append(x)
|
||||
try:
|
||||
cmds = list(self.layer.handle_event(x))
|
||||
except Exception:
|
||||
self.actual.append(_TracebackInPlaybook(traceback.format_exc()))
|
||||
break
|
||||
self.actual.extend(cmds)
|
||||
if not self.logs:
|
||||
for offset, cmd in enumerate(cmds):
|
||||
if isinstance(cmd, commands.Log):
|
||||
self.expected.insert(i + 1 + offset, cmd)
|
||||
pos = i + 1 + offset
|
||||
if isinstance(cmd, commands.Log) and not isinstance(self.expected[pos], commands.Log):
|
||||
self.expected.insert(pos, cmd)
|
||||
if not self.hooks:
|
||||
last_cmd = self.actual[-1]
|
||||
if isinstance(last_cmd, commands.Hook):
|
||||
self.expected.insert(i + len(cmds), last_cmd)
|
||||
self.expected.insert(i + len(cmds) + 1, events.HookReply(last_cmd))
|
||||
pos = i + len(cmds)
|
||||
need_to_emulate_hook = (
|
||||
isinstance(last_cmd, commands.Hook) and
|
||||
not (isinstance(self.expected[pos], commands.Hook) and self.expected[pos].name == last_cmd.name)
|
||||
)
|
||||
if need_to_emulate_hook:
|
||||
self.expected.insert(pos, last_cmd)
|
||||
self.expected.insert(pos + 1, events.HookReply(last_cmd))
|
||||
i += 1
|
||||
|
||||
self.actual = _merge_sends(self.actual)
|
||||
self.expected = _merge_sends(self.expected)
|
||||
|
||||
if not eq(self.expected, self.actual):
|
||||
self._errored = True
|
||||
diff = "\n".join(difflib.ndiff(
|
||||
@ -210,7 +248,7 @@ class reply(events.Event):
|
||||
self.to = to
|
||||
self.side_effect = side_effect
|
||||
|
||||
def playbook_eval(self, playbook: playbook) -> events.CommandReply:
|
||||
def playbook_eval(self, playbook: Playbook) -> events.CommandReply:
|
||||
if isinstance(self.to, int):
|
||||
expected = playbook.expected[:playbook.expected.index(self)]
|
||||
assert abs(self.to) < len(expected)
|
||||
@ -225,7 +263,7 @@ class reply(events.Event):
|
||||
break
|
||||
else:
|
||||
actual_str = "\n".join(_fmt_entry(x) for x in playbook.actual)
|
||||
raise AssertionError(f"Expected command ({self.to}) did not occur:\n{actual_str}")
|
||||
raise AssertionError(f"Expected command {self.to} did not occur:\n{actual_str}")
|
||||
|
||||
assert isinstance(self.to, commands.Command)
|
||||
self.side_effect(self.to)
|
||||
@ -279,23 +317,12 @@ class EchoLayer(Layer):
|
||||
yield commands.SendData(event.connection, event.data.lower())
|
||||
|
||||
|
||||
def next_layer(
|
||||
def reply_next_layer(
|
||||
layer: typing.Union[typing.Type[Layer], typing.Callable[[context.Context], Layer]],
|
||||
*args,
|
||||
**kwargs
|
||||
) -> reply:
|
||||
"""
|
||||
Helper function to simplify the syntax for next_layer events from this:
|
||||
|
||||
<< commands.Hook("next_layer", next_layer)
|
||||
)
|
||||
next_layer().layer = tutils.EchoLayer(next_layer().context)
|
||||
assert (
|
||||
playbook
|
||||
>> events.HookReply(-1)
|
||||
|
||||
to this:
|
||||
|
||||
"""Helper function to simplify the syntax for next_layer events to this:
|
||||
<< commands.Hook("next_layer", next_layer)
|
||||
>> tutils.next_layer(next_layer, tutils.EchoLayer)
|
||||
"""
|
||||
@ -305,3 +332,15 @@ def next_layer(
|
||||
hook.data.layer = layer(hook.data.context)
|
||||
|
||||
return reply(*args, side_effect=set_layer, **kwargs)
|
||||
|
||||
|
||||
def reply_establish_server_tls(**kwargs) -> reply:
|
||||
"""Helper function to simplify the syntax for EstablishServerTls events to this:
|
||||
<< tls.EstablishServerTLS(server)
|
||||
>> tutils.reply_establish_server_tls()
|
||||
"""
|
||||
|
||||
def fake_tls(cmd: tls.EstablishServerTLS) -> None:
|
||||
cmd.connection.tls_established = True
|
||||
|
||||
return reply(None, side_effect=fake_tls, **kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user