diff --git a/mitmproxy/addons/clientplayback_sansio.py b/mitmproxy/addons/clientplayback_sansio.py index a9b5c10f2..20055dec6 100644 --- a/mitmproxy/addons/clientplayback_sansio.py +++ b/mitmproxy/addons/clientplayback_sansio.py @@ -85,7 +85,7 @@ class ReplayHandler(server.ConnectionHandler): ctx.log(f"[replay] {message}", level) async def handle_hook(self, hook: commands.Hook) -> None: - data, = hook.as_tuple() + data, = hook.args() data.reply = AsyncReply(data) await ctx.master.addons.handle_lifecycle(hook.name, data) await data.reply.done.wait() diff --git a/mitmproxy/addons/proxyserver.py b/mitmproxy/addons/proxyserver.py index fad27db9f..ab8e23aaa 100644 --- a/mitmproxy/addons/proxyserver.py +++ b/mitmproxy/addons/proxyserver.py @@ -43,7 +43,7 @@ class ProxyConnectionHandler(server.StreamConnectionHandler): async def handle_hook(self, hook: commands.Hook) -> None: with self.timeout_watchdog.disarm(): # We currently only support single-argument hooks. - data, = hook.as_tuple() + data, = hook.args() data.reply = AsyncReply(data) await self.master.addons.handle_lifecycle(hook.name, data) await data.reply.done.wait() diff --git a/mitmproxy/proxy2/commands.py b/mitmproxy/proxy2/commands.py index 0fd68f9d7..3f9af7e3c 100644 --- a/mitmproxy/proxy2/commands.py +++ b/mitmproxy/proxy2/commands.py @@ -114,7 +114,7 @@ class Hook(Command): def __repr__(self): return f"Hook({self.name})" - def as_tuple(self) -> List[Any]: + def args(self) -> List[Any]: args = [] # noinspection PyDataclass for field in dataclasses.fields(self): diff --git a/mitmproxy/proxy2/context.py b/mitmproxy/proxy2/context.py index c12793924..f6f0d60ac 100644 --- a/mitmproxy/proxy2/context.py +++ b/mitmproxy/proxy2/context.py @@ -80,6 +80,14 @@ class Connection(serializable.Serializable, metaclass=ABCMeta): def tls_established(self) -> bool: return self.timestamp_tls_setup is not None + def __eq__(self, other): + if isinstance(other, Connection): + return self.id == other.id + return False + + def __hash__(self): + return hash(self.id) + def __repr__(self): attrs = repr({ k: {"cipher_list": lambda: f"<{len(v)} ciphers>"}.get(k, lambda: v)() diff --git a/mitmproxy/proxy2/events.py b/mitmproxy/proxy2/events.py index 8d20b1839..c1875d346 100644 --- a/mitmproxy/proxy2/events.py +++ b/mitmproxy/proxy2/events.py @@ -36,13 +36,6 @@ class ConnectionEvent(Event): connection: Connection -class ConnectionClosed(ConnectionEvent): - """ - Remote has closed a connection. - """ - pass - - @dataclass class DataReceived(ConnectionEvent): """ @@ -55,6 +48,13 @@ class DataReceived(ConnectionEvent): return f"DataReceived({target}, {self.data})" +class ConnectionClosed(ConnectionEvent): + """ + Remote has closed a connection. + """ + pass + + class CommandReply(Event): """ Emitted when a command has been finished, e.g. @@ -64,41 +64,47 @@ class CommandReply(Event): reply: typing.Any def __new__(cls, *args, **kwargs): - assert is_dataclass(cls) if cls is CommandReply: raise TypeError("CommandReply may not be instantiated directly.") + assert is_dataclass(cls) return super().__new__(cls) def __init_subclass__(cls, **kwargs): command_cls = cls.__annotations__["command"] - if not issubclass(command_cls, commands.Command) and command_cls is not commands.Command: + valid_command_subclass = ( + issubclass(command_cls, commands.Command) and command_cls is not commands.Command + ) + if not valid_command_subclass: raise RuntimeError(f"{command_cls} needs a properly annotated command attribute.") if command_cls in command_reply_subclasses: other = command_reply_subclasses[command_cls] raise RuntimeError(f"Two conflicting subclasses for {command_cls}: {cls} and {other}") command_reply_subclasses[command_cls] = cls + def __repr__(self): + return f"Reply({repr(self.command)})" + command_reply_subclasses: typing.Dict[commands.Command, typing.Type[CommandReply]] = {} -@dataclass +@dataclass(repr=False) class OpenConnectionReply(CommandReply): command: commands.OpenConnection reply: typing.Optional[str] """error message""" + def __repr__(self): + return super().__repr__() -@dataclass + +@dataclass(repr=False) class HookReply(CommandReply): command: commands.Hook reply: None = None - def __repr__(self): - return f"HookReply({repr(self.command)[5:-1]})" - -@dataclass +@dataclass(repr=False) class GetSocketReply(CommandReply): command: commands.GetSocket reply: socket.socket diff --git a/mitmproxy/proxy2/server.py b/mitmproxy/proxy2/server.py index c483d959e..e20a62fad 100644 --- a/mitmproxy/proxy2/server.py +++ b/mitmproxy/proxy2/server.py @@ -357,7 +357,7 @@ class SimpleConnectionHandler(StreamConnectionHandler): hook: commands.Hook ) -> None: if hook.name in self.hook_handlers: - self.hook_handlers[hook.name](*hook.as_tuple()) + self.hook_handlers[hook.name](*hook.args()) def log(self, message: str, level: str = "info"): if "Hook" not in message: diff --git a/mitmproxy/test/tflow.py b/mitmproxy/test/tflow.py index 8f8a94335..ae3b34a77 100644 --- a/mitmproxy/test/tflow.py +++ b/mitmproxy/test/tflow.py @@ -170,8 +170,9 @@ def tclient_conn() -> compat.Client: cipher_list=[], )) c.reply = controller.DummyReply() - c.rfile = io.BytesIO() - c.wfile = io.BytesIO() + if not compat.new_proxy_core: + c.rfile = io.BytesIO() + c.wfile = io.BytesIO() return c diff --git a/test/filename_matching.py b/test/filename_matching.py index c54cd1aa6..1257ce505 100755 --- a/test/filename_matching.py +++ b/test/filename_matching.py @@ -9,7 +9,7 @@ import sys def check_src_files_have_test(): missing_test_files = [] - excluded = ['mitmproxy/contrib/', 'mitmproxy/io/proto/', + excluded = ['mitmproxy/contrib/', 'mitmproxy/io/proto/', 'mitmproxy/proxy2/layers/http', 'mitmproxy/test/', 'mitmproxy/tools/', 'mitmproxy/platform/'] src_files = glob.glob('mitmproxy/**/*.py', recursive=True) + glob.glob('pathod/**/*.py', recursive=True) src_files = [f for f in src_files if os.path.basename(f) != '__init__.py'] diff --git a/test/mitmproxy/addons/test_tlsconfig.py b/test/mitmproxy/addons/test_tlsconfig.py index d01381b1e..fd0729c58 100644 --- a/test/mitmproxy/addons/test_tlsconfig.py +++ b/test/mitmproxy/addons/test_tlsconfig.py @@ -1,4 +1,5 @@ import ssl +import time from pathlib import Path from typing import Union @@ -166,6 +167,29 @@ class TestTlsConfig: tssl_server = test_tls.SSLTest(server_side=True) assert self.do_handshake(tssl_client, tssl_server) + def test_alpn_selection(self): + ta = tlsconfig.TlsConfig() + with taddons.context(ta) as tctx: + ctx = context.Context(context.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options) + ctx.server.address = ("example.mitmproxy.org", 443) + tls_start = tls.TlsStartData(ctx.server, context=ctx) + + def assert_alpn(http2, client_offers, expected): + tctx.configure(ta, http2=http2) + ctx.client.alpn_offers = client_offers + ctx.server.alpn_offers = None + ta.tls_start(tls_start) + assert ctx.server.alpn_offers == expected + + assert_alpn(True, tls.HTTP_ALPNS + (b"foo",), tls.HTTP_ALPNS + (b"foo",)) + assert_alpn(False, tls.HTTP_ALPNS + (b"foo",), tls.HTTP1_ALPNS + (b"foo",)) + assert_alpn(True, [], tls.HTTP_ALPNS) + assert_alpn(False, [], tls.HTTP1_ALPNS) + ctx.client.timestamp_tls_setup = time.time() + # make sure that we don't upgrade h1 to h2, + # see comment in tlsconfig.py + assert_alpn(True, [], []) + @pytest.mark.parametrize( "client_certs", [ diff --git a/test/mitmproxy/proxy2/test_commands.py b/test/mitmproxy/proxy2/test_commands.py index e69de29bb..b70d3a569 100644 --- a/test/mitmproxy/proxy2/test_commands.py +++ b/test/mitmproxy/proxy2/test_commands.py @@ -0,0 +1,33 @@ +import pytest + +from mitmproxy.proxy2 import commands, context + + +@pytest.fixture +def tconn() -> context.Server: + return context.Server(None) + + +def test_dataclasses(tconn): + assert repr(commands.SendData(tconn, b"foo")) + assert repr(commands.OpenConnection(tconn)) + assert repr(commands.CloseConnection(tconn)) + assert repr(commands.GetSocket(tconn)) + assert repr(commands.Log("hello", "info")) + + +def test_hook(): + with pytest.raises(TypeError): + commands.Hook() + + class FooHook(commands.Hook): + data: bytes + + f = FooHook(b"foo") + assert repr(f) + assert f.args() == [b"foo"] + assert FooHook in commands.all_hooks.values() + + with pytest.raises(RuntimeError, match="Two conflicting hooks"): + class FooHook2(commands.Hook): + name = "foo" diff --git a/test/mitmproxy/proxy2/test_context.py b/test/mitmproxy/proxy2/test_context.py index e69de29bb..3952d967c 100644 --- a/test/mitmproxy/proxy2/test_context.py +++ b/test/mitmproxy/proxy2/test_context.py @@ -0,0 +1,82 @@ +from mitmproxy.proxy2 import context +from mitmproxy.test import tflow, taddons + + +class TestConnection: + def test_basic(self): + c = context.Client( + ("127.0.0.1", 52314), + ("127.0.0.1", 8080), + 1607780791 + ) + assert not c.tls_established + c.timestamp_tls_setup = 1607780792 + assert c.tls_established + assert c.connected + c.state = context.ConnectionState.CAN_WRITE + assert not c.connected + + def test_eq(self): + c = tflow.tclient_conn() + c2 = c.copy() + assert c == c + assert c != c2 + assert c != 42 + assert hash(c) != hash(c2) + + c2.id = c.id + assert c == c2 + + +class TestClient: + def test_basic(self): + c = context.Client( + ("127.0.0.1", 52314), + ("127.0.0.1", 8080), + 1607780791 + ) + assert repr(c) + + def test_state(self): + c = tflow.tclient_conn() + assert context.Client.from_state(c.get_state()).get_state() == c.get_state() + + c2 = tflow.tclient_conn() + assert c != c2 + + c2.timestamp_start = 42 + c.set_state(c2.get_state()) + assert c.timestamp_start == 42 + + c3 = c.copy() + assert c3.get_state() != c.get_state() + c.id = c3.id = "foo" + assert c3.get_state() == c.get_state() + + +class TestServer: + def test_basic(self): + s = context.Server(("address", 22)) + assert repr(s) + + def test_state(self): + c = tflow.tserver_conn() + c2 = c.copy() + assert c2.get_state() != c.get_state() + c.id = c2.id = "foo" + assert c2.get_state() == c.get_state() + + +def test_context(): + with taddons.context() as tctx: + c = context.Context( + tflow.tclient_conn(), + tctx.options + ) + assert repr(c) + c.layers.append(1) + c2 = c.fork() + c.layers.append(2) + c2.layers.append(3) + assert c.layers == [1, 2] + assert c2.layers == [1, 3] diff --git a/test/mitmproxy/proxy2/test_events.py b/test/mitmproxy/proxy2/test_events.py index e69de29bb..1640aa374 100644 --- a/test/mitmproxy/proxy2/test_events.py +++ b/test/mitmproxy/proxy2/test_events.py @@ -0,0 +1,36 @@ +from unittest.mock import Mock + +import pytest + +from mitmproxy.proxy2 import events, context, commands + + +@pytest.fixture +def tconn() -> context.Server: + return context.Server(None) + + +def test_dataclasses(tconn): + assert repr(events.Start()) + assert repr(events.DataReceived(tconn, b"foo")) + assert repr(events.ConnectionClosed(tconn)) + + +def test_commandreply(): + with pytest.raises(TypeError): + events.CommandReply() + assert repr(events.HookReply(Mock(), None)) + + class FooCommand(commands.Command): + pass + + with pytest.raises(RuntimeError, match="properly annotated"): + class FooReply(events.CommandReply): + pass + + class FooReply1(events.CommandReply): + command: FooCommand + + with pytest.raises(RuntimeError, match="conflicting subclasses"): + class FooReply2(events.CommandReply): + command: FooCommand diff --git a/test/mitmproxy/proxy2/test_server_hooks.py b/test/mitmproxy/proxy2/test_server_hooks.py new file mode 100644 index 000000000..e1dfc7054 --- /dev/null +++ b/test/mitmproxy/proxy2/test_server_hooks.py @@ -0,0 +1,5 @@ +from mitmproxy.proxy2 import server_hooks + + +def test_noop(): + assert server_hooks diff --git a/test/mitmproxy/proxy2/test_tunnel.py b/test/mitmproxy/proxy2/test_tunnel.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/mitmproxy/proxy2/tutils.py b/test/mitmproxy/proxy2/tutils.py index 427c93e73..4e925fd06 100644 --- a/test/mitmproxy/proxy2/tutils.py +++ b/test/mitmproxy/proxy2/tutils.py @@ -306,7 +306,7 @@ class reply(events.Event): assert isinstance(self.to, commands.Command) if isinstance(self.to, commands.Hook): - self.side_effect(*self.to.as_tuple()) + self.side_effect(*self.to.args()) reply_cls = events.HookReply else: self.side_effect(self.to)