mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-22 15:37:45 +00:00
[sans-io] adjust tls tests
This commit is contained in:
parent
25373672c5
commit
8938aec2c0
@ -1,7 +1,6 @@
|
||||
from mitmproxy import ctx
|
||||
from mitmproxy.net import server_spec
|
||||
from mitmproxy.proxy.config import HostMatcher
|
||||
from mitmproxy.net.tls import is_tls_record_magic
|
||||
from mitmproxy.proxy.config import HostMatcher
|
||||
from mitmproxy.proxy.protocol.http import HTTPMode
|
||||
from mitmproxy.proxy2 import layer, layers, context
|
||||
|
||||
@ -16,20 +15,6 @@ class NextLayer:
|
||||
if "tcp_hosts" in updated:
|
||||
self.check_tcp = HostMatcher(ctx.options.tcp_hosts)
|
||||
|
||||
def make_top_layer(self, context):
|
||||
if ctx.options.mode == "regular":
|
||||
return layers.modes.HttpProxy(context)
|
||||
elif ctx.options.mode == "transparent":
|
||||
raise NotImplementedError("Mode not implemented.")
|
||||
elif ctx.options.mode == "socks5":
|
||||
raise NotImplementedError("Mode not implemented.")
|
||||
elif ctx.options.mode.startswith("reverse:"):
|
||||
return layers.modes.ReverseProxy(context)
|
||||
elif ctx.options.mode.startswith("upstream:"):
|
||||
raise NotImplementedError("Mode not implemented.")
|
||||
else:
|
||||
raise NotImplementedError("Mode not implemented.")
|
||||
|
||||
def next_layer(self, nextlayer: layer.NextLayer):
|
||||
nextlayer.layer = self._next_layer(nextlayer, nextlayer.context)
|
||||
|
||||
@ -90,3 +75,17 @@ class NextLayer:
|
||||
|
||||
# 8. Assume HTTP1 by default.
|
||||
return layers.HTTPLayer(context, HTTPMode.transparent)
|
||||
|
||||
def make_top_layer(self, context):
|
||||
if ctx.options.mode == "regular":
|
||||
return layers.modes.HttpProxy(context)
|
||||
elif ctx.options.mode == "transparent":
|
||||
raise NotImplementedError("Mode not implemented.")
|
||||
elif ctx.options.mode == "socks5":
|
||||
raise NotImplementedError("Mode not implemented.")
|
||||
elif ctx.options.mode.startswith("reverse:"):
|
||||
return layers.modes.ReverseProxy(context)
|
||||
elif ctx.options.mode.startswith("upstream:"):
|
||||
raise NotImplementedError("Mode not implemented.")
|
||||
else:
|
||||
raise NotImplementedError("Mode not implemented.")
|
||||
|
@ -8,7 +8,7 @@ The counterpart to commands are events.
|
||||
"""
|
||||
import typing
|
||||
|
||||
from mitmproxy.proxy2.context import Connection
|
||||
from mitmproxy.proxy2.context import Connection, Server
|
||||
|
||||
|
||||
class Command:
|
||||
@ -61,6 +61,7 @@ class OpenConnection(ConnectionCommand):
|
||||
"""
|
||||
Open a new connection
|
||||
"""
|
||||
connection: Server
|
||||
blocking = True
|
||||
|
||||
|
||||
@ -88,9 +89,7 @@ class Log(Command):
|
||||
message: str
|
||||
level: str
|
||||
|
||||
def __init__(self, message, level="info"):
|
||||
assert isinstance(message, str)
|
||||
assert isinstance(level, str)
|
||||
def __init__(self, message: str, level: str="info"):
|
||||
self.message = message
|
||||
self.level = level
|
||||
|
||||
|
@ -109,8 +109,8 @@ class _TLSLayer(layer.Layer):
|
||||
yield commands.SendData(conn, data)
|
||||
|
||||
def send(
|
||||
self,
|
||||
send_command: commands.SendData,
|
||||
self,
|
||||
send_command: commands.SendData,
|
||||
) -> commands.TCommandGenerator:
|
||||
tls_conn = self.tls[send_command.connection]
|
||||
if send_command.connection.tls_established:
|
||||
@ -288,13 +288,13 @@ class ClientTLSLayer(_TLSLayer):
|
||||
client.alpn_offers = client_hello.alpn_protocols
|
||||
|
||||
client_tls_requires_server_connection = (
|
||||
self.context.server.tls and
|
||||
self.context.options.upstream_cert and
|
||||
(
|
||||
self.context.options.add_upstream_certs_to_client_chain or
|
||||
client.alpn_offers or
|
||||
not client.sni
|
||||
)
|
||||
self.context.server.tls and
|
||||
self.context.options.upstream_cert and
|
||||
(
|
||||
self.context.options.add_upstream_certs_to_client_chain or
|
||||
client.alpn_offers or
|
||||
not client.sni
|
||||
)
|
||||
)
|
||||
|
||||
# What do we do with the client connection now?
|
||||
@ -304,6 +304,9 @@ class ClientTLSLayer(_TLSLayer):
|
||||
else:
|
||||
yield from self.start_negotiate()
|
||||
self._handle_event = self.state_process
|
||||
|
||||
# In any case, we now have enough information to start server TLS if needed.
|
||||
yield from self.child_layer.handle_event(events.Start())
|
||||
else:
|
||||
raise NotImplementedError(event) # TODO
|
||||
|
||||
@ -318,7 +321,7 @@ class ClientTLSLayer(_TLSLayer):
|
||||
|
||||
def state_process(self, event: events.Event):
|
||||
if isinstance(event, events.DataReceived) and event.connection == self.context.client:
|
||||
if not event.connection.tls_established:
|
||||
if not self.context.client.tls_established:
|
||||
yield from self.negotiate(event)
|
||||
else:
|
||||
yield from self.relay(event)
|
||||
@ -342,12 +345,7 @@ class ClientTLSLayer(_TLSLayer):
|
||||
if not (x.startswith(b"h2-") or x.startswith(b"spdy"))
|
||||
]
|
||||
|
||||
yield from self.child_layer.handle_event(events.Start())
|
||||
|
||||
def start_negotiate(self):
|
||||
if not self.child_layer:
|
||||
yield from self.child_layer.handle_event(events.Start())
|
||||
|
||||
# FIXME: Do this properly
|
||||
client = self.context.client
|
||||
server = self.context.server
|
||||
|
@ -8,6 +8,5 @@ from mitmproxy.proxy2 import context
|
||||
def tctx():
|
||||
return context.Context(
|
||||
context.Client(("client", 1234)),
|
||||
context.Server(("server", 42)),
|
||||
options.Options()
|
||||
)
|
||||
|
@ -81,37 +81,30 @@ class SSLTest:
|
||||
)
|
||||
|
||||
|
||||
def test_no_tls(tctx: context.Context):
|
||||
def test_server_no_tls(tctx: context.Context):
|
||||
"""Test TLS layer without TLS"""
|
||||
layer = tls.TLSLayer(tctx)
|
||||
layer = tls.ServerTLSLayer(tctx)
|
||||
playbook = tutils.playbook(layer)
|
||||
next_layer = tutils.Placeholder()
|
||||
|
||||
# Handshake
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, b"Hello World")
|
||||
<< commands.Hook("next_layer", next_layer)
|
||||
)
|
||||
next_layer().layer = tutils.EchoLayer(next_layer().context)
|
||||
assert (
|
||||
playbook
|
||||
>> events.HookReply(-1)
|
||||
<< commands.Hook("next_layer", tutils.Placeholder())
|
||||
>> tutils.next_layer(tutils.EchoLayer)
|
||||
<< commands.SendData(tctx.client, b"hello world")
|
||||
)
|
||||
|
||||
|
||||
def test_client_tls(tctx: context.Context):
|
||||
def test_client_tls_only(tctx: context.Context):
|
||||
"""Test TLS with client only"""
|
||||
layer = tls.TLSLayer(tctx)
|
||||
layer = tls.ClientTLSLayer(tctx)
|
||||
playbook = tutils.playbook(layer)
|
||||
tctx.client.tls = True
|
||||
tssl = SSLTest()
|
||||
|
||||
# Handshake
|
||||
assert playbook
|
||||
assert layer.state[tctx.client] == tls.ConnectionState.NEGOTIATING
|
||||
assert layer.state[tctx.server] == tls.ConnectionState.NO_TLS
|
||||
assert layer._handle_event == layer.state_wait_for_clienthello
|
||||
|
||||
def interact():
|
||||
data = tutils.Placeholder()
|
||||
@ -136,29 +129,25 @@ def test_client_tls(tctx: context.Context):
|
||||
assert interact()
|
||||
tssl.obj.do_handshake()
|
||||
|
||||
assert layer.state[tctx.client] == tls.ConnectionState.ESTABLISHED
|
||||
assert layer.state[tctx.server] == tls.ConnectionState.NO_TLS
|
||||
assert layer._handle_event == layer.state_process
|
||||
|
||||
# Echo
|
||||
echo(playbook, tssl, tctx.client)
|
||||
|
||||
|
||||
def echo(playbook, tssl, conn):
|
||||
tconn = type(conn).__name__.lower()
|
||||
tssl.obj.write(b"Hello World")
|
||||
next_layer = tutils.Placeholder()
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(conn, tssl.out.read())
|
||||
<< commands.Log(f"PlainDataReceived({tconn}, b'Hello World')")
|
||||
<< commands.Hook("next_layer", next_layer)
|
||||
>> events.DataReceived(tctx.server, b"Hello")
|
||||
<< commands.SendData(tctx.server, b"hello")
|
||||
)
|
||||
next_layer().layer = tutils.EchoLayer(next_layer().context)
|
||||
|
||||
|
||||
def echo(playbook: tutils.playbook, tssl: SSLTest, conn: context.Connection) -> None:
|
||||
tssl.obj.write(b"Hello World")
|
||||
data = tutils.Placeholder()
|
||||
assert (
|
||||
playbook
|
||||
>> events.HookReply(-1)
|
||||
<< commands.Log(f"PlainSendData({tconn}, b'hello world')")
|
||||
>> events.DataReceived(conn, tssl.out.read())
|
||||
<< commands.Hook("next_layer", tutils.Placeholder())
|
||||
>> tutils.next_layer(tutils.EchoLayer)
|
||||
<< commands.SendData(conn, data)
|
||||
)
|
||||
tssl.inc.write(data())
|
||||
@ -166,20 +155,26 @@ def echo(playbook, tssl, conn):
|
||||
|
||||
|
||||
def test_server_tls_no_conn(tctx):
|
||||
layer = tls.TLSLayer(tctx)
|
||||
"""
|
||||
The server TLS layer is initiated, but there is no active connection yet, so nothing
|
||||
should be done.
|
||||
"""
|
||||
layer = tls.ServerTLSLayer(tctx)
|
||||
playbook = tutils.playbook(layer)
|
||||
tctx.server.tls = True
|
||||
|
||||
# We did not have a server connection before, so let's do nothing.
|
||||
assert playbook
|
||||
assert layer.state[tctx.client] == tls.ConnectionState.NO_TLS
|
||||
assert layer.state[tctx.server] == tls.ConnectionState.NO_TLS
|
||||
assert (
|
||||
playbook
|
||||
<< None
|
||||
)
|
||||
|
||||
|
||||
def test_server_tls(tctx):
|
||||
layer = tls.TLSLayer(tctx)
|
||||
layer = tls.ServerTLSLayer(tctx)
|
||||
playbook = tutils.playbook(layer)
|
||||
tctx.server.connected = True
|
||||
tctx.server.address = ("example.com", 443)
|
||||
tctx.server.tls = True
|
||||
|
||||
tssl = SSLTest(server_side=True)
|
||||
@ -190,8 +185,6 @@ def test_server_tls(tctx):
|
||||
playbook
|
||||
<< commands.SendData(tctx.server, data)
|
||||
)
|
||||
assert layer.state[tctx.client] == tls.ConnectionState.NO_TLS
|
||||
assert layer.state[tctx.server] == tls.ConnectionState.NEGOTIATING
|
||||
|
||||
# receive ServerHello, finish client handshake
|
||||
tssl.inc.write(data())
|
||||
@ -213,24 +206,25 @@ def test_server_tls(tctx):
|
||||
<< None
|
||||
)
|
||||
|
||||
assert layer.state[tctx.client] == tls.ConnectionState.NO_TLS
|
||||
assert layer.state[tctx.server] == tls.ConnectionState.ESTABLISHED
|
||||
assert tctx.server.tls_established
|
||||
assert tctx.server.sni == b"example.com"
|
||||
|
||||
# Echo
|
||||
echo(playbook, tssl, tctx.server)
|
||||
|
||||
|
||||
def _test_tls_client_server(tctx, alpn):
|
||||
layer = tls.TLSLayer(tctx)
|
||||
layer = tls.ClientTLSLayer(tctx)
|
||||
playbook = tutils.playbook(layer)
|
||||
tctx.client.tls = True
|
||||
tctx.server.tls = True
|
||||
tctx.server.address = ("example.com", 443)
|
||||
tssl_client = SSLTest(alpn=alpn)
|
||||
|
||||
# Handshake
|
||||
assert playbook
|
||||
assert layer.state[tctx.client] == tls.ConnectionState.WAIT_FOR_CLIENTHELLO
|
||||
assert layer.state[tctx.server] == tls.ConnectionState.WAIT_FOR_CLIENTHELLO
|
||||
assert (
|
||||
playbook
|
||||
<< None
|
||||
)
|
||||
|
||||
with pytest.raises(ssl.SSLWantReadError):
|
||||
tssl_client.obj.do_handshake()
|
||||
@ -241,9 +235,6 @@ def _test_tls_client_server(tctx, alpn):
|
||||
<< None
|
||||
)
|
||||
# Still waiting...
|
||||
assert layer.state[tctx.client] == tls.ConnectionState.WAIT_FOR_CLIENTHELLO
|
||||
assert layer.state[tctx.server] == tls.ConnectionState.WAIT_FOR_CLIENTHELLO
|
||||
|
||||
# Finish sending ClientHello
|
||||
playbook >> events.DataReceived(tctx.client, client_hello[42:])
|
||||
return playbook, tssl_client
|
||||
@ -263,8 +254,7 @@ def test_tls_client_server_no_server_conn(tctx):
|
||||
<< commands.SendData(tctx.client, data)
|
||||
)
|
||||
assert data()
|
||||
assert playbook.layer.state[tctx.client] == tls.ConnectionState.NEGOTIATING
|
||||
assert playbook.layer.state[tctx.server] == tls.ConnectionState.NO_TLS
|
||||
assert playbook.layer._handle_event == playbook.layer.state_process
|
||||
|
||||
|
||||
def test_tls_client_server_alpn(tctx):
|
||||
@ -288,8 +278,8 @@ def test_tls_client_server_alpn(tctx):
|
||||
>> events.OpenConnectionReply(-1, None)
|
||||
<< commands.SendData(tctx.server, data)
|
||||
)
|
||||
assert playbook.layer.state[tctx.client] == tls.ConnectionState.WAIT_FOR_SERVER_TLS
|
||||
assert playbook.layer.state[tctx.server] == tls.ConnectionState.NEGOTIATING
|
||||
assert playbook.layer._handle_event == playbook.layer.state_wait_for_server_tls
|
||||
assert playbook.layer.child_layer.tls[tctx.server]
|
||||
|
||||
# Establish TLS with the server...
|
||||
tssl_server.inc.write(data())
|
||||
@ -310,8 +300,8 @@ def test_tls_client_server_alpn(tctx):
|
||||
<< commands.SendData(tctx.client, data)
|
||||
)
|
||||
|
||||
assert playbook.layer.state[tctx.client] == tls.ConnectionState.NEGOTIATING
|
||||
assert playbook.layer.state[tctx.server] == tls.ConnectionState.ESTABLISHED
|
||||
assert playbook.layer._handle_event == playbook.layer.state_process
|
||||
assert tctx.server.tls_established
|
||||
|
||||
# Server TLS is established, we can now reply to the client handshake...
|
||||
tssl_client.inc.write(data())
|
||||
@ -327,8 +317,8 @@ def test_tls_client_server_alpn(tctx):
|
||||
tssl_client.obj.do_handshake()
|
||||
|
||||
# Both handshakes completed!
|
||||
assert playbook.layer.state[tctx.client] == tls.ConnectionState.ESTABLISHED
|
||||
assert playbook.layer.state[tctx.server] == tls.ConnectionState.ESTABLISHED
|
||||
assert tctx.client.tls_established
|
||||
assert tctx.server.tls_established
|
||||
|
||||
assert tssl_client.obj.selected_alpn_protocol() == "foo"
|
||||
assert tssl_server.obj.selected_alpn_protocol() == "foo"
|
||||
|
@ -1,21 +1,21 @@
|
||||
import collections
|
||||
import copy
|
||||
import difflib
|
||||
import itertools
|
||||
import typing
|
||||
|
||||
from mitmproxy.proxy2 import commands
|
||||
import collections
|
||||
|
||||
from mitmproxy.proxy2 import commands, context
|
||||
from mitmproxy.proxy2 import events
|
||||
from mitmproxy.proxy2 import layer
|
||||
from mitmproxy.proxy2.layer import Layer
|
||||
from mitmproxy.proxy2.layer import Layer, NextLayer
|
||||
|
||||
TPlaybookEntry = typing.Union[commands.Command, events.Event]
|
||||
TPlaybook = typing.List[TPlaybookEntry]
|
||||
|
||||
|
||||
def _eq(
|
||||
a: TPlaybookEntry,
|
||||
b: TPlaybookEntry
|
||||
a: TPlaybookEntry,
|
||||
b: TPlaybookEntry
|
||||
) -> bool:
|
||||
"""Compare two commands/events, and possibly update placeholders."""
|
||||
if type(a) != type(b):
|
||||
@ -43,8 +43,8 @@ def _eq(
|
||||
|
||||
|
||||
def eq(
|
||||
a: typing.Union[TPlaybookEntry, typing.Iterable[TPlaybookEntry]],
|
||||
b: typing.Union[TPlaybookEntry, typing.Iterable[TPlaybookEntry]]
|
||||
a: typing.Union[TPlaybookEntry, typing.Iterable[TPlaybookEntry]],
|
||||
b: typing.Union[TPlaybookEntry, typing.Iterable[TPlaybookEntry]]
|
||||
):
|
||||
"""
|
||||
Compare an indiviual event/command or a list of events/commands.
|
||||
@ -76,7 +76,7 @@ class playbook:
|
||||
x2 = list(t.handle_event(events.OpenConnectionReply(x1[-1])))
|
||||
assert x2 == []
|
||||
"""
|
||||
layer: layer.Layer
|
||||
layer: Layer
|
||||
"""The base layer"""
|
||||
expected: TPlaybook
|
||||
"""expected command/event sequence"""
|
||||
@ -84,11 +84,14 @@ class playbook:
|
||||
"""actual command/event sequence"""
|
||||
_errored: bool
|
||||
"""used to check if playbook as been fully asserted"""
|
||||
ignore_log: bool
|
||||
"""If True, log statements are ignored."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer,
|
||||
expected=None,
|
||||
self,
|
||||
layer: Layer,
|
||||
expected: typing.Optional[TPlaybook]=None,
|
||||
ignore_log: bool=True
|
||||
):
|
||||
if expected is None:
|
||||
expected = [
|
||||
@ -99,6 +102,7 @@ class playbook:
|
||||
self.expected = expected
|
||||
self.actual = []
|
||||
self._errored = False
|
||||
self.ignore_log = ignore_log
|
||||
|
||||
def __rshift__(self, e):
|
||||
"""Add an event to send"""
|
||||
@ -111,6 +115,7 @@ class playbook:
|
||||
if c is None:
|
||||
return self
|
||||
assert isinstance(c, commands.Command)
|
||||
assert not (self.ignore_log and isinstance(c, commands.Log))
|
||||
self.expected.append(c)
|
||||
return self
|
||||
|
||||
@ -124,19 +129,26 @@ class playbook:
|
||||
if isinstance(x, events.CommandReply):
|
||||
if isinstance(x.command, int) and abs(x.command) < len(self.actual):
|
||||
x.command = self.actual[x.command]
|
||||
if hasattr(x, "_playbook_eval"):
|
||||
x._playbook_eval(self)
|
||||
|
||||
self.actual.append(x)
|
||||
self.actual.extend(
|
||||
self.layer.handle_event(x)
|
||||
)
|
||||
|
||||
if self.ignore_log:
|
||||
self.actual = [
|
||||
x for x in self.actual if not isinstance(x, commands.Log)
|
||||
]
|
||||
|
||||
if not eq(self.expected, self.actual):
|
||||
self._errored = True
|
||||
|
||||
def _str(x):
|
||||
arrow = ">" if isinstance(x, events.Event) else "<"
|
||||
x = str(x)\
|
||||
.replace('Placeholder:None', '<unset placeholder>')\
|
||||
x = str(x) \
|
||||
.replace('Placeholder:None', '<unset placeholder>') \
|
||||
.replace('Placeholder:', '')
|
||||
return f"{arrow} {x}"
|
||||
|
||||
@ -189,9 +201,48 @@ class Placeholder:
|
||||
def __repr__(self):
|
||||
return f"Placeholder:{repr(self.obj)}"
|
||||
|
||||
def __str__(self):
|
||||
return f"Placeholder:{str(self.obj)}"
|
||||
|
||||
|
||||
class EchoLayer(Layer):
|
||||
"""Echo layer that sends all data back to the client in lowercase."""
|
||||
|
||||
def _handle_event(self, event: events.Event):
|
||||
if isinstance(event, events.DataReceived):
|
||||
yield commands.SendData(event.connection, event.data.lower())
|
||||
|
||||
|
||||
def next_layer(
|
||||
layer: typing.Union[typing.Type[Layer], typing.Callable[[context.Context], Layer]]
|
||||
) -> events.HookReply:
|
||||
"""
|
||||
Helper function to simplify the syntax for next_layer events from this:
|
||||
|
||||
<< commands.Hook("next_layer", next_layer)
|
||||
)
|
||||
next_layer().layer = tutils.EchoLayer(next_layer().context)
|
||||
assert (
|
||||
playbook
|
||||
>> events.HookReply(-1)
|
||||
|
||||
to this:
|
||||
|
||||
<< commands.Hook("next_layer", next_layer)
|
||||
>> tutils.next_layer(next_layer, tutils.EchoLayer)
|
||||
"""
|
||||
if isinstance(layer, type):
|
||||
def make_layer(ctx: context.Context) -> Layer:
|
||||
return layer(ctx)
|
||||
else:
|
||||
make_layer = layer
|
||||
|
||||
def set_layer(playbook: playbook) -> None:
|
||||
last_command = playbook.actual[-1]
|
||||
assert isinstance(last_command, commands.Hook)
|
||||
assert isinstance(last_command.data, NextLayer)
|
||||
last_command.data.layer = make_layer(last_command.data.context)
|
||||
|
||||
reply = events.HookReply(-1)
|
||||
reply._playbook_eval = set_layer
|
||||
return reply
|
||||
|
Loading…
Reference in New Issue
Block a user