[sans-io] minor test improvements

This commit is contained in:
Maximilian Hils 2018-05-22 13:20:35 +02:00
parent 8938aec2c0
commit a860fe4a4b
2 changed files with 218 additions and 208 deletions

View File

@ -1,5 +1,6 @@
import os import os
import ssl import ssl
import typing
import pytest import pytest
@ -81,63 +82,34 @@ class SSLTest:
) )
def test_server_no_tls(tctx: context.Context): def _test_tls_client_server(
"""Test TLS layer without TLS""" tctx: context.Context,
layer = tls.ServerTLSLayer(tctx) alpn: typing.Optional[str]
playbook = tutils.playbook(layer) ) -> typing.Tuple[tutils.playbook[tls.ClientTLSLayer], SSLTest]:
# Handshake
assert (
playbook
>> events.DataReceived(tctx.client, b"Hello World")
<< commands.Hook("next_layer", tutils.Placeholder())
>> tutils.next_layer(tutils.EchoLayer)
<< commands.SendData(tctx.client, b"hello world")
)
def test_client_tls_only(tctx: context.Context):
"""Test TLS with client only"""
layer = tls.ClientTLSLayer(tctx) layer = tls.ClientTLSLayer(tctx)
playbook = tutils.playbook(layer) playbook = tutils.playbook(layer)
tssl = SSLTest() tctx.server.tls = True
tctx.server.address = ("example.com", 443)
tssl_client = SSLTest(alpn=alpn)
# Handshake # Handshake
assert playbook
assert layer._handle_event == layer.state_wait_for_clienthello
def interact():
data = tutils.Placeholder()
assert ( assert (
playbook playbook
>> events.DataReceived(tctx.client, tssl.out.read()) << None
<< commands.SendData(tctx.client, data)
) )
tssl.inc.write(data())
try:
tssl.obj.do_handshake()
except ssl.SSLWantReadError:
return False
else:
return True
# receive ClientHello, send ServerHello
with pytest.raises(ssl.SSLWantReadError): with pytest.raises(ssl.SSLWantReadError):
tssl.obj.do_handshake() tssl_client.obj.do_handshake()
assert not interact() client_hello = tssl_client.out.read()
# Finish Handshake
assert interact()
tssl.obj.do_handshake()
assert layer._handle_event == layer.state_process
# Echo
echo(playbook, tssl, tctx.client)
assert ( assert (
playbook playbook
>> events.DataReceived(tctx.server, b"Hello") >> events.DataReceived(tctx.client, client_hello[:42])
<< commands.SendData(tctx.server, b"hello") << None
) )
# Still waiting...
# Finish sending ClientHello
playbook >> events.DataReceived(tctx.client, client_hello[42:])
return playbook, tssl_client
def echo(playbook: tutils.playbook, tssl: SSLTest, conn: context.Connection) -> None: def echo(playbook: tutils.playbook, tssl: SSLTest, conn: context.Connection) -> None:
@ -154,7 +126,22 @@ def echo(playbook: tutils.playbook, tssl: SSLTest, conn: context.Connection) ->
assert tssl.obj.read() == b"hello world" assert tssl.obj.read() == b"hello world"
def test_server_tls_no_conn(tctx): class TestServerTLS:
def test_no_tls(self, tctx: context.Context):
"""Test TLS layer without TLS"""
layer = tls.ServerTLSLayer(tctx)
playbook = tutils.playbook(layer)
# Handshake
assert (
playbook
>> events.DataReceived(tctx.client, b"Hello World")
<< commands.Hook("next_layer", tutils.Placeholder())
>> tutils.next_layer(tutils.EchoLayer)
<< commands.SendData(tctx.client, b"hello world")
)
def test_no_connection(self, tctx):
""" """
The server TLS layer is initiated, but there is no active connection yet, so nothing The server TLS layer is initiated, but there is no active connection yet, so nothing
should be done. should be done.
@ -169,8 +156,7 @@ def test_server_tls_no_conn(tctx):
<< None << None
) )
def test_simple(self, tctx):
def test_server_tls(tctx):
layer = tls.ServerTLSLayer(tctx) layer = tls.ServerTLSLayer(tctx)
playbook = tutils.playbook(layer) playbook = tutils.playbook(layer)
tctx.server.connected = True tctx.server.connected = True
@ -213,34 +199,51 @@ def test_server_tls(tctx):
echo(playbook, tssl, tctx.server) echo(playbook, tssl, tctx.server)
def _test_tls_client_server(tctx, alpn): class TestClientTLS:
def test_simple(self, tctx: context.Context):
"""Test TLS with client only"""
layer = tls.ClientTLSLayer(tctx) layer = tls.ClientTLSLayer(tctx)
playbook = tutils.playbook(layer) playbook = tutils.playbook(layer)
tctx.server.tls = True tssl = SSLTest()
tctx.server.address = ("example.com", 443)
tssl_client = SSLTest(alpn=alpn)
# Handshake # Handshake
assert playbook
assert layer._handle_event == layer.state_wait_for_clienthello
def interact():
data = tutils.Placeholder()
assert ( assert (
playbook playbook
<< None >> events.DataReceived(tctx.client, tssl.out.read())
<< commands.SendData(tctx.client, data)
) )
tssl.inc.write(data())
try:
tssl.obj.do_handshake()
except ssl.SSLWantReadError:
return False
else:
return True
# receive ClientHello, send ServerHello
with pytest.raises(ssl.SSLWantReadError): with pytest.raises(ssl.SSLWantReadError):
tssl_client.obj.do_handshake() tssl.obj.do_handshake()
client_hello = tssl_client.out.read() assert not interact()
# Finish Handshake
assert interact()
tssl.obj.do_handshake()
assert layer._handle_event == layer.state_process
# Echo
echo(playbook, tssl, tctx.client)
assert ( assert (
playbook playbook
>> events.DataReceived(tctx.client, client_hello[:42]) >> events.DataReceived(tctx.server, b"Hello")
<< None << commands.SendData(tctx.server, b"hello")
) )
# Still waiting...
# Finish sending ClientHello
playbook >> events.DataReceived(tctx.client, client_hello[42:])
return playbook, tssl_client
def test_no_server_conn_required(self, tctx):
def test_tls_client_server_no_server_conn(tctx):
""" """
Here we test the scenario where a server connection is _not_ required Here we test the scenario where a server connection is _not_ required
to establish TLS with the client. After determining this when parsing the ClientHello, to establish TLS with the client. After determining this when parsing the ClientHello,
@ -256,8 +259,7 @@ def test_tls_client_server_no_server_conn(tctx):
assert data() assert data()
assert playbook.layer._handle_event == playbook.layer.state_process assert playbook.layer._handle_event == playbook.layer.state_process
def test_alpn(self, tctx):
def test_tls_client_server_alpn(tctx):
""" """
Here we test the scenario where a server connection is required (e.g. because of ALPN negotation) Here we test the scenario where a server connection is required (e.g. because of ALPN negotation)
to establish TLS with the client. to establish TLS with the client.

View File

@ -30,9 +30,9 @@ def _eq(
x, y = a[k], b[k] x, y = a[k], b[k]
# if there's a placeholder, make it x. # if there's a placeholder, make it x.
if isinstance(y, Placeholder): if isinstance(y, _Placeholder):
x, y = y, x x, y = y, x
if isinstance(x, Placeholder): if isinstance(x, _Placeholder):
if x.obj is None: if x.obj is None:
x.obj = y x.obj = y
x = x.obj x = x.obj
@ -56,7 +56,11 @@ def eq(
return _eq(a, b) return _eq(a, b)
class playbook: T = typing.TypeVar('T', bound=Layer)
# noinspection PyPep8Naming
class playbook(typing.Generic[T]):
""" """
Assert that a layer emits the expected commands in reaction to a given sequence of events. Assert that a layer emits the expected commands in reaction to a given sequence of events.
For example, the following code asserts that the TCP layer emits an OpenConnection command For example, the following code asserts that the TCP layer emits an OpenConnection command
@ -76,7 +80,7 @@ class playbook:
x2 = list(t.handle_event(events.OpenConnectionReply(x1[-1]))) x2 = list(t.handle_event(events.OpenConnectionReply(x1[-1])))
assert x2 == [] assert x2 == []
""" """
layer: Layer layer: T
"""The base layer""" """The base layer"""
expected: TPlaybook expected: TPlaybook
"""expected command/event sequence""" """expected command/event sequence"""
@ -89,7 +93,7 @@ class playbook:
def __init__( def __init__(
self, self,
layer: Layer, layer: T,
expected: typing.Optional[TPlaybook] = None, expected: typing.Optional[TPlaybook] = None,
ignore_log: bool = True ignore_log: bool = True
): ):
@ -176,7 +180,7 @@ class playbook:
return copy.deepcopy(self) return copy.deepcopy(self)
class Placeholder: class _Placeholder:
""" """
Placeholder value in playbooks, so that objects (flows in particular) can be referenced before Placeholder value in playbooks, so that objects (flows in particular) can be referenced before
they are known. Example: they are known. Example:
@ -205,6 +209,10 @@ class Placeholder:
return f"Placeholder:{str(self.obj)}" return f"Placeholder:{str(self.obj)}"
def Placeholder() -> typing.Any:
return _Placeholder()
class EchoLayer(Layer): class EchoLayer(Layer):
"""Echo layer that sends all data back to the client in lowercase.""" """Echo layer that sends all data back to the client in lowercase."""