[sans-io] test coverage++

This commit is contained in:
Maximilian Hils 2020-12-12 15:58:58 +01:00
parent 64d2ac8ef3
commit 4c75765387
15 changed files with 218 additions and 23 deletions

View File

@ -85,7 +85,7 @@ class ReplayHandler(server.ConnectionHandler):
ctx.log(f"[replay] {message}", level) ctx.log(f"[replay] {message}", level)
async def handle_hook(self, hook: commands.Hook) -> None: async def handle_hook(self, hook: commands.Hook) -> None:
data, = hook.as_tuple() data, = hook.args()
data.reply = AsyncReply(data) data.reply = AsyncReply(data)
await ctx.master.addons.handle_lifecycle(hook.name, data) await ctx.master.addons.handle_lifecycle(hook.name, data)
await data.reply.done.wait() await data.reply.done.wait()

View File

@ -43,7 +43,7 @@ class ProxyConnectionHandler(server.StreamConnectionHandler):
async def handle_hook(self, hook: commands.Hook) -> None: async def handle_hook(self, hook: commands.Hook) -> None:
with self.timeout_watchdog.disarm(): with self.timeout_watchdog.disarm():
# We currently only support single-argument hooks. # We currently only support single-argument hooks.
data, = hook.as_tuple() data, = hook.args()
data.reply = AsyncReply(data) data.reply = AsyncReply(data)
await self.master.addons.handle_lifecycle(hook.name, data) await self.master.addons.handle_lifecycle(hook.name, data)
await data.reply.done.wait() await data.reply.done.wait()

View File

@ -114,7 +114,7 @@ class Hook(Command):
def __repr__(self): def __repr__(self):
return f"Hook({self.name})" return f"Hook({self.name})"
def as_tuple(self) -> List[Any]: def args(self) -> List[Any]:
args = [] args = []
# noinspection PyDataclass # noinspection PyDataclass
for field in dataclasses.fields(self): for field in dataclasses.fields(self):

View File

@ -80,6 +80,14 @@ class Connection(serializable.Serializable, metaclass=ABCMeta):
def tls_established(self) -> bool: def tls_established(self) -> bool:
return self.timestamp_tls_setup is not None return self.timestamp_tls_setup is not None
def __eq__(self, other):
if isinstance(other, Connection):
return self.id == other.id
return False
def __hash__(self):
return hash(self.id)
def __repr__(self): def __repr__(self):
attrs = repr({ attrs = repr({
k: {"cipher_list": lambda: f"<{len(v)} ciphers>"}.get(k, lambda: v)() k: {"cipher_list": lambda: f"<{len(v)} ciphers>"}.get(k, lambda: v)()

View File

@ -36,13 +36,6 @@ class ConnectionEvent(Event):
connection: Connection connection: Connection
class ConnectionClosed(ConnectionEvent):
"""
Remote has closed a connection.
"""
pass
@dataclass @dataclass
class DataReceived(ConnectionEvent): class DataReceived(ConnectionEvent):
""" """
@ -55,6 +48,13 @@ class DataReceived(ConnectionEvent):
return f"DataReceived({target}, {self.data})" return f"DataReceived({target}, {self.data})"
class ConnectionClosed(ConnectionEvent):
"""
Remote has closed a connection.
"""
pass
class CommandReply(Event): class CommandReply(Event):
""" """
Emitted when a command has been finished, e.g. Emitted when a command has been finished, e.g.
@ -64,41 +64,47 @@ class CommandReply(Event):
reply: typing.Any reply: typing.Any
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
assert is_dataclass(cls)
if cls is CommandReply: if cls is CommandReply:
raise TypeError("CommandReply may not be instantiated directly.") raise TypeError("CommandReply may not be instantiated directly.")
assert is_dataclass(cls)
return super().__new__(cls) return super().__new__(cls)
def __init_subclass__(cls, **kwargs): def __init_subclass__(cls, **kwargs):
command_cls = cls.__annotations__["command"] command_cls = cls.__annotations__["command"]
if not issubclass(command_cls, commands.Command) and command_cls is not commands.Command: valid_command_subclass = (
issubclass(command_cls, commands.Command) and command_cls is not commands.Command
)
if not valid_command_subclass:
raise RuntimeError(f"{command_cls} needs a properly annotated command attribute.") raise RuntimeError(f"{command_cls} needs a properly annotated command attribute.")
if command_cls in command_reply_subclasses: if command_cls in command_reply_subclasses:
other = command_reply_subclasses[command_cls] other = command_reply_subclasses[command_cls]
raise RuntimeError(f"Two conflicting subclasses for {command_cls}: {cls} and {other}") raise RuntimeError(f"Two conflicting subclasses for {command_cls}: {cls} and {other}")
command_reply_subclasses[command_cls] = cls command_reply_subclasses[command_cls] = cls
def __repr__(self):
return f"Reply({repr(self.command)})"
command_reply_subclasses: typing.Dict[commands.Command, typing.Type[CommandReply]] = {} command_reply_subclasses: typing.Dict[commands.Command, typing.Type[CommandReply]] = {}
@dataclass @dataclass(repr=False)
class OpenConnectionReply(CommandReply): class OpenConnectionReply(CommandReply):
command: commands.OpenConnection command: commands.OpenConnection
reply: typing.Optional[str] reply: typing.Optional[str]
"""error message""" """error message"""
def __repr__(self):
return super().__repr__()
@dataclass
@dataclass(repr=False)
class HookReply(CommandReply): class HookReply(CommandReply):
command: commands.Hook command: commands.Hook
reply: None = None reply: None = None
def __repr__(self):
return f"HookReply({repr(self.command)[5:-1]})"
@dataclass(repr=False)
@dataclass
class GetSocketReply(CommandReply): class GetSocketReply(CommandReply):
command: commands.GetSocket command: commands.GetSocket
reply: socket.socket reply: socket.socket

View File

@ -357,7 +357,7 @@ class SimpleConnectionHandler(StreamConnectionHandler):
hook: commands.Hook hook: commands.Hook
) -> None: ) -> None:
if hook.name in self.hook_handlers: if hook.name in self.hook_handlers:
self.hook_handlers[hook.name](*hook.as_tuple()) self.hook_handlers[hook.name](*hook.args())
def log(self, message: str, level: str = "info"): def log(self, message: str, level: str = "info"):
if "Hook" not in message: if "Hook" not in message:

View File

@ -170,6 +170,7 @@ def tclient_conn() -> compat.Client:
cipher_list=[], cipher_list=[],
)) ))
c.reply = controller.DummyReply() c.reply = controller.DummyReply()
if not compat.new_proxy_core:
c.rfile = io.BytesIO() c.rfile = io.BytesIO()
c.wfile = io.BytesIO() c.wfile = io.BytesIO()
return c return c

View File

@ -9,7 +9,7 @@ import sys
def check_src_files_have_test(): def check_src_files_have_test():
missing_test_files = [] missing_test_files = []
excluded = ['mitmproxy/contrib/', 'mitmproxy/io/proto/', excluded = ['mitmproxy/contrib/', 'mitmproxy/io/proto/', 'mitmproxy/proxy2/layers/http',
'mitmproxy/test/', 'mitmproxy/tools/', 'mitmproxy/platform/'] 'mitmproxy/test/', 'mitmproxy/tools/', 'mitmproxy/platform/']
src_files = glob.glob('mitmproxy/**/*.py', recursive=True) + glob.glob('pathod/**/*.py', recursive=True) src_files = glob.glob('mitmproxy/**/*.py', recursive=True) + glob.glob('pathod/**/*.py', recursive=True)
src_files = [f for f in src_files if os.path.basename(f) != '__init__.py'] src_files = [f for f in src_files if os.path.basename(f) != '__init__.py']

View File

@ -1,4 +1,5 @@
import ssl import ssl
import time
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@ -166,6 +167,29 @@ class TestTlsConfig:
tssl_server = test_tls.SSLTest(server_side=True) tssl_server = test_tls.SSLTest(server_side=True)
assert self.do_handshake(tssl_client, tssl_server) assert self.do_handshake(tssl_client, tssl_server)
def test_alpn_selection(self):
ta = tlsconfig.TlsConfig()
with taddons.context(ta) as tctx:
ctx = context.Context(context.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options)
ctx.server.address = ("example.mitmproxy.org", 443)
tls_start = tls.TlsStartData(ctx.server, context=ctx)
def assert_alpn(http2, client_offers, expected):
tctx.configure(ta, http2=http2)
ctx.client.alpn_offers = client_offers
ctx.server.alpn_offers = None
ta.tls_start(tls_start)
assert ctx.server.alpn_offers == expected
assert_alpn(True, tls.HTTP_ALPNS + (b"foo",), tls.HTTP_ALPNS + (b"foo",))
assert_alpn(False, tls.HTTP_ALPNS + (b"foo",), tls.HTTP1_ALPNS + (b"foo",))
assert_alpn(True, [], tls.HTTP_ALPNS)
assert_alpn(False, [], tls.HTTP1_ALPNS)
ctx.client.timestamp_tls_setup = time.time()
# make sure that we don't upgrade h1 to h2,
# see comment in tlsconfig.py
assert_alpn(True, [], [])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"client_certs", "client_certs",
[ [

View File

@ -0,0 +1,33 @@
import pytest
from mitmproxy.proxy2 import commands, context
@pytest.fixture
def tconn() -> context.Server:
return context.Server(None)
def test_dataclasses(tconn):
assert repr(commands.SendData(tconn, b"foo"))
assert repr(commands.OpenConnection(tconn))
assert repr(commands.CloseConnection(tconn))
assert repr(commands.GetSocket(tconn))
assert repr(commands.Log("hello", "info"))
def test_hook():
with pytest.raises(TypeError):
commands.Hook()
class FooHook(commands.Hook):
data: bytes
f = FooHook(b"foo")
assert repr(f)
assert f.args() == [b"foo"]
assert FooHook in commands.all_hooks.values()
with pytest.raises(RuntimeError, match="Two conflicting hooks"):
class FooHook2(commands.Hook):
name = "foo"

View File

@ -0,0 +1,82 @@
from mitmproxy.proxy2 import context
from mitmproxy.test import tflow, taddons
class TestConnection:
def test_basic(self):
c = context.Client(
("127.0.0.1", 52314),
("127.0.0.1", 8080),
1607780791
)
assert not c.tls_established
c.timestamp_tls_setup = 1607780792
assert c.tls_established
assert c.connected
c.state = context.ConnectionState.CAN_WRITE
assert not c.connected
def test_eq(self):
c = tflow.tclient_conn()
c2 = c.copy()
assert c == c
assert c != c2
assert c != 42
assert hash(c) != hash(c2)
c2.id = c.id
assert c == c2
class TestClient:
def test_basic(self):
c = context.Client(
("127.0.0.1", 52314),
("127.0.0.1", 8080),
1607780791
)
assert repr(c)
def test_state(self):
c = tflow.tclient_conn()
assert context.Client.from_state(c.get_state()).get_state() == c.get_state()
c2 = tflow.tclient_conn()
assert c != c2
c2.timestamp_start = 42
c.set_state(c2.get_state())
assert c.timestamp_start == 42
c3 = c.copy()
assert c3.get_state() != c.get_state()
c.id = c3.id = "foo"
assert c3.get_state() == c.get_state()
class TestServer:
def test_basic(self):
s = context.Server(("address", 22))
assert repr(s)
def test_state(self):
c = tflow.tserver_conn()
c2 = c.copy()
assert c2.get_state() != c.get_state()
c.id = c2.id = "foo"
assert c2.get_state() == c.get_state()
def test_context():
with taddons.context() as tctx:
c = context.Context(
tflow.tclient_conn(),
tctx.options
)
assert repr(c)
c.layers.append(1)
c2 = c.fork()
c.layers.append(2)
c2.layers.append(3)
assert c.layers == [1, 2]
assert c2.layers == [1, 3]

View File

@ -0,0 +1,36 @@
from unittest.mock import Mock
import pytest
from mitmproxy.proxy2 import events, context, commands
@pytest.fixture
def tconn() -> context.Server:
return context.Server(None)
def test_dataclasses(tconn):
assert repr(events.Start())
assert repr(events.DataReceived(tconn, b"foo"))
assert repr(events.ConnectionClosed(tconn))
def test_commandreply():
with pytest.raises(TypeError):
events.CommandReply()
assert repr(events.HookReply(Mock(), None))
class FooCommand(commands.Command):
pass
with pytest.raises(RuntimeError, match="properly annotated"):
class FooReply(events.CommandReply):
pass
class FooReply1(events.CommandReply):
command: FooCommand
with pytest.raises(RuntimeError, match="conflicting subclasses"):
class FooReply2(events.CommandReply):
command: FooCommand

View File

@ -0,0 +1,5 @@
from mitmproxy.proxy2 import server_hooks
def test_noop():
assert server_hooks

View File

View File

@ -306,7 +306,7 @@ class reply(events.Event):
assert isinstance(self.to, commands.Command) assert isinstance(self.to, commands.Command)
if isinstance(self.to, commands.Hook): if isinstance(self.to, commands.Hook):
self.side_effect(*self.to.as_tuple()) self.side_effect(*self.to.args())
reply_cls = events.HookReply reply_cls = events.HookReply
else: else:
self.side_effect(self.to) self.side_effect(self.to)