[sans-io] adjust tls tests

This commit is contained in:
Maximilian Hils 2018-05-16 19:14:36 +02:00
parent 25373672c5
commit 8938aec2c0
6 changed files with 139 additions and 103 deletions

View File

@ -1,7 +1,6 @@
from mitmproxy import ctx
from mitmproxy.net import server_spec
from mitmproxy.proxy.config import HostMatcher
from mitmproxy.net.tls import is_tls_record_magic
from mitmproxy.proxy.config import HostMatcher
from mitmproxy.proxy.protocol.http import HTTPMode
from mitmproxy.proxy2 import layer, layers, context
@ -16,20 +15,6 @@ class NextLayer:
if "tcp_hosts" in updated:
self.check_tcp = HostMatcher(ctx.options.tcp_hosts)
def make_top_layer(self, context):
if ctx.options.mode == "regular":
return layers.modes.HttpProxy(context)
elif ctx.options.mode == "transparent":
raise NotImplementedError("Mode not implemented.")
elif ctx.options.mode == "socks5":
raise NotImplementedError("Mode not implemented.")
elif ctx.options.mode.startswith("reverse:"):
return layers.modes.ReverseProxy(context)
elif ctx.options.mode.startswith("upstream:"):
raise NotImplementedError("Mode not implemented.")
else:
raise NotImplementedError("Mode not implemented.")
def next_layer(self, nextlayer: layer.NextLayer):
nextlayer.layer = self._next_layer(nextlayer, nextlayer.context)
@ -90,3 +75,17 @@ class NextLayer:
# 8. Assume HTTP1 by default.
return layers.HTTPLayer(context, HTTPMode.transparent)
def make_top_layer(self, context):
if ctx.options.mode == "regular":
return layers.modes.HttpProxy(context)
elif ctx.options.mode == "transparent":
raise NotImplementedError("Mode not implemented.")
elif ctx.options.mode == "socks5":
raise NotImplementedError("Mode not implemented.")
elif ctx.options.mode.startswith("reverse:"):
return layers.modes.ReverseProxy(context)
elif ctx.options.mode.startswith("upstream:"):
raise NotImplementedError("Mode not implemented.")
else:
raise NotImplementedError("Mode not implemented.")

View File

@ -8,7 +8,7 @@ The counterpart to commands are events.
"""
import typing
from mitmproxy.proxy2.context import Connection
from mitmproxy.proxy2.context import Connection, Server
class Command:
@ -61,6 +61,7 @@ class OpenConnection(ConnectionCommand):
"""
Open a new connection
"""
connection: Server
blocking = True
@ -88,9 +89,7 @@ class Log(Command):
message: str
level: str
def __init__(self, message, level="info"):
assert isinstance(message, str)
assert isinstance(level, str)
def __init__(self, message: str, level: str="info"):
self.message = message
self.level = level

View File

@ -304,6 +304,9 @@ class ClientTLSLayer(_TLSLayer):
else:
yield from self.start_negotiate()
self._handle_event = self.state_process
# In any case, we now have enough information to start server TLS if needed.
yield from self.child_layer.handle_event(events.Start())
else:
raise NotImplementedError(event) # TODO
@ -318,7 +321,7 @@ class ClientTLSLayer(_TLSLayer):
def state_process(self, event: events.Event):
if isinstance(event, events.DataReceived) and event.connection == self.context.client:
if not event.connection.tls_established:
if not self.context.client.tls_established:
yield from self.negotiate(event)
else:
yield from self.relay(event)
@ -342,12 +345,7 @@ class ClientTLSLayer(_TLSLayer):
if not (x.startswith(b"h2-") or x.startswith(b"spdy"))
]
yield from self.child_layer.handle_event(events.Start())
def start_negotiate(self):
if not self.child_layer:
yield from self.child_layer.handle_event(events.Start())
# FIXME: Do this properly
client = self.context.client
server = self.context.server

View File

@ -8,6 +8,5 @@ from mitmproxy.proxy2 import context
def tctx():
return context.Context(
context.Client(("client", 1234)),
context.Server(("server", 42)),
options.Options()
)

View File

@ -81,37 +81,30 @@ class SSLTest:
)
def test_no_tls(tctx: context.Context):
def test_server_no_tls(tctx: context.Context):
"""Test TLS layer without TLS"""
layer = tls.TLSLayer(tctx)
layer = tls.ServerTLSLayer(tctx)
playbook = tutils.playbook(layer)
next_layer = tutils.Placeholder()
# Handshake
assert (
playbook
>> events.DataReceived(tctx.client, b"Hello World")
<< commands.Hook("next_layer", next_layer)
)
next_layer().layer = tutils.EchoLayer(next_layer().context)
assert (
playbook
>> events.HookReply(-1)
<< commands.Hook("next_layer", tutils.Placeholder())
>> tutils.next_layer(tutils.EchoLayer)
<< commands.SendData(tctx.client, b"hello world")
)
def test_client_tls(tctx: context.Context):
def test_client_tls_only(tctx: context.Context):
"""Test TLS with client only"""
layer = tls.TLSLayer(tctx)
layer = tls.ClientTLSLayer(tctx)
playbook = tutils.playbook(layer)
tctx.client.tls = True
tssl = SSLTest()
# Handshake
assert playbook
assert layer.state[tctx.client] == tls.ConnectionState.NEGOTIATING
assert layer.state[tctx.server] == tls.ConnectionState.NO_TLS
assert layer._handle_event == layer.state_wait_for_clienthello
def interact():
data = tutils.Placeholder()
@ -136,29 +129,25 @@ def test_client_tls(tctx: context.Context):
assert interact()
tssl.obj.do_handshake()
assert layer.state[tctx.client] == tls.ConnectionState.ESTABLISHED
assert layer.state[tctx.server] == tls.ConnectionState.NO_TLS
assert layer._handle_event == layer.state_process
# Echo
echo(playbook, tssl, tctx.client)
def echo(playbook, tssl, conn):
tconn = type(conn).__name__.lower()
tssl.obj.write(b"Hello World")
next_layer = tutils.Placeholder()
assert (
playbook
>> events.DataReceived(conn, tssl.out.read())
<< commands.Log(f"PlainDataReceived({tconn}, b'Hello World')")
<< commands.Hook("next_layer", next_layer)
>> events.DataReceived(tctx.server, b"Hello")
<< commands.SendData(tctx.server, b"hello")
)
next_layer().layer = tutils.EchoLayer(next_layer().context)
def echo(playbook: tutils.playbook, tssl: SSLTest, conn: context.Connection) -> None:
tssl.obj.write(b"Hello World")
data = tutils.Placeholder()
assert (
playbook
>> events.HookReply(-1)
<< commands.Log(f"PlainSendData({tconn}, b'hello world')")
>> events.DataReceived(conn, tssl.out.read())
<< commands.Hook("next_layer", tutils.Placeholder())
>> tutils.next_layer(tutils.EchoLayer)
<< commands.SendData(conn, data)
)
tssl.inc.write(data())
@ -166,20 +155,26 @@ def echo(playbook, tssl, conn):
def test_server_tls_no_conn(tctx):
layer = tls.TLSLayer(tctx)
"""
The server TLS layer is initiated, but there is no active connection yet, so nothing
should be done.
"""
layer = tls.ServerTLSLayer(tctx)
playbook = tutils.playbook(layer)
tctx.server.tls = True
# We did not have a server connection before, so let's do nothing.
assert playbook
assert layer.state[tctx.client] == tls.ConnectionState.NO_TLS
assert layer.state[tctx.server] == tls.ConnectionState.NO_TLS
assert (
playbook
<< None
)
def test_server_tls(tctx):
layer = tls.TLSLayer(tctx)
layer = tls.ServerTLSLayer(tctx)
playbook = tutils.playbook(layer)
tctx.server.connected = True
tctx.server.address = ("example.com", 443)
tctx.server.tls = True
tssl = SSLTest(server_side=True)
@ -190,8 +185,6 @@ def test_server_tls(tctx):
playbook
<< commands.SendData(tctx.server, data)
)
assert layer.state[tctx.client] == tls.ConnectionState.NO_TLS
assert layer.state[tctx.server] == tls.ConnectionState.NEGOTIATING
# receive ServerHello, finish client handshake
tssl.inc.write(data())
@ -213,24 +206,25 @@ def test_server_tls(tctx):
<< None
)
assert layer.state[tctx.client] == tls.ConnectionState.NO_TLS
assert layer.state[tctx.server] == tls.ConnectionState.ESTABLISHED
assert tctx.server.tls_established
assert tctx.server.sni == b"example.com"
# Echo
echo(playbook, tssl, tctx.server)
def _test_tls_client_server(tctx, alpn):
layer = tls.TLSLayer(tctx)
layer = tls.ClientTLSLayer(tctx)
playbook = tutils.playbook(layer)
tctx.client.tls = True
tctx.server.tls = True
tctx.server.address = ("example.com", 443)
tssl_client = SSLTest(alpn=alpn)
# Handshake
assert playbook
assert layer.state[tctx.client] == tls.ConnectionState.WAIT_FOR_CLIENTHELLO
assert layer.state[tctx.server] == tls.ConnectionState.WAIT_FOR_CLIENTHELLO
assert (
playbook
<< None
)
with pytest.raises(ssl.SSLWantReadError):
tssl_client.obj.do_handshake()
@ -241,9 +235,6 @@ def _test_tls_client_server(tctx, alpn):
<< None
)
# Still waiting...
assert layer.state[tctx.client] == tls.ConnectionState.WAIT_FOR_CLIENTHELLO
assert layer.state[tctx.server] == tls.ConnectionState.WAIT_FOR_CLIENTHELLO
# Finish sending ClientHello
playbook >> events.DataReceived(tctx.client, client_hello[42:])
return playbook, tssl_client
@ -263,8 +254,7 @@ def test_tls_client_server_no_server_conn(tctx):
<< commands.SendData(tctx.client, data)
)
assert data()
assert playbook.layer.state[tctx.client] == tls.ConnectionState.NEGOTIATING
assert playbook.layer.state[tctx.server] == tls.ConnectionState.NO_TLS
assert playbook.layer._handle_event == playbook.layer.state_process
def test_tls_client_server_alpn(tctx):
@ -288,8 +278,8 @@ def test_tls_client_server_alpn(tctx):
>> events.OpenConnectionReply(-1, None)
<< commands.SendData(tctx.server, data)
)
assert playbook.layer.state[tctx.client] == tls.ConnectionState.WAIT_FOR_SERVER_TLS
assert playbook.layer.state[tctx.server] == tls.ConnectionState.NEGOTIATING
assert playbook.layer._handle_event == playbook.layer.state_wait_for_server_tls
assert playbook.layer.child_layer.tls[tctx.server]
# Establish TLS with the server...
tssl_server.inc.write(data())
@ -310,8 +300,8 @@ def test_tls_client_server_alpn(tctx):
<< commands.SendData(tctx.client, data)
)
assert playbook.layer.state[tctx.client] == tls.ConnectionState.NEGOTIATING
assert playbook.layer.state[tctx.server] == tls.ConnectionState.ESTABLISHED
assert playbook.layer._handle_event == playbook.layer.state_process
assert tctx.server.tls_established
# Server TLS is established, we can now reply to the client handshake...
tssl_client.inc.write(data())
@ -327,8 +317,8 @@ def test_tls_client_server_alpn(tctx):
tssl_client.obj.do_handshake()
# Both handshakes completed!
assert playbook.layer.state[tctx.client] == tls.ConnectionState.ESTABLISHED
assert playbook.layer.state[tctx.server] == tls.ConnectionState.ESTABLISHED
assert tctx.client.tls_established
assert tctx.server.tls_established
assert tssl_client.obj.selected_alpn_protocol() == "foo"
assert tssl_server.obj.selected_alpn_protocol() == "foo"

View File

@ -1,13 +1,13 @@
import collections
import copy
import difflib
import itertools
import typing
from mitmproxy.proxy2 import commands
import collections
from mitmproxy.proxy2 import commands, context
from mitmproxy.proxy2 import events
from mitmproxy.proxy2 import layer
from mitmproxy.proxy2.layer import Layer
from mitmproxy.proxy2.layer import Layer, NextLayer
TPlaybookEntry = typing.Union[commands.Command, events.Event]
TPlaybook = typing.List[TPlaybookEntry]
@ -76,7 +76,7 @@ class playbook:
x2 = list(t.handle_event(events.OpenConnectionReply(x1[-1])))
assert x2 == []
"""
layer: layer.Layer
layer: Layer
"""The base layer"""
expected: TPlaybook
"""expected command/event sequence"""
@ -84,11 +84,14 @@ class playbook:
"""actual command/event sequence"""
_errored: bool
"""used to check if playbook as been fully asserted"""
ignore_log: bool
"""If True, log statements are ignored."""
def __init__(
self,
layer,
expected=None,
layer: Layer,
expected: typing.Optional[TPlaybook]=None,
ignore_log: bool=True
):
if expected is None:
expected = [
@ -99,6 +102,7 @@ class playbook:
self.expected = expected
self.actual = []
self._errored = False
self.ignore_log = ignore_log
def __rshift__(self, e):
"""Add an event to send"""
@ -111,6 +115,7 @@ class playbook:
if c is None:
return self
assert isinstance(c, commands.Command)
assert not (self.ignore_log and isinstance(c, commands.Log))
self.expected.append(c)
return self
@ -124,12 +129,19 @@ class playbook:
if isinstance(x, events.CommandReply):
if isinstance(x.command, int) and abs(x.command) < len(self.actual):
x.command = self.actual[x.command]
if hasattr(x, "_playbook_eval"):
x._playbook_eval(self)
self.actual.append(x)
self.actual.extend(
self.layer.handle_event(x)
)
if self.ignore_log:
self.actual = [
x for x in self.actual if not isinstance(x, commands.Log)
]
if not eq(self.expected, self.actual):
self._errored = True
@ -189,9 +201,48 @@ class Placeholder:
def __repr__(self):
return f"Placeholder:{repr(self.obj)}"
def __str__(self):
return f"Placeholder:{str(self.obj)}"
class EchoLayer(Layer):
"""Echo layer that sends all data back to the client in lowercase."""
def _handle_event(self, event: events.Event):
if isinstance(event, events.DataReceived):
yield commands.SendData(event.connection, event.data.lower())
def next_layer(
layer: typing.Union[typing.Type[Layer], typing.Callable[[context.Context], Layer]]
) -> events.HookReply:
"""
Helper function to simplify the syntax for next_layer events from this:
<< commands.Hook("next_layer", next_layer)
)
next_layer().layer = tutils.EchoLayer(next_layer().context)
assert (
playbook
>> events.HookReply(-1)
to this:
<< commands.Hook("next_layer", next_layer)
>> tutils.next_layer(next_layer, tutils.EchoLayer)
"""
if isinstance(layer, type):
def make_layer(ctx: context.Context) -> Layer:
return layer(ctx)
else:
make_layer = layer
def set_layer(playbook: playbook) -> None:
last_command = playbook.actual[-1]
assert isinstance(last_command, commands.Hook)
assert isinstance(last_command.data, NextLayer)
last_command.data.layer = make_layer(last_command.data.context)
reply = events.HookReply(-1)
reply._playbook_eval = set_layer
return reply