From 41597272f9ba37103ef08e1810fb4bc10603bd87 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 24 Jul 2017 14:42:25 +0200 Subject: [PATCH] protocols: make playbooks immutable --- mitmproxy/proxy/protocol2/test/test_tutils.py | 59 +++++++++++++++++++ mitmproxy/proxy/protocol2/test/tutils.py | 43 ++++++++++---- 2 files changed, 90 insertions(+), 12 deletions(-) create mode 100644 mitmproxy/proxy/protocol2/test/test_tutils.py diff --git a/mitmproxy/proxy/protocol2/test/test_tutils.py b/mitmproxy/proxy/protocol2/test/test_tutils.py new file mode 100644 index 000000000..8c6206e37 --- /dev/null +++ b/mitmproxy/proxy/protocol2/test/test_tutils.py @@ -0,0 +1,59 @@ +import typing + +from mitmproxy.proxy.protocol2 import events, commands +from mitmproxy.proxy.protocol2.layer import Layer +from mitmproxy.proxy.protocol2.test import tutils +from mitmproxy.proxy.protocol2.utils import expect + + +class TEvent(events.Event): + commands: typing.Iterable[typing.Any] + + def __init__(self, cmds=(None,)): + self.commands = cmds + + +class TCommand(commands.Command): + x: typing.Any + + def __init__(self, x=None): + self.x = x + + +class TLayer(Layer): + """ + Simple echo layer + """ + + @expect(TEvent) + def _handle_event(self, event: TEvent) -> commands.TCommandGenerator: + for x in event.commands: + yield TCommand(x) + + +def test_playbook_simple(tctx): + playbook = tutils.playbook(TLayer(tctx), []) + assert ( + playbook + >> TEvent() + << TCommand() + >> TEvent([]) + << None + ) + + +def test_playbook_partial_assert(tctx): + playbook = tutils.playbook(TLayer(tctx), []) + playbook = ( + playbook + >> TEvent() + << TCommand() + ) + assert playbook + playbook = ( + playbook + >> TEvent() + << TCommand() + ) + assert playbook + assert len(playbook.actual) == len(playbook.playbook) == 4 diff --git a/mitmproxy/proxy/protocol2/test/tutils.py b/mitmproxy/proxy/protocol2/test/tutils.py index 3773d56d0..af3177ffc 100644 --- a/mitmproxy/proxy/protocol2/test/tutils.py +++ b/mitmproxy/proxy/protocol2/test/tutils.py @@ -3,11 +3,14 @@ import itertools import re import typing +import copy + from mitmproxy.proxy.protocol2 import commands from mitmproxy.proxy.protocol2 import events from mitmproxy.proxy.protocol2 import layer -TPlaybook = typing.List[typing.Union[commands.Command, events.Event]] +TPlaybookEntry = typing.Union[commands.Command, events.Event] +TPlaybook = typing.List[TPlaybookEntry] def _eq( @@ -60,9 +63,13 @@ class playbook: assert x2 == [] """ layer: layer.Layer + """The base layer""" playbook: TPlaybook - _asserted: int + """expected command/event sequence""" actual: TPlaybook + """actual command/event sequence""" + _final: bool + """True if no << or >> operation has been called on this.""" def __init__( self, @@ -77,25 +84,38 @@ class playbook: self.layer = layer self.playbook = playbook self.actual = [] - self._asserted = 0 + self._final = True + + def _copy_with(self, entry: TPlaybookEntry): + self._final = False + p = playbook( + self.layer, + self.playbook + [entry] + ) + p.actual = self.actual.copy() + return p def __rshift__(self, e): """Add an event to send""" assert isinstance(e, events.Event) - self.playbook.append(e) - return self + return self._copy_with(e) def __lshift__(self, c): """Add an expected command""" if c is None: return self assert isinstance(c, commands.Command) - self.playbook.append(c) - return self + return self._copy_with(c) def __bool__(self): """Determine if playbook is correct.""" - for i, x in list(enumerate(self.playbook))[self._asserted:]: + + self.layer = copy.deepcopy(self.layer) + self.playbook = copy.deepcopy(self.playbook) + self.actual = copy.deepcopy(self.actual) + + already_asserted = len(self.actual) + for i, x in enumerate(self.playbook[already_asserted:], already_asserted): if isinstance(x, commands.Command): pass else: @@ -107,11 +127,10 @@ class playbook: self.actual.extend( self.layer.handle_event(x) ) - self._asserted = len(self.playbook) success = all( _eq(e, a) - for e, a in itertools.zip_longest(self.playbook, self.actual) + for e, a in itertools.zip_longest(self.playbook, self.actual) ) if not success: def _str(x): @@ -129,7 +148,7 @@ class playbook: return True def __del__(self): - if self._asserted < len(self.playbook): + if self._final and len(self.actual) < len(self.playbook): raise RuntimeError("Unfinished playbook!") @@ -144,4 +163,4 @@ class Placeholder: return self.obj def __repr__(self): - return repr(self.obj) + return f"P({repr(self.obj)})"