mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-22 15:37:45 +00:00
[sans-io] add transparent proxy, improve testing
This commit is contained in:
parent
a4803cfaae
commit
e0eb77a794
@ -90,6 +90,14 @@ class Hook(Command):
|
||||
# return f"Hook({self.name}: {data})"
|
||||
|
||||
|
||||
class GetSocket(ConnectionCommand):
|
||||
"""
|
||||
Get the underlying socket.
|
||||
This should really never be used, but is required to implement transparent mode.
|
||||
"""
|
||||
blocking = True
|
||||
|
||||
|
||||
class Log(Command):
|
||||
message: str
|
||||
level: str
|
||||
|
@ -27,6 +27,15 @@ class Connection:
|
||||
def connected(self):
|
||||
return self.state is ConnectionState.OPEN
|
||||
|
||||
@connected.setter
|
||||
def connected(self, val: bool) -> None:
|
||||
# We should really set .state, but verdict is still due if we even want to keep .state around.
|
||||
# We allow setting .connected while we figure that out.
|
||||
if val:
|
||||
self.state = ConnectionState.OPEN
|
||||
else:
|
||||
self.state = ConnectionState.CLOSED
|
||||
|
||||
def __repr__(self):
|
||||
return f"{type(self).__name__}({repr(self.__dict__)})"
|
||||
|
||||
|
@ -3,6 +3,7 @@ When IO actions occur at the proxy server, they are passed down to layers as eve
|
||||
Events represent the only way for layers to receive new data from sockets.
|
||||
The counterpart to events are commands.
|
||||
"""
|
||||
import socket
|
||||
import typing
|
||||
|
||||
from mitmproxy.proxy2 import commands
|
||||
@ -62,7 +63,7 @@ class CommandReply(Event):
|
||||
Emitted when a command has been finished, e.g.
|
||||
when the master has replied or when we have established a server connection.
|
||||
"""
|
||||
command: typing.Union[commands.Command, int]
|
||||
command: commands.Command
|
||||
reply: typing.Any
|
||||
|
||||
def __init__(self, command: typing.Union[commands.Command, int], reply: typing.Any):
|
||||
@ -74,10 +75,19 @@ class CommandReply(Event):
|
||||
raise TypeError("CommandReply may not be instantiated directly.")
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
command_cls = cls.__annotations__["command"]
|
||||
if not issubclass(command_cls, commands.Command) and command_cls is not commands.Command:
|
||||
raise RuntimeError(f"{command_cls} needs a properly annotated command attribute.")
|
||||
command_reply_subclasses[command_cls] = cls
|
||||
|
||||
|
||||
command_reply_subclasses: typing.Dict[commands.Command, typing.Type[CommandReply]] = {}
|
||||
|
||||
|
||||
class OpenConnectionReply(CommandReply):
|
||||
command: typing.Union[commands.OpenConnection, int]
|
||||
reply: str
|
||||
command: commands.OpenConnection
|
||||
reply: typing.Optional[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -88,10 +98,22 @@ class OpenConnectionReply(CommandReply):
|
||||
|
||||
|
||||
class HookReply(CommandReply):
|
||||
command: typing.Union[commands.Hook, int]
|
||||
command: commands.Hook
|
||||
|
||||
def __init__(self, command: typing.Union[commands.Hook, int]):
|
||||
def __init__(self, command: commands.Hook):
|
||||
super().__init__(command, None)
|
||||
|
||||
def __repr__(self):
|
||||
return f"HookReply({repr(self.command)[5:-1]})"
|
||||
|
||||
|
||||
class GetSocketReply(CommandReply):
|
||||
command: commands.GetSocket
|
||||
reply: socket.socket
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
command: typing.Union[commands.GetSocket, int],
|
||||
socket: socket.socket
|
||||
):
|
||||
super().__init__(command, socket)
|
||||
|
@ -1,8 +1,6 @@
|
||||
from . import modes
|
||||
from .glue import GlueLayer
|
||||
from mitmproxy.proxy2.layers.old.old_http import OldHTTPLayer
|
||||
from .http.http import HTTPLayer
|
||||
from mitmproxy.proxy2.layers.old.http1 import ClientHTTP1Layer, ServerHTTP1Layer
|
||||
from .tcp import TCPLayer
|
||||
from .tls import ClientTLSLayer, ServerTLSLayer
|
||||
from .websocket import WebsocketLayer
|
||||
@ -10,10 +8,7 @@ from .websocket import WebsocketLayer
|
||||
__all__ = [
|
||||
"modes",
|
||||
"GlueLayer",
|
||||
"OldHTTPLayer", # TODO remove this and replace with ClientHTTP1Layer
|
||||
"HTTPLayer",
|
||||
"ClientHTTP1Layer", "ServerHTTP1Layer",
|
||||
"ClientHTTP2Layer", "ServerHTTP2Layer",
|
||||
"TCPLayer",
|
||||
"ClientTLSLayer", "ServerTLSLayer",
|
||||
"WebsocketLayer",
|
||||
|
@ -1,23 +1,41 @@
|
||||
from mitmproxy import platform
|
||||
from mitmproxy.net import server_spec
|
||||
from mitmproxy.proxy2 import layer
|
||||
from mitmproxy.proxy2.context import Context, Server
|
||||
from mitmproxy.proxy2 import commands, events, layer
|
||||
from mitmproxy.proxy2.context import Server
|
||||
from mitmproxy.proxy2.utils import expect
|
||||
|
||||
|
||||
class ReverseProxy(layer.Layer):
|
||||
def __init__(self, context: Context):
|
||||
super().__init__(context)
|
||||
spec = server_spec.parse_with_mode(context.options.mode)[1]
|
||||
@expect(events.Start)
|
||||
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
spec = server_spec.parse_with_mode(self.context.options.mode)[1]
|
||||
self.context.server = Server(spec.address)
|
||||
if spec.scheme != "http":
|
||||
if spec.scheme not in ("http", "tcp"):
|
||||
self.context.server.tls = True
|
||||
if not context.options.keep_host_header:
|
||||
if not self.context.options.keep_host_header:
|
||||
self.context.server.sni = spec.address[0]
|
||||
child_layer = layer.NextLayer(self.context)
|
||||
self._handle_event = child_layer.handle_event
|
||||
yield from child_layer.handle_event(event)
|
||||
|
||||
|
||||
class HttpProxy(layer.Layer):
|
||||
def __init__(self, context: Context):
|
||||
super().__init__(context)
|
||||
@expect(events.Start)
|
||||
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
child_layer = layer.NextLayer(self.context)
|
||||
self._handle_event = child_layer.handle_event
|
||||
yield from child_layer.handle_event(event)
|
||||
|
||||
|
||||
class TransparentProxy(layer.Layer):
|
||||
@expect(events.Start)
|
||||
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
socket = yield commands.GetSocket(self.context.client)
|
||||
try:
|
||||
self.context.server.address = platform.original_addr(socket)
|
||||
except Exception as e:
|
||||
yield commands.Log(f"Transparent mode failure: {e!r}")
|
||||
|
||||
child_layer = layer.NextLayer(self.context)
|
||||
self._handle_event = child_layer.handle_event
|
||||
yield from child_layer.handle_event(event)
|
||||
|
@ -1,4 +1,6 @@
|
||||
from mitmproxy import tcp, flow
|
||||
from typing import Optional
|
||||
|
||||
from mitmproxy import flow, tcp
|
||||
from mitmproxy.proxy2 import commands, events
|
||||
from mitmproxy.proxy2.context import Context
|
||||
from mitmproxy.proxy2.layer import Layer
|
||||
@ -10,26 +12,25 @@ class TCPLayer(Layer):
|
||||
Simple TCP layer that just relays messages right now.
|
||||
"""
|
||||
context: Context
|
||||
ignore: bool
|
||||
flow: tcp.TCPFlow
|
||||
flow: Optional[tcp.TCPFlow]
|
||||
|
||||
def __init__(self, context: Context, ignore: bool = False):
|
||||
super().__init__(context)
|
||||
self.ignore = ignore
|
||||
self.flow = None
|
||||
if ignore:
|
||||
self.flow = None
|
||||
else:
|
||||
self.flow = tcp.TCPFlow(self.context.client, self.context.server, True)
|
||||
|
||||
@expect(events.Start)
|
||||
def start(self, _) -> commands.TCommandGenerator:
|
||||
if not self.ignore:
|
||||
self.flow = tcp.TCPFlow(self.context.client, self.context.server, True)
|
||||
if self.flow:
|
||||
yield commands.Hook("tcp_start", self.flow)
|
||||
|
||||
if not self.context.server.connected:
|
||||
try:
|
||||
yield commands.OpenConnection(self.context.server)
|
||||
except IOError as e:
|
||||
if not self.ignore:
|
||||
self.flow.error = flow.Error(str(e))
|
||||
err = yield commands.OpenConnection(self.context.server)
|
||||
if err:
|
||||
if self.flow:
|
||||
self.flow.error = flow.Error(str(err))
|
||||
yield commands.Hook("tcp_error", self.flow)
|
||||
yield commands.CloseConnection(self.context.client)
|
||||
self._handle_event = self.done
|
||||
@ -47,19 +48,21 @@ class TCPLayer(Layer):
|
||||
send_to = self.context.client
|
||||
|
||||
if isinstance(event, events.DataReceived):
|
||||
if self.ignore:
|
||||
yield commands.SendData(send_to, event.data)
|
||||
else:
|
||||
if self.flow:
|
||||
tcp_message = tcp.TCPMessage(from_client, event.data)
|
||||
self.flow.messages.append(tcp_message)
|
||||
yield commands.Hook("tcp_message", self.flow)
|
||||
yield commands.SendData(send_to, tcp_message.content)
|
||||
else:
|
||||
yield commands.SendData(send_to, event.data)
|
||||
|
||||
elif isinstance(event, events.ConnectionClosed):
|
||||
yield commands.CloseConnection(send_to)
|
||||
if not self.ignore:
|
||||
yield commands.Hook("tcp_end", self.flow)
|
||||
self._handle_event = self.done
|
||||
all_done = (not self.context.client.connected and not self.context.server.connected)
|
||||
if all_done:
|
||||
self._handle_event = self.done
|
||||
if self.flow:
|
||||
yield commands.Hook("tcp_end", self.flow)
|
||||
|
||||
@expect(events.DataReceived, events.ConnectionClosed)
|
||||
def done(self, _):
|
||||
|
@ -6,7 +6,7 @@ from OpenSSL import SSL
|
||||
|
||||
from mitmproxy.certs import CertStore
|
||||
from mitmproxy.net.tls import ClientHello
|
||||
from mitmproxy.proxy.protocol import tls
|
||||
from mitmproxy.proxy.protocol.tls import DEFAULT_CLIENT_CIPHERS
|
||||
from mitmproxy.proxy2 import commands, events, layer
|
||||
from mitmproxy.proxy2 import context
|
||||
from mitmproxy.proxy2.utils import expect
|
||||
@ -362,7 +362,7 @@ class ClientTLSLayer(_TLSLayer):
|
||||
).get_cert(client.sni, (client.sni,))
|
||||
context.use_privatekey(privkey)
|
||||
context.use_certificate(cert.x509)
|
||||
context.set_cipher_list(tls.DEFAULT_CLIENT_CIPHERS)
|
||||
context.set_cipher_list(DEFAULT_CLIENT_CIPHERS)
|
||||
|
||||
def alpn_select_callback(conn_, options):
|
||||
if server.alpn in options:
|
||||
|
@ -31,7 +31,7 @@ class StreamIO(typing.NamedTuple):
|
||||
|
||||
class TimeoutWatchdog:
|
||||
last_activity: float
|
||||
CONNECTION_TIMEOUT = 120
|
||||
CONNECTION_TIMEOUT = 10 * 60
|
||||
can_timeout: asyncio.Event
|
||||
blocker: int
|
||||
|
||||
@ -197,6 +197,9 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
asyncio.ensure_future(
|
||||
self.shutdown_connection(command.connection)
|
||||
)
|
||||
elif isinstance(command, commands.GetSocket):
|
||||
socket = self.transports[command.connection].w.get_extra_info("socket")
|
||||
self.server_event(events.GetSocketReply(command, socket))
|
||||
elif isinstance(command, glue.GlueGetConnectionHandler):
|
||||
self.server_event(glue.GlueGetConnectionHandlerReply(command, self))
|
||||
elif isinstance(command, commands.Hook):
|
||||
|
@ -18,7 +18,8 @@ def expect(*event_types):
|
||||
if isinstance(event, event_types):
|
||||
yield from f(self, event)
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event type at {f}: Expected {event_types}, got {event}.")
|
||||
event_types_str = '|'.join(e.__name__ for e in event_types)
|
||||
raise AssertionError(f"Unexpected event type at {f.__qualname__}: Expected {event_types_str}, got {event}.")
|
||||
|
||||
return wrapper
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
from mitmproxy.proxy2 import commands, events
|
||||
from mitmproxy.proxy2.layers import tcp
|
||||
from .. import tutils
|
||||
from mitmproxy.proxy2.commands import CloseConnection, Hook, OpenConnection, SendData
|
||||
from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
|
||||
from mitmproxy.proxy2.layers import TCPLayer
|
||||
from ..tutils import Placeholder, playbook, reply
|
||||
|
||||
|
||||
def test_open_connection(tctx):
|
||||
@ -9,121 +10,101 @@ def test_open_connection(tctx):
|
||||
because the server may send data first.
|
||||
"""
|
||||
assert (
|
||||
tutils.playbook(tcp.TCPLayer(tctx, True))
|
||||
<< commands.OpenConnection(tctx.server)
|
||||
playbook(TCPLayer(tctx, True))
|
||||
<< OpenConnection(tctx.server)
|
||||
)
|
||||
|
||||
tctx.server.connected = True
|
||||
assert (
|
||||
tutils.playbook(tcp.TCPLayer(tctx, True))
|
||||
<< None
|
||||
playbook(TCPLayer(tctx, True))
|
||||
<< None
|
||||
)
|
||||
|
||||
|
||||
def test_open_connection_err(tctx):
|
||||
f = tutils.Placeholder()
|
||||
f = Placeholder()
|
||||
assert (
|
||||
tutils.playbook(tcp.TCPLayer(tctx))
|
||||
<< commands.Hook("tcp_start", f)
|
||||
>> events.HookReply(-1)
|
||||
<< commands.OpenConnection(tctx.server)
|
||||
>> events.OpenConnectionReply(-1, "Connect call failed")
|
||||
<< commands.Hook("tcp_error", f)
|
||||
>> events.HookReply(-1)
|
||||
<< commands.CloseConnection(tctx.client)
|
||||
playbook(TCPLayer(tctx))
|
||||
<< Hook("tcp_start", f)
|
||||
>> reply()
|
||||
<< OpenConnection(tctx.server)
|
||||
>> reply("Connect call failed")
|
||||
<< Hook("tcp_error", f)
|
||||
>> reply()
|
||||
<< CloseConnection(tctx.client)
|
||||
)
|
||||
|
||||
|
||||
def test_simple(tctx):
|
||||
"""open connection, receive data, send it to peer"""
|
||||
f = tutils.Placeholder()
|
||||
playbook = tutils.playbook(tcp.TCPLayer(tctx))
|
||||
f = Placeholder()
|
||||
|
||||
assert (
|
||||
playbook
|
||||
<< commands.Hook("tcp_start", f)
|
||||
>> events.HookReply(-1)
|
||||
<< commands.OpenConnection(tctx.server)
|
||||
>> events.OpenConnectionReply(-1, None)
|
||||
>> events.DataReceived(tctx.client, b"hello!")
|
||||
<< commands.Hook("tcp_message", f)
|
||||
>> events.HookReply(-1)
|
||||
<< commands.SendData(tctx.server, b"hello!")
|
||||
>> events.DataReceived(tctx.server, b"hi")
|
||||
<< commands.Hook("tcp_message", f)
|
||||
>> events.HookReply(-1)
|
||||
<< commands.SendData(tctx.client, b"hi")
|
||||
>> events.ConnectionClosed(tctx.server)
|
||||
<< commands.CloseConnection(tctx.client)
|
||||
<< commands.Hook("tcp_end", f)
|
||||
>> events.HookReply(-1)
|
||||
>> events.ConnectionClosed(tctx.client)
|
||||
<< None
|
||||
playbook(TCPLayer(tctx))
|
||||
<< Hook("tcp_start", f)
|
||||
>> reply()
|
||||
<< OpenConnection(tctx.server)
|
||||
>> reply(None)
|
||||
>> DataReceived(tctx.client, b"hello!")
|
||||
<< Hook("tcp_message", f)
|
||||
>> reply()
|
||||
<< SendData(tctx.server, b"hello!")
|
||||
>> DataReceived(tctx.server, b"hi")
|
||||
<< Hook("tcp_message", f)
|
||||
>> reply()
|
||||
<< SendData(tctx.client, b"hi")
|
||||
>> ConnectionClosed(tctx.server)
|
||||
<< CloseConnection(tctx.client)
|
||||
>> ConnectionClosed(tctx.client)
|
||||
<< CloseConnection(tctx.server)
|
||||
<< Hook("tcp_end", f)
|
||||
>> reply()
|
||||
>> ConnectionClosed(tctx.client)
|
||||
<< None
|
||||
)
|
||||
assert len(f().messages) == 2
|
||||
|
||||
|
||||
def test_simple_explicit(tctx):
|
||||
"""
|
||||
For comparison, test_simple without the playbook() sugar.
|
||||
This is not substantially more code, but the playbook syntax feels cleaner to me.
|
||||
"""
|
||||
layer = tcp.TCPLayer(tctx)
|
||||
tcp_start, = layer.handle_event(events.Start())
|
||||
flow = tcp_start.data
|
||||
assert tutils._eq(tcp_start, commands.Hook("tcp_start", flow))
|
||||
open_conn, = layer.handle_event(events.HookReply(tcp_start))
|
||||
assert tutils._eq(open_conn, commands.OpenConnection(tctx.server))
|
||||
assert list(layer.handle_event(events.OpenConnectionReply(open_conn, None))) == []
|
||||
tcp_msg, = layer.handle_event(events.DataReceived(tctx.client, b"hello!"))
|
||||
assert tutils._eq(tcp_msg, commands.Hook("tcp_message", flow))
|
||||
assert flow.messages[0].content == b"hello!"
|
||||
|
||||
send, = layer.handle_event(events.HookReply(tcp_msg))
|
||||
assert tutils._eq(send, commands.SendData(tctx.server, b"hello!"))
|
||||
close, tcp_end = layer.handle_event(events.ConnectionClosed(tctx.server))
|
||||
assert tutils._eq(close, commands.CloseConnection(tctx.client))
|
||||
assert tutils._eq(tcp_end, commands.Hook("tcp_end", flow))
|
||||
assert list(layer.handle_event(events.HookReply(tcp_end))) == []
|
||||
|
||||
|
||||
def test_receive_data_before_server_connected(tctx):
|
||||
"""
|
||||
assert that data received before a server connection is established
|
||||
will still be forwarded.
|
||||
"""
|
||||
f = tutils.Placeholder()
|
||||
f = Placeholder()
|
||||
assert (
|
||||
tutils.playbook(tcp.TCPLayer(tctx))
|
||||
<< commands.Hook("tcp_start", f)
|
||||
>> events.HookReply(-1)
|
||||
<< commands.OpenConnection(tctx.server)
|
||||
>> events.DataReceived(tctx.client, b"hello!")
|
||||
>> events.OpenConnectionReply(-2, None)
|
||||
<< commands.Hook("tcp_message", f)
|
||||
>> events.HookReply(-1)
|
||||
<< commands.SendData(tctx.server, b"hello!")
|
||||
playbook(TCPLayer(tctx))
|
||||
<< Hook("tcp_start", f)
|
||||
>> reply()
|
||||
<< OpenConnection(tctx.server)
|
||||
>> DataReceived(tctx.client, b"hello!")
|
||||
>> reply(None, to=-2)
|
||||
<< Hook("tcp_message", f)
|
||||
>> reply()
|
||||
<< SendData(tctx.server, b"hello!")
|
||||
)
|
||||
assert f().messages
|
||||
|
||||
|
||||
def test_receive_data_after_server_disconnected(tctx):
|
||||
def test_receive_data_after_half_close(tctx):
|
||||
"""
|
||||
data received after a connection has been closed should just be discarded.
|
||||
data received after the other connection has been half-closed should still be forwarded.
|
||||
"""
|
||||
f = tutils.Placeholder()
|
||||
f = Placeholder()
|
||||
assert (
|
||||
tutils.playbook(tcp.TCPLayer(tctx))
|
||||
<< commands.Hook("tcp_start", f)
|
||||
>> events.HookReply(-1)
|
||||
<< commands.OpenConnection(tctx.server)
|
||||
>> events.OpenConnectionReply(-1, None)
|
||||
>> events.ConnectionClosed(tctx.server)
|
||||
<< commands.CloseConnection(tctx.client)
|
||||
<< commands.Hook("tcp_end", f)
|
||||
>> events.HookReply(-1)
|
||||
>> events.DataReceived(tctx.client, b"i'm late")
|
||||
<< None
|
||||
)
|
||||
# not included here as it has not been sent to the server.
|
||||
assert not f().messages
|
||||
playbook(TCPLayer(tctx))
|
||||
<< Hook("tcp_start", f)
|
||||
>> reply()
|
||||
<< OpenConnection(tctx.server)
|
||||
>> reply(None)
|
||||
>> ConnectionClosed(tctx.server)
|
||||
<< CloseConnection(tctx.client)
|
||||
>> DataReceived(tctx.client, b"i'm late")
|
||||
<< Hook("tcp_message", f)
|
||||
>> reply()
|
||||
<< SendData(tctx.server, b"i'm late")
|
||||
>> ConnectionClosed(tctx.client)
|
||||
<< CloseConnection(tctx.server)
|
||||
<< Hook("tcp_end", f)
|
||||
>> reply()
|
||||
<< None
|
||||
)
|
@ -22,7 +22,7 @@ class TCommand(commands.Command):
|
||||
|
||||
|
||||
class TCommandReply(events.CommandReply):
|
||||
pass
|
||||
command: TCommand
|
||||
|
||||
|
||||
class TLayer(Layer):
|
||||
@ -52,7 +52,7 @@ def test_simple(tplaybook):
|
||||
|
||||
|
||||
def test_mismatch(tplaybook):
|
||||
with pytest.raises(AssertionError, message="Playbook mismatch"):
|
||||
with pytest.raises(AssertionError, match="Playbook mismatch"):
|
||||
assert (
|
||||
tplaybook
|
||||
>> TEvent([])
|
||||
@ -135,7 +135,7 @@ def test_fork_placeholder(tplaybook):
|
||||
assert f2() == p2_flow
|
||||
|
||||
# re-using the old placeholder does not work.
|
||||
with pytest.raises(AssertionError, message="Playbook mismatch"):
|
||||
with pytest.raises(AssertionError, match="Playbook mismatch"):
|
||||
assert (
|
||||
p2
|
||||
>> TEvent([p2_flow])
|
||||
@ -146,7 +146,7 @@ def test_fork_placeholder(tplaybook):
|
||||
def test_unfinished(tplaybook):
|
||||
"""We show a warning when playbooks aren't asserted."""
|
||||
tplaybook >> TEvent()
|
||||
with pytest.raises(RuntimeError, message="Unfinished playbook"):
|
||||
with pytest.raises(RuntimeError, match="Unfinished playbook"):
|
||||
tplaybook.__del__()
|
||||
tplaybook._errored = True
|
||||
tplaybook.__del__()
|
||||
|
@ -1,12 +1,13 @@
|
||||
import collections.abc
|
||||
import copy
|
||||
import difflib
|
||||
import itertools
|
||||
import typing
|
||||
|
||||
import collections
|
||||
|
||||
from mitmproxy.proxy2 import commands, context
|
||||
from mitmproxy.proxy2 import events
|
||||
from mitmproxy.proxy2.context import ConnectionState
|
||||
from mitmproxy.proxy2.events import command_reply_subclasses
|
||||
from mitmproxy.proxy2.layer import Layer, NextLayer
|
||||
|
||||
TPlaybookEntry = typing.Union[commands.Command, events.Event]
|
||||
@ -14,8 +15,8 @@ TPlaybook = typing.List[TPlaybookEntry]
|
||||
|
||||
|
||||
def _eq(
|
||||
a: TPlaybookEntry,
|
||||
b: TPlaybookEntry
|
||||
a: TPlaybookEntry,
|
||||
b: TPlaybookEntry
|
||||
) -> bool:
|
||||
"""Compare two commands/events, and possibly update placeholders."""
|
||||
if type(a) != type(b):
|
||||
@ -43,24 +44,28 @@ def _eq(
|
||||
|
||||
|
||||
def eq(
|
||||
a: typing.Union[TPlaybookEntry, typing.Iterable[TPlaybookEntry]],
|
||||
b: typing.Union[TPlaybookEntry, typing.Iterable[TPlaybookEntry]]
|
||||
a: typing.Union[TPlaybookEntry, typing.Iterable[TPlaybookEntry]],
|
||||
b: typing.Union[TPlaybookEntry, typing.Iterable[TPlaybookEntry]]
|
||||
):
|
||||
"""
|
||||
Compare an indiviual event/command or a list of events/commands.
|
||||
"""
|
||||
if isinstance(a, collections.Iterable) and isinstance(b, collections.Iterable):
|
||||
if isinstance(a, collections.abc.Iterable) and isinstance(b, collections.abc.Iterable):
|
||||
return all(
|
||||
_eq(x, y) for x, y in itertools.zip_longest(a, b)
|
||||
)
|
||||
return _eq(a, b)
|
||||
|
||||
|
||||
T = typing.TypeVar('T', bound=Layer)
|
||||
def _str(x: typing.Union[events.Event, commands.Command]):
|
||||
arrow = ">>" if isinstance(x, events.Event) else "<<"
|
||||
x = str(x) \
|
||||
.replace('Placeholder:None', '<unset placeholder>') \
|
||||
.replace('Placeholder:', '')
|
||||
return f"{arrow} {x}"
|
||||
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
class playbook(typing.Generic[T]):
|
||||
class playbook:
|
||||
"""
|
||||
Assert that a layer emits the expected commands in reaction to a given sequence of events.
|
||||
For example, the following code asserts that the TCP layer emits an OpenConnection command
|
||||
@ -80,7 +85,7 @@ class playbook(typing.Generic[T]):
|
||||
x2 = list(t.handle_event(events.OpenConnectionReply(x1[-1])))
|
||||
assert x2 == []
|
||||
"""
|
||||
layer: T
|
||||
layer: Layer
|
||||
"""The base layer"""
|
||||
expected: TPlaybook
|
||||
"""expected command/event sequence"""
|
||||
@ -92,10 +97,10 @@ class playbook(typing.Generic[T]):
|
||||
"""If True, log statements are ignored."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: T,
|
||||
expected: typing.Optional[TPlaybook] = None,
|
||||
ignore_log: bool = True
|
||||
self,
|
||||
layer: Layer,
|
||||
expected: typing.Optional[TPlaybook] = None,
|
||||
ignore_log: bool = True
|
||||
):
|
||||
if expected is None:
|
||||
expected = [
|
||||
@ -130,11 +135,12 @@ class playbook(typing.Generic[T]):
|
||||
if isinstance(x, commands.Command):
|
||||
pass
|
||||
else:
|
||||
if isinstance(x, events.CommandReply):
|
||||
if isinstance(x.command, int) and abs(x.command) < len(self.actual):
|
||||
x.command = self.actual[x.command]
|
||||
if hasattr(x, "_playbook_eval"):
|
||||
x._playbook_eval(self)
|
||||
if hasattr(x, "playbook_eval"):
|
||||
x = self.expected[i] = x.playbook_eval(self)
|
||||
if isinstance(x, events.OpenConnectionReply):
|
||||
x.command.connection.state = ConnectionState.OPEN
|
||||
elif isinstance(x, events.ConnectionClosed):
|
||||
x.connection.state &= ~ConnectionState.CAN_READ
|
||||
|
||||
self.actual.append(x)
|
||||
self.actual.extend(
|
||||
@ -148,14 +154,6 @@ class playbook(typing.Generic[T]):
|
||||
|
||||
if not eq(self.expected, self.actual):
|
||||
self._errored = True
|
||||
|
||||
def _str(x):
|
||||
arrow = ">>" if isinstance(x, events.Event) else "<<"
|
||||
x = str(x) \
|
||||
.replace('Placeholder:None', '<unset placeholder>') \
|
||||
.replace('Placeholder:', '')
|
||||
return f"{arrow} {x}"
|
||||
|
||||
diff = "\n".join(difflib.ndiff(
|
||||
[_str(x) for x in self.expected],
|
||||
[_str(x) for x in self.actual]
|
||||
@ -180,6 +178,48 @@ class playbook(typing.Generic[T]):
|
||||
return copy.deepcopy(self)
|
||||
|
||||
|
||||
class reply(events.Event):
|
||||
args: typing.Tuple[typing.Any, ...]
|
||||
to: typing.Union[commands.Command, int]
|
||||
side_effect: typing.Callable[[commands.Command], typing.Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
to: typing.Union[commands.Command, int] = -1,
|
||||
side_effect: typing.Callable[[commands.Command], typing.Any] = lambda cmd: None
|
||||
):
|
||||
"""Utility method to reply to the latest hook in playbooks."""
|
||||
self.args = args
|
||||
self.to = to
|
||||
self.side_effect = side_effect
|
||||
|
||||
def playbook_eval(self, playbook: playbook) -> events.CommandReply:
|
||||
if isinstance(self.to, int):
|
||||
expected = playbook.expected[:playbook.expected.index(self)]
|
||||
assert abs(self.to) < len(expected)
|
||||
to = expected[self.to]
|
||||
if not isinstance(to, commands.Command):
|
||||
raise AssertionError(f"There is no command at offset {self.to}: {to}")
|
||||
else:
|
||||
self.to = to
|
||||
for cmd in reversed(playbook.actual):
|
||||
if eq(self.to, cmd):
|
||||
self.to = cmd
|
||||
break
|
||||
else:
|
||||
actual_str = "\n".join(_str(x) for x in playbook.actual)
|
||||
raise AssertionError(f"Expected command ({self.to}) did not occur:\n{actual_str}")
|
||||
|
||||
self.side_effect(self.to)
|
||||
reply_cls = command_reply_subclasses[type(self.to)]
|
||||
try:
|
||||
inst = reply_cls(self.to, *self.args)
|
||||
except TypeError as e:
|
||||
raise ValueError(f"Cannot instantiate {reply_cls.__name__}: {e}")
|
||||
return inst
|
||||
|
||||
|
||||
class _Placeholder:
|
||||
"""
|
||||
Placeholder value in playbooks, so that objects (flows in particular) can be referenced before
|
||||
@ -209,6 +249,7 @@ class _Placeholder:
|
||||
return f"Placeholder:{str(self.obj)}"
|
||||
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
def Placeholder() -> typing.Any:
|
||||
return _Placeholder()
|
||||
|
||||
@ -222,7 +263,7 @@ class EchoLayer(Layer):
|
||||
|
||||
|
||||
def next_layer(
|
||||
layer: typing.Union[typing.Type[Layer], typing.Callable[[context.Context], Layer]]
|
||||
layer: typing.Union[typing.Type[Layer], typing.Callable[[context.Context], Layer]]
|
||||
) -> events.HookReply:
|
||||
"""
|
||||
Helper function to simplify the syntax for next_layer events from this:
|
||||
@ -238,7 +279,9 @@ def next_layer(
|
||||
|
||||
<< commands.Hook("next_layer", next_layer)
|
||||
>> tutils.next_layer(next_layer, tutils.EchoLayer)
|
||||
>> tutils.reply(side_effect=lambda cmd: cmd.layer = tutils.EchoLayer(cmd.data.context)
|
||||
"""
|
||||
raise RuntimeError("Does tutils.reply(side_effect=lambda cmd: cmd.layer = tutils.EchoLayer(cmd.data.context) work?")
|
||||
if isinstance(layer, type):
|
||||
def make_layer(ctx: context.Context) -> Layer:
|
||||
return layer(ctx)
|
||||
|
Loading…
Reference in New Issue
Block a user