diff --git a/mitmproxy/addons/next_layer.py b/mitmproxy/addons/next_layer.py index bc8123280..e52fee8bf 100644 --- a/mitmproxy/addons/next_layer.py +++ b/mitmproxy/addons/next_layer.py @@ -1,7 +1,6 @@ from mitmproxy import ctx -from mitmproxy.net import server_spec -from mitmproxy.proxy.config import HostMatcher from mitmproxy.net.tls import is_tls_record_magic +from mitmproxy.proxy.config import HostMatcher from mitmproxy.proxy.protocol.http import HTTPMode from mitmproxy.proxy2 import layer, layers, context @@ -16,20 +15,6 @@ class NextLayer: if "tcp_hosts" in updated: self.check_tcp = HostMatcher(ctx.options.tcp_hosts) - def make_top_layer(self, context): - if ctx.options.mode == "regular": - return layers.modes.HttpProxy(context) - elif ctx.options.mode == "transparent": - raise NotImplementedError("Mode not implemented.") - elif ctx.options.mode == "socks5": - raise NotImplementedError("Mode not implemented.") - elif ctx.options.mode.startswith("reverse:"): - return layers.modes.ReverseProxy(context) - elif ctx.options.mode.startswith("upstream:"): - raise NotImplementedError("Mode not implemented.") - else: - raise NotImplementedError("Mode not implemented.") - def next_layer(self, nextlayer: layer.NextLayer): nextlayer.layer = self._next_layer(nextlayer, nextlayer.context) @@ -90,3 +75,17 @@ class NextLayer: # 8. Assume HTTP1 by default. return layers.HTTPLayer(context, HTTPMode.transparent) + + def make_top_layer(self, context): + if ctx.options.mode == "regular": + return layers.modes.HttpProxy(context) + elif ctx.options.mode == "transparent": + raise NotImplementedError("Mode not implemented.") + elif ctx.options.mode == "socks5": + raise NotImplementedError("Mode not implemented.") + elif ctx.options.mode.startswith("reverse:"): + return layers.modes.ReverseProxy(context) + elif ctx.options.mode.startswith("upstream:"): + raise NotImplementedError("Mode not implemented.") + else: + raise NotImplementedError("Mode not implemented.") diff --git a/mitmproxy/proxy2/commands.py b/mitmproxy/proxy2/commands.py index e7f3e68dc..52c2454db 100644 --- a/mitmproxy/proxy2/commands.py +++ b/mitmproxy/proxy2/commands.py @@ -8,7 +8,7 @@ The counterpart to commands are events. """ import typing -from mitmproxy.proxy2.context import Connection +from mitmproxy.proxy2.context import Connection, Server class Command: @@ -61,6 +61,7 @@ class OpenConnection(ConnectionCommand): """ Open a new connection """ + connection: Server blocking = True @@ -88,9 +89,7 @@ class Log(Command): message: str level: str - def __init__(self, message, level="info"): - assert isinstance(message, str) - assert isinstance(level, str) + def __init__(self, message: str, level: str="info"): self.message = message self.level = level diff --git a/mitmproxy/proxy2/layers/tls.py b/mitmproxy/proxy2/layers/tls.py index 53aa22e56..d2999503a 100644 --- a/mitmproxy/proxy2/layers/tls.py +++ b/mitmproxy/proxy2/layers/tls.py @@ -109,8 +109,8 @@ class _TLSLayer(layer.Layer): yield commands.SendData(conn, data) def send( - self, - send_command: commands.SendData, + self, + send_command: commands.SendData, ) -> commands.TCommandGenerator: tls_conn = self.tls[send_command.connection] if send_command.connection.tls_established: @@ -288,13 +288,13 @@ class ClientTLSLayer(_TLSLayer): client.alpn_offers = client_hello.alpn_protocols client_tls_requires_server_connection = ( - self.context.server.tls and - self.context.options.upstream_cert and - ( - self.context.options.add_upstream_certs_to_client_chain or - client.alpn_offers or - not client.sni - ) + self.context.server.tls and + self.context.options.upstream_cert and + ( + self.context.options.add_upstream_certs_to_client_chain or + client.alpn_offers or + not client.sni + ) ) # What do we do with the client connection now? @@ -304,6 +304,9 @@ class ClientTLSLayer(_TLSLayer): else: yield from self.start_negotiate() self._handle_event = self.state_process + + # In any case, we now have enough information to start server TLS if needed. + yield from self.child_layer.handle_event(events.Start()) else: raise NotImplementedError(event) # TODO @@ -318,7 +321,7 @@ class ClientTLSLayer(_TLSLayer): def state_process(self, event: events.Event): if isinstance(event, events.DataReceived) and event.connection == self.context.client: - if not event.connection.tls_established: + if not self.context.client.tls_established: yield from self.negotiate(event) else: yield from self.relay(event) @@ -342,12 +345,7 @@ class ClientTLSLayer(_TLSLayer): if not (x.startswith(b"h2-") or x.startswith(b"spdy")) ] - yield from self.child_layer.handle_event(events.Start()) - def start_negotiate(self): - if not self.child_layer: - yield from self.child_layer.handle_event(events.Start()) - # FIXME: Do this properly client = self.context.client server = self.context.server diff --git a/test/mitmproxy/proxy2/conftest.py b/test/mitmproxy/proxy2/conftest.py index d428a7269..03ce5b7f7 100644 --- a/test/mitmproxy/proxy2/conftest.py +++ b/test/mitmproxy/proxy2/conftest.py @@ -8,6 +8,5 @@ from mitmproxy.proxy2 import context def tctx(): return context.Context( context.Client(("client", 1234)), - context.Server(("server", 42)), options.Options() ) diff --git a/test/mitmproxy/proxy2/layers/test_tls.py b/test/mitmproxy/proxy2/layers/test_tls.py index 7f199fd3f..d29adb307 100644 --- a/test/mitmproxy/proxy2/layers/test_tls.py +++ b/test/mitmproxy/proxy2/layers/test_tls.py @@ -81,37 +81,30 @@ class SSLTest: ) -def test_no_tls(tctx: context.Context): +def test_server_no_tls(tctx: context.Context): """Test TLS layer without TLS""" - layer = tls.TLSLayer(tctx) + layer = tls.ServerTLSLayer(tctx) playbook = tutils.playbook(layer) - next_layer = tutils.Placeholder() # Handshake assert ( playbook >> events.DataReceived(tctx.client, b"Hello World") - << commands.Hook("next_layer", next_layer) - ) - next_layer().layer = tutils.EchoLayer(next_layer().context) - assert ( - playbook - >> events.HookReply(-1) + << commands.Hook("next_layer", tutils.Placeholder()) + >> tutils.next_layer(tutils.EchoLayer) << commands.SendData(tctx.client, b"hello world") ) -def test_client_tls(tctx: context.Context): +def test_client_tls_only(tctx: context.Context): """Test TLS with client only""" - layer = tls.TLSLayer(tctx) + layer = tls.ClientTLSLayer(tctx) playbook = tutils.playbook(layer) - tctx.client.tls = True tssl = SSLTest() # Handshake assert playbook - assert layer.state[tctx.client] == tls.ConnectionState.NEGOTIATING - assert layer.state[tctx.server] == tls.ConnectionState.NO_TLS + assert layer._handle_event == layer.state_wait_for_clienthello def interact(): data = tutils.Placeholder() @@ -136,29 +129,25 @@ def test_client_tls(tctx: context.Context): assert interact() tssl.obj.do_handshake() - assert layer.state[tctx.client] == tls.ConnectionState.ESTABLISHED - assert layer.state[tctx.server] == tls.ConnectionState.NO_TLS + assert layer._handle_event == layer.state_process # Echo echo(playbook, tssl, tctx.client) - - -def echo(playbook, tssl, conn): - tconn = type(conn).__name__.lower() - tssl.obj.write(b"Hello World") - next_layer = tutils.Placeholder() assert ( playbook - >> events.DataReceived(conn, tssl.out.read()) - << commands.Log(f"PlainDataReceived({tconn}, b'Hello World')") - << commands.Hook("next_layer", next_layer) + >> events.DataReceived(tctx.server, b"Hello") + << commands.SendData(tctx.server, b"hello") ) - next_layer().layer = tutils.EchoLayer(next_layer().context) + + +def echo(playbook: tutils.playbook, tssl: SSLTest, conn: context.Connection) -> None: + tssl.obj.write(b"Hello World") data = tutils.Placeholder() assert ( playbook - >> events.HookReply(-1) - << commands.Log(f"PlainSendData({tconn}, b'hello world')") + >> events.DataReceived(conn, tssl.out.read()) + << commands.Hook("next_layer", tutils.Placeholder()) + >> tutils.next_layer(tutils.EchoLayer) << commands.SendData(conn, data) ) tssl.inc.write(data()) @@ -166,20 +155,26 @@ def echo(playbook, tssl, conn): def test_server_tls_no_conn(tctx): - layer = tls.TLSLayer(tctx) + """ + The server TLS layer is initiated, but there is no active connection yet, so nothing + should be done. + """ + layer = tls.ServerTLSLayer(tctx) playbook = tutils.playbook(layer) tctx.server.tls = True # We did not have a server connection before, so let's do nothing. - assert playbook - assert layer.state[tctx.client] == tls.ConnectionState.NO_TLS - assert layer.state[tctx.server] == tls.ConnectionState.NO_TLS + assert ( + playbook + << None + ) def test_server_tls(tctx): - layer = tls.TLSLayer(tctx) + layer = tls.ServerTLSLayer(tctx) playbook = tutils.playbook(layer) tctx.server.connected = True + tctx.server.address = ("example.com", 443) tctx.server.tls = True tssl = SSLTest(server_side=True) @@ -190,8 +185,6 @@ def test_server_tls(tctx): playbook << commands.SendData(tctx.server, data) ) - assert layer.state[tctx.client] == tls.ConnectionState.NO_TLS - assert layer.state[tctx.server] == tls.ConnectionState.NEGOTIATING # receive ServerHello, finish client handshake tssl.inc.write(data()) @@ -213,24 +206,25 @@ def test_server_tls(tctx): << None ) - assert layer.state[tctx.client] == tls.ConnectionState.NO_TLS - assert layer.state[tctx.server] == tls.ConnectionState.ESTABLISHED + assert tctx.server.tls_established + assert tctx.server.sni == b"example.com" # Echo echo(playbook, tssl, tctx.server) def _test_tls_client_server(tctx, alpn): - layer = tls.TLSLayer(tctx) + layer = tls.ClientTLSLayer(tctx) playbook = tutils.playbook(layer) - tctx.client.tls = True tctx.server.tls = True + tctx.server.address = ("example.com", 443) tssl_client = SSLTest(alpn=alpn) # Handshake - assert playbook - assert layer.state[tctx.client] == tls.ConnectionState.WAIT_FOR_CLIENTHELLO - assert layer.state[tctx.server] == tls.ConnectionState.WAIT_FOR_CLIENTHELLO + assert ( + playbook + << None + ) with pytest.raises(ssl.SSLWantReadError): tssl_client.obj.do_handshake() @@ -241,9 +235,6 @@ def _test_tls_client_server(tctx, alpn): << None ) # Still waiting... - assert layer.state[tctx.client] == tls.ConnectionState.WAIT_FOR_CLIENTHELLO - assert layer.state[tctx.server] == tls.ConnectionState.WAIT_FOR_CLIENTHELLO - # Finish sending ClientHello playbook >> events.DataReceived(tctx.client, client_hello[42:]) return playbook, tssl_client @@ -263,8 +254,7 @@ def test_tls_client_server_no_server_conn(tctx): << commands.SendData(tctx.client, data) ) assert data() - assert playbook.layer.state[tctx.client] == tls.ConnectionState.NEGOTIATING - assert playbook.layer.state[tctx.server] == tls.ConnectionState.NO_TLS + assert playbook.layer._handle_event == playbook.layer.state_process def test_tls_client_server_alpn(tctx): @@ -288,8 +278,8 @@ def test_tls_client_server_alpn(tctx): >> events.OpenConnectionReply(-1, None) << commands.SendData(tctx.server, data) ) - assert playbook.layer.state[tctx.client] == tls.ConnectionState.WAIT_FOR_SERVER_TLS - assert playbook.layer.state[tctx.server] == tls.ConnectionState.NEGOTIATING + assert playbook.layer._handle_event == playbook.layer.state_wait_for_server_tls + assert playbook.layer.child_layer.tls[tctx.server] # Establish TLS with the server... tssl_server.inc.write(data()) @@ -310,8 +300,8 @@ def test_tls_client_server_alpn(tctx): << commands.SendData(tctx.client, data) ) - assert playbook.layer.state[tctx.client] == tls.ConnectionState.NEGOTIATING - assert playbook.layer.state[tctx.server] == tls.ConnectionState.ESTABLISHED + assert playbook.layer._handle_event == playbook.layer.state_process + assert tctx.server.tls_established # Server TLS is established, we can now reply to the client handshake... tssl_client.inc.write(data()) @@ -327,8 +317,8 @@ def test_tls_client_server_alpn(tctx): tssl_client.obj.do_handshake() # Both handshakes completed! - assert playbook.layer.state[tctx.client] == tls.ConnectionState.ESTABLISHED - assert playbook.layer.state[tctx.server] == tls.ConnectionState.ESTABLISHED + assert tctx.client.tls_established + assert tctx.server.tls_established assert tssl_client.obj.selected_alpn_protocol() == "foo" assert tssl_server.obj.selected_alpn_protocol() == "foo" diff --git a/test/mitmproxy/proxy2/tutils.py b/test/mitmproxy/proxy2/tutils.py index 5ae0ea862..8edf3d30d 100644 --- a/test/mitmproxy/proxy2/tutils.py +++ b/test/mitmproxy/proxy2/tutils.py @@ -1,21 +1,21 @@ -import collections import copy import difflib import itertools import typing -from mitmproxy.proxy2 import commands +import collections + +from mitmproxy.proxy2 import commands, context from mitmproxy.proxy2 import events -from mitmproxy.proxy2 import layer -from mitmproxy.proxy2.layer import Layer +from mitmproxy.proxy2.layer import Layer, NextLayer TPlaybookEntry = typing.Union[commands.Command, events.Event] TPlaybook = typing.List[TPlaybookEntry] def _eq( - a: TPlaybookEntry, - b: TPlaybookEntry + a: TPlaybookEntry, + b: TPlaybookEntry ) -> bool: """Compare two commands/events, and possibly update placeholders.""" if type(a) != type(b): @@ -43,8 +43,8 @@ def _eq( def eq( - a: typing.Union[TPlaybookEntry, typing.Iterable[TPlaybookEntry]], - b: typing.Union[TPlaybookEntry, typing.Iterable[TPlaybookEntry]] + a: typing.Union[TPlaybookEntry, typing.Iterable[TPlaybookEntry]], + b: typing.Union[TPlaybookEntry, typing.Iterable[TPlaybookEntry]] ): """ Compare an indiviual event/command or a list of events/commands. @@ -76,7 +76,7 @@ class playbook: x2 = list(t.handle_event(events.OpenConnectionReply(x1[-1]))) assert x2 == [] """ - layer: layer.Layer + layer: Layer """The base layer""" expected: TPlaybook """expected command/event sequence""" @@ -84,11 +84,14 @@ class playbook: """actual command/event sequence""" _errored: bool """used to check if playbook as been fully asserted""" + ignore_log: bool + """If True, log statements are ignored.""" def __init__( - self, - layer, - expected=None, + self, + layer: Layer, + expected: typing.Optional[TPlaybook]=None, + ignore_log: bool=True ): if expected is None: expected = [ @@ -99,6 +102,7 @@ class playbook: self.expected = expected self.actual = [] self._errored = False + self.ignore_log = ignore_log def __rshift__(self, e): """Add an event to send""" @@ -111,6 +115,7 @@ class playbook: if c is None: return self assert isinstance(c, commands.Command) + assert not (self.ignore_log and isinstance(c, commands.Log)) self.expected.append(c) return self @@ -124,19 +129,26 @@ class playbook: if isinstance(x, events.CommandReply): if isinstance(x.command, int) and abs(x.command) < len(self.actual): x.command = self.actual[x.command] + if hasattr(x, "_playbook_eval"): + x._playbook_eval(self) self.actual.append(x) self.actual.extend( self.layer.handle_event(x) ) + if self.ignore_log: + self.actual = [ + x for x in self.actual if not isinstance(x, commands.Log) + ] + if not eq(self.expected, self.actual): self._errored = True def _str(x): arrow = ">" if isinstance(x, events.Event) else "<" - x = str(x)\ - .replace('Placeholder:None', '')\ + x = str(x) \ + .replace('Placeholder:None', '') \ .replace('Placeholder:', '') return f"{arrow} {x}" @@ -189,9 +201,48 @@ class Placeholder: def __repr__(self): return f"Placeholder:{repr(self.obj)}" + def __str__(self): + return f"Placeholder:{str(self.obj)}" + class EchoLayer(Layer): """Echo layer that sends all data back to the client in lowercase.""" + def _handle_event(self, event: events.Event): if isinstance(event, events.DataReceived): yield commands.SendData(event.connection, event.data.lower()) + + +def next_layer( + layer: typing.Union[typing.Type[Layer], typing.Callable[[context.Context], Layer]] +) -> events.HookReply: + """ + Helper function to simplify the syntax for next_layer events from this: + + << commands.Hook("next_layer", next_layer) + ) + next_layer().layer = tutils.EchoLayer(next_layer().context) + assert ( + playbook + >> events.HookReply(-1) + + to this: + + << commands.Hook("next_layer", next_layer) + >> tutils.next_layer(next_layer, tutils.EchoLayer) + """ + if isinstance(layer, type): + def make_layer(ctx: context.Context) -> Layer: + return layer(ctx) + else: + make_layer = layer + + def set_layer(playbook: playbook) -> None: + last_command = playbook.actual[-1] + assert isinstance(last_command, commands.Hook) + assert isinstance(last_command.data, NextLayer) + last_command.data.layer = make_layer(last_command.data.context) + + reply = events.HookReply(-1) + reply._playbook_eval = set_layer + return reply