mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-22 15:37:45 +00:00
[sans-io] test coverage++
This commit is contained in:
parent
64d2ac8ef3
commit
4c75765387
@ -85,7 +85,7 @@ class ReplayHandler(server.ConnectionHandler):
|
||||
ctx.log(f"[replay] {message}", level)
|
||||
|
||||
async def handle_hook(self, hook: commands.Hook) -> None:
|
||||
data, = hook.as_tuple()
|
||||
data, = hook.args()
|
||||
data.reply = AsyncReply(data)
|
||||
await ctx.master.addons.handle_lifecycle(hook.name, data)
|
||||
await data.reply.done.wait()
|
||||
|
@ -43,7 +43,7 @@ class ProxyConnectionHandler(server.StreamConnectionHandler):
|
||||
async def handle_hook(self, hook: commands.Hook) -> None:
|
||||
with self.timeout_watchdog.disarm():
|
||||
# We currently only support single-argument hooks.
|
||||
data, = hook.as_tuple()
|
||||
data, = hook.args()
|
||||
data.reply = AsyncReply(data)
|
||||
await self.master.addons.handle_lifecycle(hook.name, data)
|
||||
await data.reply.done.wait()
|
||||
|
@ -114,7 +114,7 @@ class Hook(Command):
|
||||
def __repr__(self):
|
||||
return f"Hook({self.name})"
|
||||
|
||||
def as_tuple(self) -> List[Any]:
|
||||
def args(self) -> List[Any]:
|
||||
args = []
|
||||
# noinspection PyDataclass
|
||||
for field in dataclasses.fields(self):
|
||||
|
@ -80,6 +80,14 @@ class Connection(serializable.Serializable, metaclass=ABCMeta):
|
||||
def tls_established(self) -> bool:
|
||||
return self.timestamp_tls_setup is not None
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Connection):
|
||||
return self.id == other.id
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.id)
|
||||
|
||||
def __repr__(self):
|
||||
attrs = repr({
|
||||
k: {"cipher_list": lambda: f"<{len(v)} ciphers>"}.get(k, lambda: v)()
|
||||
|
@ -36,13 +36,6 @@ class ConnectionEvent(Event):
|
||||
connection: Connection
|
||||
|
||||
|
||||
class ConnectionClosed(ConnectionEvent):
|
||||
"""
|
||||
Remote has closed a connection.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataReceived(ConnectionEvent):
|
||||
"""
|
||||
@ -55,6 +48,13 @@ class DataReceived(ConnectionEvent):
|
||||
return f"DataReceived({target}, {self.data})"
|
||||
|
||||
|
||||
class ConnectionClosed(ConnectionEvent):
|
||||
"""
|
||||
Remote has closed a connection.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class CommandReply(Event):
|
||||
"""
|
||||
Emitted when a command has been finished, e.g.
|
||||
@ -64,41 +64,47 @@ class CommandReply(Event):
|
||||
reply: typing.Any
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
assert is_dataclass(cls)
|
||||
if cls is CommandReply:
|
||||
raise TypeError("CommandReply may not be instantiated directly.")
|
||||
assert is_dataclass(cls)
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
command_cls = cls.__annotations__["command"]
|
||||
if not issubclass(command_cls, commands.Command) and command_cls is not commands.Command:
|
||||
valid_command_subclass = (
|
||||
issubclass(command_cls, commands.Command) and command_cls is not commands.Command
|
||||
)
|
||||
if not valid_command_subclass:
|
||||
raise RuntimeError(f"{command_cls} needs a properly annotated command attribute.")
|
||||
if command_cls in command_reply_subclasses:
|
||||
other = command_reply_subclasses[command_cls]
|
||||
raise RuntimeError(f"Two conflicting subclasses for {command_cls}: {cls} and {other}")
|
||||
command_reply_subclasses[command_cls] = cls
|
||||
|
||||
def __repr__(self):
|
||||
return f"Reply({repr(self.command)})"
|
||||
|
||||
|
||||
command_reply_subclasses: typing.Dict[commands.Command, typing.Type[CommandReply]] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(repr=False)
|
||||
class OpenConnectionReply(CommandReply):
|
||||
command: commands.OpenConnection
|
||||
reply: typing.Optional[str]
|
||||
"""error message"""
|
||||
|
||||
def __repr__(self):
|
||||
return super().__repr__()
|
||||
|
||||
@dataclass
|
||||
|
||||
@dataclass(repr=False)
|
||||
class HookReply(CommandReply):
|
||||
command: commands.Hook
|
||||
reply: None = None
|
||||
|
||||
def __repr__(self):
|
||||
return f"HookReply({repr(self.command)[5:-1]})"
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(repr=False)
|
||||
class GetSocketReply(CommandReply):
|
||||
command: commands.GetSocket
|
||||
reply: socket.socket
|
||||
|
@ -357,7 +357,7 @@ class SimpleConnectionHandler(StreamConnectionHandler):
|
||||
hook: commands.Hook
|
||||
) -> None:
|
||||
if hook.name in self.hook_handlers:
|
||||
self.hook_handlers[hook.name](*hook.as_tuple())
|
||||
self.hook_handlers[hook.name](*hook.args())
|
||||
|
||||
def log(self, message: str, level: str = "info"):
|
||||
if "Hook" not in message:
|
||||
|
@ -170,8 +170,9 @@ def tclient_conn() -> compat.Client:
|
||||
cipher_list=[],
|
||||
))
|
||||
c.reply = controller.DummyReply()
|
||||
c.rfile = io.BytesIO()
|
||||
c.wfile = io.BytesIO()
|
||||
if not compat.new_proxy_core:
|
||||
c.rfile = io.BytesIO()
|
||||
c.wfile = io.BytesIO()
|
||||
return c
|
||||
|
||||
|
||||
|
@ -9,7 +9,7 @@ import sys
|
||||
def check_src_files_have_test():
|
||||
missing_test_files = []
|
||||
|
||||
excluded = ['mitmproxy/contrib/', 'mitmproxy/io/proto/',
|
||||
excluded = ['mitmproxy/contrib/', 'mitmproxy/io/proto/', 'mitmproxy/proxy2/layers/http',
|
||||
'mitmproxy/test/', 'mitmproxy/tools/', 'mitmproxy/platform/']
|
||||
src_files = glob.glob('mitmproxy/**/*.py', recursive=True) + glob.glob('pathod/**/*.py', recursive=True)
|
||||
src_files = [f for f in src_files if os.path.basename(f) != '__init__.py']
|
||||
|
@ -1,4 +1,5 @@
|
||||
import ssl
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
@ -166,6 +167,29 @@ class TestTlsConfig:
|
||||
tssl_server = test_tls.SSLTest(server_side=True)
|
||||
assert self.do_handshake(tssl_client, tssl_server)
|
||||
|
||||
def test_alpn_selection(self):
|
||||
ta = tlsconfig.TlsConfig()
|
||||
with taddons.context(ta) as tctx:
|
||||
ctx = context.Context(context.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options)
|
||||
ctx.server.address = ("example.mitmproxy.org", 443)
|
||||
tls_start = tls.TlsStartData(ctx.server, context=ctx)
|
||||
|
||||
def assert_alpn(http2, client_offers, expected):
|
||||
tctx.configure(ta, http2=http2)
|
||||
ctx.client.alpn_offers = client_offers
|
||||
ctx.server.alpn_offers = None
|
||||
ta.tls_start(tls_start)
|
||||
assert ctx.server.alpn_offers == expected
|
||||
|
||||
assert_alpn(True, tls.HTTP_ALPNS + (b"foo",), tls.HTTP_ALPNS + (b"foo",))
|
||||
assert_alpn(False, tls.HTTP_ALPNS + (b"foo",), tls.HTTP1_ALPNS + (b"foo",))
|
||||
assert_alpn(True, [], tls.HTTP_ALPNS)
|
||||
assert_alpn(False, [], tls.HTTP1_ALPNS)
|
||||
ctx.client.timestamp_tls_setup = time.time()
|
||||
# make sure that we don't upgrade h1 to h2,
|
||||
# see comment in tlsconfig.py
|
||||
assert_alpn(True, [], [])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"client_certs",
|
||||
[
|
||||
|
@ -0,0 +1,33 @@
|
||||
import pytest
|
||||
|
||||
from mitmproxy.proxy2 import commands, context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tconn() -> context.Server:
|
||||
return context.Server(None)
|
||||
|
||||
|
||||
def test_dataclasses(tconn):
|
||||
assert repr(commands.SendData(tconn, b"foo"))
|
||||
assert repr(commands.OpenConnection(tconn))
|
||||
assert repr(commands.CloseConnection(tconn))
|
||||
assert repr(commands.GetSocket(tconn))
|
||||
assert repr(commands.Log("hello", "info"))
|
||||
|
||||
|
||||
def test_hook():
|
||||
with pytest.raises(TypeError):
|
||||
commands.Hook()
|
||||
|
||||
class FooHook(commands.Hook):
|
||||
data: bytes
|
||||
|
||||
f = FooHook(b"foo")
|
||||
assert repr(f)
|
||||
assert f.args() == [b"foo"]
|
||||
assert FooHook in commands.all_hooks.values()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Two conflicting hooks"):
|
||||
class FooHook2(commands.Hook):
|
||||
name = "foo"
|
@ -0,0 +1,82 @@
|
||||
from mitmproxy.proxy2 import context
|
||||
from mitmproxy.test import tflow, taddons
|
||||
|
||||
|
||||
class TestConnection:
|
||||
def test_basic(self):
|
||||
c = context.Client(
|
||||
("127.0.0.1", 52314),
|
||||
("127.0.0.1", 8080),
|
||||
1607780791
|
||||
)
|
||||
assert not c.tls_established
|
||||
c.timestamp_tls_setup = 1607780792
|
||||
assert c.tls_established
|
||||
assert c.connected
|
||||
c.state = context.ConnectionState.CAN_WRITE
|
||||
assert not c.connected
|
||||
|
||||
def test_eq(self):
|
||||
c = tflow.tclient_conn()
|
||||
c2 = c.copy()
|
||||
assert c == c
|
||||
assert c != c2
|
||||
assert c != 42
|
||||
assert hash(c) != hash(c2)
|
||||
|
||||
c2.id = c.id
|
||||
assert c == c2
|
||||
|
||||
|
||||
class TestClient:
|
||||
def test_basic(self):
|
||||
c = context.Client(
|
||||
("127.0.0.1", 52314),
|
||||
("127.0.0.1", 8080),
|
||||
1607780791
|
||||
)
|
||||
assert repr(c)
|
||||
|
||||
def test_state(self):
|
||||
c = tflow.tclient_conn()
|
||||
assert context.Client.from_state(c.get_state()).get_state() == c.get_state()
|
||||
|
||||
c2 = tflow.tclient_conn()
|
||||
assert c != c2
|
||||
|
||||
c2.timestamp_start = 42
|
||||
c.set_state(c2.get_state())
|
||||
assert c.timestamp_start == 42
|
||||
|
||||
c3 = c.copy()
|
||||
assert c3.get_state() != c.get_state()
|
||||
c.id = c3.id = "foo"
|
||||
assert c3.get_state() == c.get_state()
|
||||
|
||||
|
||||
class TestServer:
|
||||
def test_basic(self):
|
||||
s = context.Server(("address", 22))
|
||||
assert repr(s)
|
||||
|
||||
def test_state(self):
|
||||
c = tflow.tserver_conn()
|
||||
c2 = c.copy()
|
||||
assert c2.get_state() != c.get_state()
|
||||
c.id = c2.id = "foo"
|
||||
assert c2.get_state() == c.get_state()
|
||||
|
||||
|
||||
def test_context():
|
||||
with taddons.context() as tctx:
|
||||
c = context.Context(
|
||||
tflow.tclient_conn(),
|
||||
tctx.options
|
||||
)
|
||||
assert repr(c)
|
||||
c.layers.append(1)
|
||||
c2 = c.fork()
|
||||
c.layers.append(2)
|
||||
c2.layers.append(3)
|
||||
assert c.layers == [1, 2]
|
||||
assert c2.layers == [1, 3]
|
@ -0,0 +1,36 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from mitmproxy.proxy2 import events, context, commands
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tconn() -> context.Server:
|
||||
return context.Server(None)
|
||||
|
||||
|
||||
def test_dataclasses(tconn):
|
||||
assert repr(events.Start())
|
||||
assert repr(events.DataReceived(tconn, b"foo"))
|
||||
assert repr(events.ConnectionClosed(tconn))
|
||||
|
||||
|
||||
def test_commandreply():
|
||||
with pytest.raises(TypeError):
|
||||
events.CommandReply()
|
||||
assert repr(events.HookReply(Mock(), None))
|
||||
|
||||
class FooCommand(commands.Command):
|
||||
pass
|
||||
|
||||
with pytest.raises(RuntimeError, match="properly annotated"):
|
||||
class FooReply(events.CommandReply):
|
||||
pass
|
||||
|
||||
class FooReply1(events.CommandReply):
|
||||
command: FooCommand
|
||||
|
||||
with pytest.raises(RuntimeError, match="conflicting subclasses"):
|
||||
class FooReply2(events.CommandReply):
|
||||
command: FooCommand
|
5
test/mitmproxy/proxy2/test_server_hooks.py
Normal file
5
test/mitmproxy/proxy2/test_server_hooks.py
Normal file
@ -0,0 +1,5 @@
|
||||
from mitmproxy.proxy2 import server_hooks
|
||||
|
||||
|
||||
def test_noop():
|
||||
assert server_hooks
|
0
test/mitmproxy/proxy2/test_tunnel.py
Normal file
0
test/mitmproxy/proxy2/test_tunnel.py
Normal file
@ -306,7 +306,7 @@ class reply(events.Event):
|
||||
|
||||
assert isinstance(self.to, commands.Command)
|
||||
if isinstance(self.to, commands.Hook):
|
||||
self.side_effect(*self.to.as_tuple())
|
||||
self.side_effect(*self.to.args())
|
||||
reply_cls = events.HookReply
|
||||
else:
|
||||
self.side_effect(self.to)
|
||||
|
Loading…
Reference in New Issue
Block a user