mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-22 15:37:45 +00:00
[sans-io] tls layer++
This commit is contained in:
parent
0c04638d8d
commit
1c80dfe17f
138
mitmproxy/addons/tlsconfig.py
Normal file
138
mitmproxy/addons/tlsconfig.py
Normal file
@ -0,0 +1,138 @@
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from OpenSSL import SSL, crypto
|
||||
|
||||
from mitmproxy import certs, ctx, exceptions
|
||||
from mitmproxy.net import tls as net_tls
|
||||
from mitmproxy.options import CONF_BASENAME
|
||||
from mitmproxy.proxy.protocol.tls import CIPHER_ID_NAME_MAP, DEFAULT_CLIENT_CIPHERS
|
||||
from mitmproxy.proxy2 import context
|
||||
from mitmproxy.proxy2.layers import tls
|
||||
|
||||
|
||||
def alpn_select_callback(conn: SSL.Connection, options):
|
||||
server_alpn = conn.get_app_data()["server_alpn"]
|
||||
if server_alpn and server_alpn in options:
|
||||
return server_alpn
|
||||
for alpn in tls.HTTP_ALPNS:
|
||||
if alpn in options:
|
||||
return alpn
|
||||
else:
|
||||
# FIXME: pyOpenSSL requires that an ALPN is negotiated, we can't return SSL_TLSEXT_ERR_NOACK.
|
||||
return options[0]
|
||||
|
||||
|
||||
class TlsConfig:
|
||||
certstore: certs.CertStore
|
||||
|
||||
# TODO: We should re-use SSL.Context options here, if only for TLS session resumption.
|
||||
# This may require patches to pyOpenSSL, as some functionality is only exposed on contexts.
|
||||
|
||||
def get_cert(self, context: context.Context) -> Tuple[certs.Cert, SSL.PKey, str]:
|
||||
return self.certstore.get_cert(
|
||||
context.client.sni, [context.client.sni]
|
||||
)
|
||||
|
||||
def tls_start(self, tls_start: tls.TlsStart):
|
||||
if tls_start.conn == tls_start.context.client:
|
||||
self.create_client_proxy_ssl_conn(tls_start)
|
||||
else:
|
||||
self.create_proxy_server_ssl_conn(tls_start)
|
||||
|
||||
def create_client_proxy_ssl_conn(self, tls_start: tls.TlsStart) -> None:
|
||||
tls_method, tls_options = net_tls.VERSION_CHOICES[ctx.options.ssl_version_client]
|
||||
cert, key, chain_file = self.get_cert(tls_start.context)
|
||||
if ctx.options.add_upstream_certs_to_client_chain:
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
extra_chain_certs = None
|
||||
ssl_ctx = net_tls.create_server_context(
|
||||
cert=cert,
|
||||
key=key,
|
||||
method=tls_method,
|
||||
options=tls_options,
|
||||
cipher_list=ctx.options.ciphers_client or DEFAULT_CLIENT_CIPHERS,
|
||||
dhparams=self.certstore.dhparams,
|
||||
chain_file=chain_file,
|
||||
alpn_select_callback=alpn_select_callback,
|
||||
extra_chain_certs=extra_chain_certs,
|
||||
)
|
||||
tls_start.ssl_conn = SSL.Connection(ssl_ctx)
|
||||
tls_start.ssl_conn.set_app_data({
|
||||
"server_alpn": tls_start.context.server.alpn
|
||||
})
|
||||
|
||||
def create_proxy_server_ssl_conn(self, tls_start: tls.TlsStart) -> None:
|
||||
client = tls_start.context.client
|
||||
server: context.Server = tls_start.conn
|
||||
|
||||
if server.sni is True:
|
||||
server.sni = client.sni or server.address[0].encode()
|
||||
|
||||
if not server.alpn_offers:
|
||||
if client.alpn:
|
||||
server.alpn_offers = [client.alpn]
|
||||
elif client.alpn_offers:
|
||||
server.alpn_offers = client.alpn_offers
|
||||
|
||||
# We pass through the list of ciphers send by the client, because some HTTP/2 servers
|
||||
# will select a non-HTTP/2 compatible cipher from our default list and then hang up
|
||||
# because it's incompatible with h2.
|
||||
if not server.cipher_list:
|
||||
if ctx.options.ciphers_server:
|
||||
server.cipher_list = ctx.options.ciphers_server.split(":")
|
||||
elif client.cipher_list:
|
||||
server.cipher_list = [
|
||||
x for x in client.cipher_list
|
||||
if x in CIPHER_ID_NAME_MAP
|
||||
]
|
||||
|
||||
args = net_tls.client_arguments_from_options(ctx.options)
|
||||
|
||||
client_certs = args.pop("client_certs")
|
||||
client_cert: Optional[str] = None
|
||||
if client_certs:
|
||||
client_certs = os.path.expanduser(client_certs)
|
||||
if os.path.isfile(client_certs):
|
||||
client_cert = client_certs
|
||||
else:
|
||||
server_name: str = (server.sni or server.address[0].encode("idna")).decode()
|
||||
path = os.path.join(client_certs, f"{server_name}.pem")
|
||||
if os.path.exists(path):
|
||||
client_cert = path
|
||||
|
||||
args["cipher_list"] = ':'.join(server.cipher_list) if server.cipher_list else None
|
||||
ssl_ctx = net_tls.create_client_context(
|
||||
cert=client_cert,
|
||||
sni=server.sni.decode("idna"), # FIXME: Should pass-through here.
|
||||
alpn_protos=server.alpn_offers,
|
||||
**args
|
||||
)
|
||||
tls_start.ssl_conn = SSL.Connection(ssl_ctx)
|
||||
|
||||
def configure(self, updated):
|
||||
if not any(x in updated for x in ["confdir", "certs"]):
|
||||
return
|
||||
|
||||
certstore_path = os.path.expanduser(ctx.options.confdir)
|
||||
if not os.path.exists(os.path.dirname(certstore_path)):
|
||||
raise exceptions.OptionsError(
|
||||
f"Certificate Authority parent directory does not exist: {os.path.dirname(certstore_path)}")
|
||||
self.certstore = certs.CertStore.from_store(
|
||||
path=certstore_path,
|
||||
basename=CONF_BASENAME,
|
||||
key_size=ctx.options.key_size
|
||||
)
|
||||
for certspec in ctx.options.certs:
|
||||
parts = certspec.split("=", 1)
|
||||
if len(parts) == 1:
|
||||
parts = ["*", parts[0]]
|
||||
|
||||
cert = os.path.expanduser(parts[1])
|
||||
if not os.path.exists(cert):
|
||||
raise exceptions.OptionsError(f"Certificate file does not exist: {cert}")
|
||||
try:
|
||||
self.certstore.add_cert_file(parts[0], cert)
|
||||
except crypto.Error as e:
|
||||
raise exceptions.OptionsError(f"Invalid certificate format: {cert}") from e
|
@ -21,8 +21,12 @@ class Connection:
|
||||
tls_established: bool = False
|
||||
alpn: Optional[bytes] = None
|
||||
alpn_offers: Sequence[bytes] = ()
|
||||
cipher_list: Sequence[bytes] = ()
|
||||
tls_version: Optional[str] = None
|
||||
sni: Union[bytes, bool, None]
|
||||
|
||||
timestamp_tls_setup: Optional[float] = None
|
||||
|
||||
@property
|
||||
def connected(self):
|
||||
return self.state is ConnectionState.OPEN
|
||||
|
@ -67,7 +67,7 @@ class Layer:
|
||||
@abstractmethod
|
||||
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
"""Handle a proxy server event"""
|
||||
yield from ()
|
||||
yield from () # pragma: no cover
|
||||
|
||||
def handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
if self._paused:
|
||||
@ -95,10 +95,7 @@ class Layer:
|
||||
processing any other commands.
|
||||
"""
|
||||
try:
|
||||
if isinstance(send, Exception):
|
||||
command = command_generator.throw(type(send), send)
|
||||
else:
|
||||
command = command_generator.send(send)
|
||||
command = command_generator.send(send)
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
|
@ -13,7 +13,7 @@ from mitmproxy.proxy.protocol.http import HTTPMode
|
||||
from mitmproxy.proxy2 import commands, events
|
||||
from mitmproxy.proxy2.context import Client, Connection, Context, Server
|
||||
from mitmproxy.proxy2.layer import Layer, NextLayer
|
||||
from mitmproxy.proxy2.layers.tls import EstablishServerTLS, EstablishServerTLSReply
|
||||
from mitmproxy.proxy2.layers.tls import EstablishServerTLS, EstablishServerTLSReply, HTTP_ALPNS
|
||||
from mitmproxy.proxy2.utils import expect
|
||||
from mitmproxy.utils import human
|
||||
|
||||
@ -676,6 +676,9 @@ class HTTPLayer(Layer):
|
||||
|
||||
def make_http_connection(self, connection: Server) -> None:
|
||||
if connection.tls and not connection.tls_established:
|
||||
connection.alpn_offers = list(HTTP_ALPNS)
|
||||
if not self.context.options.http2:
|
||||
connection.alpn_offers.remove(b"h2")
|
||||
new_command = EstablishServerTLS(connection)
|
||||
new_command.blocking = object()
|
||||
yield new_command
|
||||
|
@ -1,12 +1,10 @@
|
||||
import os
|
||||
import struct
|
||||
import time
|
||||
from typing import Any, Dict, Generator, Iterator, Optional, Tuple
|
||||
|
||||
from OpenSSL import SSL
|
||||
|
||||
from mitmproxy.certs import CertStore
|
||||
from mitmproxy.net.tls import ClientHello
|
||||
from mitmproxy.proxy.protocol.tls import DEFAULT_CLIENT_CIPHERS
|
||||
from mitmproxy.net import tls as net_tls
|
||||
from mitmproxy.proxy2 import commands, events, layer
|
||||
from mitmproxy.proxy2 import context
|
||||
from mitmproxy.proxy2.utils import expect
|
||||
@ -69,7 +67,7 @@ def get_client_hello(data: bytes) -> Optional[bytes]:
|
||||
return None
|
||||
|
||||
|
||||
def parse_client_hello(data: bytes) -> Optional[ClientHello]:
|
||||
def parse_client_hello(data: bytes) -> Optional[net_tls.ClientHello]:
|
||||
"""
|
||||
Check if the supplied bytes contain a full ClientHello message,
|
||||
and if so, parse it.
|
||||
@ -84,10 +82,13 @@ def parse_client_hello(data: bytes) -> Optional[ClientHello]:
|
||||
# Check if ClientHello is complete
|
||||
client_hello = get_client_hello(data)
|
||||
if client_hello:
|
||||
return ClientHello(client_hello[4:])
|
||||
return net_tls.ClientHello(client_hello[4:])
|
||||
return None
|
||||
|
||||
|
||||
HTTP_ALPNS = (b"h2", b"http/1.1", b"http/1.0", b"http/0.9")
|
||||
|
||||
|
||||
class EstablishServerTLS(commands.ConnectionCommand):
|
||||
connection: context.Server
|
||||
blocking = True
|
||||
@ -99,9 +100,17 @@ class EstablishServerTLSReply(events.CommandReply):
|
||||
"""error message"""
|
||||
|
||||
|
||||
class TlsStart:
|
||||
def __init__(self, conn: context.Connection, context: context.Context) -> None:
|
||||
self.conn = conn
|
||||
self.context = context
|
||||
self.ssl_conn = None
|
||||
|
||||
|
||||
class _TLSLayer(layer.Layer):
|
||||
tls: Dict[context.Connection, SSL.Connection]
|
||||
child_layer: layer.Layer
|
||||
ssl_context: Optional[SSL.Context] = None
|
||||
|
||||
def __init__(self, context: context.Context):
|
||||
super().__init__(context)
|
||||
@ -140,15 +149,18 @@ class _TLSLayer(layer.Layer):
|
||||
except SSL.WantReadError:
|
||||
yield from self.tls_interact(conn)
|
||||
return False, None
|
||||
except SSL.ZeroReturnError as e:
|
||||
except SSL.Error as e:
|
||||
return False, repr(e)
|
||||
else:
|
||||
conn.tls_established = True
|
||||
conn.sni = self.tls[conn].get_servername()
|
||||
conn.alpn = self.tls[conn].get_alpn_proto_negotiated()
|
||||
conn.cipher_list = self.tls[conn].get_cipher_list()
|
||||
conn.tls_version = self.tls[conn].get_protocol_version_name()
|
||||
conn.timestamp_tls_setup = time.time()
|
||||
yield commands.Log(f"TLS established: {conn}")
|
||||
yield from self.receive(conn, b"")
|
||||
# TODO: Set all other connection attributes here
|
||||
# there might already be data in the OpenSSL BIO, so we need to trigger its processing.
|
||||
return True, None
|
||||
|
||||
def receive(self, conn: context.Connection, data: bytes):
|
||||
@ -213,8 +225,8 @@ class ServerTLSLayer(_TLSLayer):
|
||||
self.command_to_reply_to = {}
|
||||
self.child_layer = layer.NextLayer(self.context)
|
||||
|
||||
def negotiate(self, conn: context.Connection, data: bytes) -> Generator[
|
||||
commands.Command, Any, Tuple[bool, Optional[str]]]:
|
||||
def negotiate(self, conn: context.Connection, data: bytes) \
|
||||
-> Generator[commands.Command, Any, Tuple[bool, Optional[str]]]:
|
||||
done, err = yield from super().negotiate(conn, data)
|
||||
if done or err:
|
||||
cmd = self.command_to_reply_to.pop(conn)
|
||||
@ -232,19 +244,11 @@ class ServerTLSLayer(_TLSLayer):
|
||||
def start_server_tls(self, conn: context.Server):
|
||||
assert conn not in self.tls
|
||||
assert conn.connected
|
||||
conn.tls = True
|
||||
|
||||
ssl_context = SSL.Context(SSL.SSLv23_METHOD)
|
||||
if conn.alpn_offers:
|
||||
ssl_context.set_alpn_protos(conn.alpn_offers)
|
||||
self.tls[conn] = SSL.Connection(ssl_context)
|
||||
|
||||
if conn.sni:
|
||||
if conn.sni is True:
|
||||
if self.context.client.sni:
|
||||
conn.sni = self.context.client.sni
|
||||
else:
|
||||
conn.sni = conn.address[0].encode()
|
||||
self.tls[conn].set_tlsext_host_name(conn.sni)
|
||||
tls_start = TlsStart(conn, self.context)
|
||||
yield commands.Hook("tls_start", tls_start)
|
||||
self.tls[conn] = tls_start.ssl_conn
|
||||
self.tls[conn].set_connect_state()
|
||||
|
||||
yield from self.negotiate(conn, b"")
|
||||
@ -274,6 +278,7 @@ class ClientTLSLayer(_TLSLayer):
|
||||
super().__init__(context)
|
||||
self.recv_buffer = bytearray()
|
||||
self.child_layer = layer.NextLayer(self.context)
|
||||
self._handle_event = self.state_start
|
||||
|
||||
@expect(events.Start)
|
||||
def state_start(self, _) -> commands.TCommandGenerator:
|
||||
@ -281,9 +286,6 @@ class ClientTLSLayer(_TLSLayer):
|
||||
self._handle_event = self.state_wait_for_clienthello
|
||||
yield from ()
|
||||
|
||||
_handle_event = state_start
|
||||
|
||||
@expect(events.DataReceived, events.ConnectionClosed)
|
||||
def state_wait_for_clienthello(self, event: events.Event):
|
||||
client = self.context.client
|
||||
if isinstance(event, events.DataReceived) and event.connection == client:
|
||||
@ -296,8 +298,7 @@ class ClientTLSLayer(_TLSLayer):
|
||||
if client_hello:
|
||||
yield commands.Log(f"Client Hello: {client_hello}")
|
||||
|
||||
# TODO: Don't do double conversion
|
||||
client.sni = client_hello.sni.encode("idna")
|
||||
client.sni = client_hello.sni
|
||||
client.alpn_offers = client_hello.alpn_protocols
|
||||
|
||||
client_tls_requires_server_connection = (
|
||||
@ -322,8 +323,10 @@ class ClientTLSLayer(_TLSLayer):
|
||||
|
||||
# In any case, we now have enough information to start server TLS if needed.
|
||||
yield from self.event_to_child(events.Start())
|
||||
elif isinstance(event, events.ConnectionClosed) and event.connection == client:
|
||||
self.recv_buffer.clear()
|
||||
else:
|
||||
raise NotImplementedError(event) # TODO
|
||||
yield from self.event_to_child(event)
|
||||
|
||||
def start_server_tls(self):
|
||||
"""
|
||||
@ -339,11 +342,6 @@ class ClientTLSLayer(_TLSLayer):
|
||||
)
|
||||
return err
|
||||
|
||||
server.alpn_offers = [
|
||||
x for x in self.context.client.alpn_offers
|
||||
if not (x.startswith(b"h2-") or x.startswith(b"spdy"))
|
||||
]
|
||||
|
||||
err = yield EstablishServerTLS(server)
|
||||
if err:
|
||||
yield commands.Log(
|
||||
@ -352,36 +350,10 @@ class ClientTLSLayer(_TLSLayer):
|
||||
return err
|
||||
|
||||
def start_client_tls(self) -> commands.TCommandGenerator:
|
||||
# FIXME: Do this properly. Also adjust error message in negotiate()
|
||||
client = self.context.client
|
||||
server = self.context.server
|
||||
context = SSL.Context(SSL.SSLv23_METHOD)
|
||||
cert, privkey, cert_chain = CertStore.from_store(
|
||||
os.path.expanduser("~/.mitmproxy"), "mitmproxy",
|
||||
self.context.options.key_size
|
||||
).get_cert(client.sni, (client.sni,))
|
||||
context.use_privatekey(privkey)
|
||||
context.use_certificate(cert.x509)
|
||||
context.set_cipher_list(DEFAULT_CLIENT_CIPHERS)
|
||||
|
||||
def alpn_select_callback(conn_, options):
|
||||
if server.alpn in options:
|
||||
return server.alpn
|
||||
elif b"h2" in options:
|
||||
return b"h2"
|
||||
elif b"http/1.1" in options:
|
||||
return b"http/1.1"
|
||||
elif b"http/1.0" in options:
|
||||
return b"http/1.0"
|
||||
elif b"http/0.9" in options:
|
||||
return b"http/0.9"
|
||||
else:
|
||||
# FIXME: We MUST return something here. At this point we are at loss.
|
||||
return options[0]
|
||||
|
||||
context.set_alpn_select_callback(alpn_select_callback)
|
||||
|
||||
self.tls[client] = SSL.Connection(context)
|
||||
tls_start = TlsStart(client, self.context)
|
||||
yield commands.Hook("tls_start", tls_start)
|
||||
self.tls[client] = tls_start.ssl_conn
|
||||
self.tls[client].set_accept_state()
|
||||
|
||||
yield from self.negotiate(client, bytes(self.recv_buffer))
|
||||
@ -390,11 +362,16 @@ class ClientTLSLayer(_TLSLayer):
|
||||
def negotiate(self, conn: context.Connection, data: bytes) -> Generator[commands.Command, Any, bool]:
|
||||
done, err = yield from super().negotiate(conn, data)
|
||||
if err:
|
||||
if self.context.client.sni:
|
||||
# TODO: Also use other sources than SNI
|
||||
dest = " for " + self.context.client.sni.decode("idna")
|
||||
else:
|
||||
dest = ""
|
||||
yield commands.Log(
|
||||
f"Client TLS Handshake failed. "
|
||||
f"The client may not trust the proxy's certificate (SNI: {self.context.client.sni}).",
|
||||
f"The client may not trust the proxy's certificate{dest} ({err}).",
|
||||
level="warn"
|
||||
# TODO: Also use other sources than SNI
|
||||
|
||||
)
|
||||
yield commands.CloseConnection(self.context.client)
|
||||
return done
|
||||
|
@ -24,7 +24,7 @@ def test_open_connection(tctx):
|
||||
def test_open_connection_err(tctx):
|
||||
f = Placeholder()
|
||||
assert (
|
||||
playbook(TCPLayer(tctx), hooks=True)
|
||||
playbook(TCPLayer(tctx))
|
||||
<< Hook("tcp_start", f)
|
||||
>> reply()
|
||||
<< OpenConnection(tctx.server)
|
||||
@ -40,7 +40,7 @@ def test_simple(tctx):
|
||||
f = Placeholder()
|
||||
|
||||
assert (
|
||||
playbook(TCPLayer(tctx), hooks=True)
|
||||
playbook(TCPLayer(tctx))
|
||||
<< Hook("tcp_start", f)
|
||||
>> reply()
|
||||
<< OpenConnection(tctx.server)
|
||||
@ -71,7 +71,7 @@ def test_receive_data_before_server_connected(tctx):
|
||||
will still be forwarded.
|
||||
"""
|
||||
assert (
|
||||
playbook(TCPLayer(tctx))
|
||||
playbook(TCPLayer(tctx), hooks=False)
|
||||
<< OpenConnection(tctx.server)
|
||||
>> DataReceived(tctx.client, b"hello!")
|
||||
>> reply(None, to=-2)
|
||||
|
@ -3,11 +3,15 @@ import ssl
|
||||
import typing
|
||||
|
||||
import pytest
|
||||
from OpenSSL import SSL
|
||||
|
||||
from mitmproxy.proxy2 import context, events, commands
|
||||
from mitmproxy.proxy2 import commands, context, events
|
||||
from mitmproxy.proxy2.layers import tls
|
||||
from mitmproxy.utils import data
|
||||
from test.mitmproxy.proxy2 import tutils
|
||||
|
||||
tlsdata = data.Data(__name__)
|
||||
|
||||
|
||||
def test_is_tls_handshake_record():
|
||||
assert tls.is_tls_handshake_record(bytes.fromhex("160300"))
|
||||
@ -33,11 +37,11 @@ def test_record_contents():
|
||||
|
||||
|
||||
def test_record_contents_err():
|
||||
with pytest.raises(ValueError, msg="Expected TLS record"):
|
||||
with pytest.raises(ValueError, match="Expected TLS record"):
|
||||
next(tls.handshake_record_contents(b"GET /error"))
|
||||
|
||||
empty_record = bytes.fromhex("1603010000")
|
||||
with pytest.raises(ValueError, msg="Record must not be empty"):
|
||||
with pytest.raises(ValueError, match="Record must not be empty"):
|
||||
next(tls.handshake_record_contents(empty_record))
|
||||
|
||||
|
||||
@ -53,8 +57,8 @@ def test_get_client_hello():
|
||||
assert tls.get_client_hello(single_record) == client_hello_no_extensions
|
||||
|
||||
split_over_two_records = (
|
||||
bytes.fromhex("1603010020") + client_hello_no_extensions[:32] +
|
||||
bytes.fromhex("1603010045") + client_hello_no_extensions[32:]
|
||||
bytes.fromhex("1603010020") + client_hello_no_extensions[:32] +
|
||||
bytes.fromhex("1603010045") + client_hello_no_extensions[32:]
|
||||
)
|
||||
assert tls.get_client_hello(split_over_two_records) == client_hello_no_extensions
|
||||
|
||||
@ -65,7 +69,8 @@ def test_get_client_hello():
|
||||
class SSLTest:
|
||||
"""Helper container for Python's builtin SSL object."""
|
||||
|
||||
def __init__(self, server_side=False, alpn=None):
|
||||
def __init__(self, server_side: bool = False, alpn: typing.List[bytes] = None,
|
||||
sni: typing.Optional[bytes] = b"example.com"):
|
||||
self.inc = ssl.MemoryBIO()
|
||||
self.out = ssl.MemoryBIO()
|
||||
self.ctx = ssl.SSLContext()
|
||||
@ -77,83 +82,78 @@ class SSLTest:
|
||||
self.obj = self.ctx.wrap_bio(
|
||||
self.inc,
|
||||
self.out,
|
||||
server_hostname=None if server_side else "example.com",
|
||||
server_hostname=None if server_side else sni,
|
||||
server_side=server_side,
|
||||
)
|
||||
|
||||
|
||||
def _test_tls_client_server(
|
||||
tctx: context.Context,
|
||||
alpn: typing.Optional[str]
|
||||
) -> typing.Tuple[tutils.playbook[tls.ClientTLSLayer], SSLTest]:
|
||||
layer = tls.ClientTLSLayer(tctx)
|
||||
playbook = tutils.playbook(layer)
|
||||
tctx.server.tls = True
|
||||
tctx.server.address = ("example.com", 443)
|
||||
tssl_client = SSLTest(alpn=alpn)
|
||||
|
||||
# Handshake
|
||||
assert (
|
||||
playbook
|
||||
<< None
|
||||
)
|
||||
|
||||
with pytest.raises(ssl.SSLWantReadError):
|
||||
tssl_client.obj.do_handshake()
|
||||
client_hello = tssl_client.out.read()
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, client_hello[:42])
|
||||
<< None
|
||||
)
|
||||
# Still waiting...
|
||||
# Finish sending ClientHello
|
||||
playbook >> events.DataReceived(tctx.client, client_hello[42:])
|
||||
return playbook, tssl_client
|
||||
|
||||
|
||||
def echo(playbook: tutils.playbook, tssl: SSLTest, conn: context.Connection) -> None:
|
||||
def _test_echo(playbook: tutils.playbook, tssl: SSLTest, conn: context.Connection) -> None:
|
||||
tssl.obj.write(b"Hello World")
|
||||
data = tutils.Placeholder()
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(conn, tssl.out.read())
|
||||
<< commands.Hook("next_layer", tutils.Placeholder())
|
||||
>> tutils.next_layer(tutils.EchoLayer)
|
||||
<< commands.SendData(conn, data)
|
||||
playbook
|
||||
>> events.DataReceived(conn, tssl.out.read())
|
||||
<< commands.SendData(conn, data)
|
||||
)
|
||||
tssl.inc.write(data())
|
||||
assert tssl.obj.read() == b"hello world"
|
||||
|
||||
|
||||
class TlsEchoLayer(tutils.EchoLayer):
|
||||
err: typing.Optional[str] = None
|
||||
|
||||
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
if isinstance(event, events.DataReceived) and event.data == b"establish-server-tls":
|
||||
# noinspection PyTypeChecker
|
||||
self.err = yield tls.EstablishServerTLS(self.context.server)
|
||||
else:
|
||||
yield from super()._handle_event(event)
|
||||
|
||||
|
||||
def interact(playbook: tutils.playbook, conn: context.Connection, tssl: SSLTest):
|
||||
data = tutils.Placeholder()
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(conn, tssl.out.read())
|
||||
<< commands.SendData(conn, data)
|
||||
)
|
||||
tssl.inc.write(data())
|
||||
|
||||
|
||||
def reply_tls_start(*args, **kwargs) -> tutils.reply:
|
||||
"""
|
||||
Helper function to simplify the syntax for tls_start hooks.
|
||||
"""
|
||||
|
||||
def make_conn(hook: commands.Hook) -> None:
|
||||
tls_start = hook.data
|
||||
assert isinstance(tls_start, tls.TlsStart)
|
||||
ssl_context = SSL.Context(SSL.SSLv23_METHOD)
|
||||
if tls_start.conn == tls_start.context.client:
|
||||
ssl_context.use_privatekey_file(
|
||||
tlsdata.path("../../net/data/verificationcerts/trusted-leaf.key")
|
||||
)
|
||||
ssl_context.use_certificate_chain_file(
|
||||
tlsdata.path("../../net/data/verificationcerts/trusted-leaf.crt")
|
||||
)
|
||||
tls_start.ssl_conn = SSL.Connection(ssl_context)
|
||||
|
||||
return tutils.reply(*args, side_effect=make_conn, **kwargs)
|
||||
|
||||
|
||||
class TestServerTLS:
|
||||
def test_no_tls(self, tctx: context.Context):
|
||||
"""Test TLS layer without TLS"""
|
||||
layer = tls.ServerTLSLayer(tctx)
|
||||
playbook = tutils.playbook(layer)
|
||||
layer.child_layer = TlsEchoLayer(tctx)
|
||||
|
||||
# Handshake
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, b"Hello World")
|
||||
<< commands.Hook("next_layer", tutils.Placeholder())
|
||||
>> tutils.next_layer(tutils.EchoLayer)
|
||||
<< commands.SendData(tctx.client, b"hello world")
|
||||
)
|
||||
|
||||
def test_no_connection(self, tctx):
|
||||
"""
|
||||
The server TLS layer is initiated, but there is no active connection yet, so nothing
|
||||
should be done.
|
||||
"""
|
||||
layer = tls.ServerTLSLayer(tctx)
|
||||
playbook = tutils.playbook(layer)
|
||||
tctx.server.tls = True
|
||||
|
||||
# We did not have a server connection before, so let's do nothing.
|
||||
assert (
|
||||
playbook
|
||||
<< None
|
||||
tutils.playbook(layer)
|
||||
>> events.DataReceived(tctx.client, b"Hello World")
|
||||
<< commands.SendData(tctx.client, b"hello world")
|
||||
>> events.DataReceived(tctx.server, b"Foo")
|
||||
<< commands.SendData(tctx.server, b"foo")
|
||||
)
|
||||
|
||||
def test_simple(self, tctx):
|
||||
@ -161,166 +161,171 @@ class TestServerTLS:
|
||||
playbook = tutils.playbook(layer)
|
||||
tctx.server.connected = True
|
||||
tctx.server.address = ("example.com", 443)
|
||||
tctx.server.tls = True
|
||||
|
||||
tssl = SSLTest(server_side=True)
|
||||
|
||||
# send ClientHello
|
||||
data = tutils.Placeholder()
|
||||
assert (
|
||||
playbook
|
||||
<< commands.SendData(tctx.server, data)
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, b"establish-server-tls")
|
||||
<< commands.Hook("next_layer", tutils.Placeholder())
|
||||
>> tutils.next_layer(TlsEchoLayer)
|
||||
<< commands.Hook("tls_start", tutils.Placeholder())
|
||||
>> reply_tls_start()
|
||||
<< commands.SendData(tctx.server, data)
|
||||
)
|
||||
|
||||
# receive ServerHello, finish client handshake
|
||||
tssl.inc.write(data())
|
||||
with pytest.raises(ssl.SSLWantReadError):
|
||||
tssl.obj.do_handshake()
|
||||
data = tutils.Placeholder()
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.server, tssl.out.read())
|
||||
<< commands.SendData(tctx.server, data)
|
||||
)
|
||||
tssl.inc.write(data())
|
||||
interact(playbook, tctx.server, tssl)
|
||||
|
||||
# finish server handshake
|
||||
tssl.obj.do_handshake()
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.server, tssl.out.read())
|
||||
<< None
|
||||
playbook
|
||||
>> events.DataReceived(tctx.server, tssl.out.read())
|
||||
<< None
|
||||
)
|
||||
|
||||
assert tctx.server.tls_established
|
||||
assert tctx.server.sni == b"example.com"
|
||||
|
||||
# Echo
|
||||
echo(playbook, tssl, tctx.server)
|
||||
_test_echo(playbook, tssl, tctx.server)
|
||||
|
||||
|
||||
def _make_client_tls_layer(tctx: context.Context) -> typing.Tuple[tutils.playbook, tls.ClientTLSLayer]:
|
||||
# This is a bit contrived as the client layer expects a server layer as parent.
|
||||
# We also set child layers manually to avoid NextLayer noise.
|
||||
server_layer = tls.ServerTLSLayer(tctx)
|
||||
client_layer = tls.ClientTLSLayer(tctx)
|
||||
server_layer.child_layer = client_layer
|
||||
client_layer.child_layer = TlsEchoLayer(tctx)
|
||||
playbook = tutils.playbook(server_layer)
|
||||
return playbook, client_layer
|
||||
|
||||
|
||||
def _test_tls_client_server(
|
||||
tctx: context.Context,
|
||||
sni: typing.Optional[bytes]
|
||||
) -> typing.Tuple[tutils.playbook, tls.ClientTLSLayer, SSLTest]:
|
||||
playbook, client_layer = _make_client_tls_layer(tctx)
|
||||
tctx.server.tls = True
|
||||
tctx.server.address = ("example.com", 443)
|
||||
tssl_client = SSLTest(sni=sni)
|
||||
|
||||
# Send ClientHello
|
||||
with pytest.raises(ssl.SSLWantReadError):
|
||||
tssl_client.obj.do_handshake()
|
||||
|
||||
return playbook, client_layer, tssl_client
|
||||
|
||||
|
||||
class TestClientTLS:
|
||||
def test_simple(self, tctx: context.Context):
|
||||
def test_client_only(self, tctx: context.Context):
|
||||
"""Test TLS with client only"""
|
||||
layer = tls.ClientTLSLayer(tctx)
|
||||
playbook = tutils.playbook(layer)
|
||||
playbook, client_layer = _make_client_tls_layer(tctx)
|
||||
tssl = SSLTest()
|
||||
assert not tctx.client.tls_established
|
||||
|
||||
# Handshake
|
||||
assert playbook
|
||||
assert layer._handle_event == layer.state_wait_for_clienthello
|
||||
|
||||
def interact():
|
||||
data = tutils.Placeholder()
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, tssl.out.read())
|
||||
<< commands.SendData(tctx.client, data)
|
||||
)
|
||||
tssl.inc.write(data())
|
||||
try:
|
||||
tssl.obj.do_handshake()
|
||||
except ssl.SSLWantReadError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
# receive ClientHello, send ServerHello
|
||||
# Start Handshake, send ClientHello and ServerHello
|
||||
with pytest.raises(ssl.SSLWantReadError):
|
||||
tssl.obj.do_handshake()
|
||||
assert not interact()
|
||||
# Finish Handshake
|
||||
assert interact()
|
||||
data = tutils.Placeholder()
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, tssl.out.read())
|
||||
<< commands.Hook("tls_start", tutils.Placeholder())
|
||||
>> reply_tls_start()
|
||||
<< commands.SendData(tctx.client, data)
|
||||
)
|
||||
tssl.inc.write(data())
|
||||
tssl.obj.do_handshake()
|
||||
# Finish Handshake
|
||||
interact(playbook, tctx.client, tssl)
|
||||
|
||||
assert layer._handle_event == layer.state_process
|
||||
assert tssl.obj.getpeercert(True)
|
||||
assert tctx.client.tls_established
|
||||
|
||||
# Echo
|
||||
echo(playbook, tssl, tctx.client)
|
||||
_test_echo(playbook, tssl, tctx.client)
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.server, b"Hello")
|
||||
<< commands.SendData(tctx.server, b"hello")
|
||||
playbook
|
||||
>> events.DataReceived(tctx.server, b"Plaintext")
|
||||
<< commands.SendData(tctx.server, b"plaintext")
|
||||
)
|
||||
|
||||
def test_no_server_conn_required(self, tctx):
|
||||
def test_server_not_required(self, tctx):
|
||||
"""
|
||||
Here we test the scenario where a server connection is _not_ required
|
||||
to establish TLS with the client. After determining this when parsing the ClientHello,
|
||||
we only establish a connection with the client. The server connection may ultimately
|
||||
be established when OpenConnection is called.
|
||||
"""
|
||||
playbook, _ = _test_tls_client_server(tctx, None)
|
||||
playbook, client_layer, tssl = _test_tls_client_server(tctx, sni=b"example.com")
|
||||
data = tutils.Placeholder()
|
||||
assert (
|
||||
playbook
|
||||
<< commands.SendData(tctx.client, data)
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, tssl.out.read())
|
||||
<< commands.Hook("tls_start", tutils.Placeholder())
|
||||
>> reply_tls_start()
|
||||
<< commands.SendData(tctx.client, data)
|
||||
)
|
||||
assert data()
|
||||
assert playbook.layer._handle_event == playbook.layer.state_process
|
||||
tssl.inc.write(data())
|
||||
tssl.obj.do_handshake()
|
||||
interact(playbook, tctx.client, tssl)
|
||||
assert tctx.client.tls_established
|
||||
|
||||
def test_alpn(self, tctx):
|
||||
def test_server_required(self, tctx):
|
||||
"""
|
||||
Here we test the scenario where a server connection is required (e.g. because of ALPN negotation)
|
||||
Here we test the scenario where a server connection is required (because SNI is missing)
|
||||
to establish TLS with the client.
|
||||
"""
|
||||
tssl_server = SSLTest(server_side=True, alpn=["foo", "bar"])
|
||||
|
||||
playbook, tssl_client = _test_tls_client_server(tctx, ["qux", "foo"])
|
||||
tssl_server = SSLTest(server_side=True)
|
||||
playbook, client_layer, tssl_client = _test_tls_client_server(tctx, sni=None)
|
||||
|
||||
# We should now get instructed to open a server connection.
|
||||
assert (
|
||||
playbook
|
||||
<< commands.OpenConnection(tctx.server)
|
||||
)
|
||||
tctx.server.connected = True
|
||||
data = tutils.Placeholder()
|
||||
assert (
|
||||
playbook
|
||||
>> events.OpenConnectionReply(-1, None)
|
||||
<< commands.SendData(tctx.server, data)
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, tssl_client.out.read())
|
||||
<< commands.OpenConnection(tctx.server)
|
||||
>> tutils.reply(None)
|
||||
<< commands.Hook("tls_start", tutils.Placeholder())
|
||||
>> reply_tls_start()
|
||||
<< commands.SendData(tctx.server, data)
|
||||
)
|
||||
assert playbook.layer._handle_event == playbook.layer.state_wait_for_server_tls
|
||||
assert playbook.layer.child_layer.tls[tctx.server]
|
||||
|
||||
# Establish TLS with the server...
|
||||
tssl_server.inc.write(data())
|
||||
with pytest.raises(ssl.SSLWantReadError):
|
||||
tssl_server.obj.do_handshake()
|
||||
|
||||
data = tutils.Placeholder()
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.server, tssl_server.out.read())
|
||||
<< commands.SendData(tctx.server, data)
|
||||
playbook
|
||||
>> events.DataReceived(tctx.server, tssl_server.out.read())
|
||||
<< commands.SendData(tctx.server, data)
|
||||
<< commands.Hook("tls_start", tutils.Placeholder())
|
||||
)
|
||||
tssl_server.inc.write(data())
|
||||
tssl_server.obj.do_handshake()
|
||||
data = tutils.Placeholder()
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.server, tssl_server.out.read())
|
||||
<< commands.SendData(tctx.client, data)
|
||||
)
|
||||
|
||||
assert playbook.layer._handle_event == playbook.layer.state_process
|
||||
assert tctx.server.tls_established
|
||||
|
||||
# Server TLS is established, we can now reply to the client handshake...
|
||||
tssl_client.inc.write(data())
|
||||
with pytest.raises(ssl.SSLWantReadError):
|
||||
tssl_client.obj.do_handshake()
|
||||
|
||||
data = tutils.Placeholder()
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, tssl_client.out.read())
|
||||
<< commands.SendData(tctx.client, data)
|
||||
playbook
|
||||
>> reply_tls_start()
|
||||
<< commands.SendData(tctx.client, data)
|
||||
)
|
||||
tssl_client.inc.write(data())
|
||||
tssl_client.obj.do_handshake()
|
||||
interact(playbook, tctx.client, tssl_client)
|
||||
|
||||
# Both handshakes completed!
|
||||
assert tctx.client.tls_established
|
||||
assert tctx.server.tls_established
|
||||
|
||||
assert tssl_client.obj.selected_alpn_protocol() == "foo"
|
||||
assert tssl_server.obj.selected_alpn_protocol() == "foo"
|
||||
_test_echo(playbook, tssl_server, tctx.server)
|
||||
_test_echo(playbook, tssl_client, tctx.client)
|
||||
|
@ -5,13 +5,13 @@ from test.mitmproxy.proxy2 import tutils
|
||||
class TestNextLayer:
|
||||
def test_simple(self, tctx):
|
||||
nl = layer.NextLayer(tctx)
|
||||
playbook = tutils.playbook(nl)
|
||||
playbook = tutils.playbook(nl, hooks=True)
|
||||
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, b"foo")
|
||||
<< commands.Hook("next_layer", nl)
|
||||
>> events.HookReply(-1)
|
||||
>> tutils.reply()
|
||||
>> events.DataReceived(tctx.client, b"bar")
|
||||
<< commands.Hook("next_layer", nl)
|
||||
)
|
||||
@ -21,7 +21,7 @@ class TestNextLayer:
|
||||
nl.layer = tutils.EchoLayer(tctx)
|
||||
assert (
|
||||
playbook
|
||||
>> events.HookReply(-1)
|
||||
>> tutils.reply()
|
||||
<< commands.SendData(tctx.client, b"foo")
|
||||
<< commands.SendData(tctx.client, b"bar")
|
||||
)
|
||||
@ -45,7 +45,7 @@ class TestNextLayer:
|
||||
|
||||
assert (
|
||||
playbook
|
||||
>> events.HookReply(-2)
|
||||
>> tutils.reply(to=-2)
|
||||
<< commands.SendData(tctx.client, b"foo")
|
||||
<< commands.SendData(tctx.client, b"bar")
|
||||
)
|
||||
@ -63,7 +63,7 @@ class TestNextLayer:
|
||||
handle = nl.handle_event
|
||||
assert (
|
||||
playbook
|
||||
>> events.HookReply(-1)
|
||||
>> tutils.reply()
|
||||
<< commands.SendData(tctx.client, b"foo")
|
||||
)
|
||||
sd, = handle(events.DataReceived(tctx.client, b"bar"))
|
||||
|
@ -38,7 +38,7 @@ class TLayer(Layer):
|
||||
|
||||
@pytest.fixture
|
||||
def tplaybook(tctx):
|
||||
return tutils.playbook(TLayer(tctx), [])
|
||||
return tutils.playbook(TLayer(tctx), expected=[])
|
||||
|
||||
|
||||
def test_simple(tplaybook):
|
||||
@ -158,7 +158,7 @@ def test_command_reply(tplaybook):
|
||||
tplaybook
|
||||
>> TEvent()
|
||||
<< TCommand()
|
||||
>> TCommandReply(-1, 42)
|
||||
>> tutils.reply(42)
|
||||
)
|
||||
assert tplaybook.actual[1] == tplaybook.actual[2].command
|
||||
|
||||
|
@ -2,6 +2,7 @@ import collections.abc
|
||||
import copy
|
||||
import difflib
|
||||
import itertools
|
||||
import sys
|
||||
import typing
|
||||
|
||||
from mitmproxy.proxy2 import commands, context
|
||||
@ -101,7 +102,7 @@ class playbook:
|
||||
def __init__(
|
||||
self,
|
||||
layer: Layer,
|
||||
hooks: bool = False,
|
||||
hooks: bool = True,
|
||||
logs: bool = False,
|
||||
expected: typing.Optional[TPlaybook] = None,
|
||||
):
|
||||
@ -196,13 +197,13 @@ class playbook:
|
||||
class reply(events.Event):
|
||||
args: typing.Tuple[typing.Any, ...]
|
||||
to: typing.Union[commands.Command, int]
|
||||
side_effect: typing.Callable[[commands.Command], typing.Any]
|
||||
side_effect: typing.Callable[[typing.Any], typing.Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
to: typing.Union[commands.Command, int] = -1,
|
||||
side_effect: typing.Callable[[commands.Command], typing.Any] = lambda cmd: None
|
||||
side_effect: typing.Callable[[typing.Any], None] = lambda x: None
|
||||
):
|
||||
"""Utility method to reply to the latest hook in playbooks."""
|
||||
self.args = args
|
||||
@ -226,6 +227,7 @@ class reply(events.Event):
|
||||
actual_str = "\n".join(_fmt_entry(x) for x in playbook.actual)
|
||||
raise AssertionError(f"Expected command ({self.to}) did not occur:\n{actual_str}")
|
||||
|
||||
assert isinstance(self.to, commands.Command)
|
||||
self.side_effect(self.to)
|
||||
reply_cls = command_reply_subclasses[type(self.to)]
|
||||
try:
|
||||
@ -272,14 +274,16 @@ def Placeholder() -> typing.Any:
|
||||
class EchoLayer(Layer):
|
||||
"""Echo layer that sends all data back to the client in lowercase."""
|
||||
|
||||
def _handle_event(self, event: events.Event):
|
||||
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
if isinstance(event, events.DataReceived):
|
||||
yield commands.SendData(event.connection, event.data.lower())
|
||||
|
||||
|
||||
def next_layer(
|
||||
layer: typing.Union[typing.Type[Layer], typing.Callable[[context.Context], Layer]]
|
||||
) -> events.HookReply:
|
||||
layer: typing.Union[typing.Type[Layer], typing.Callable[[context.Context], Layer]],
|
||||
*args,
|
||||
**kwargs
|
||||
) -> reply:
|
||||
"""
|
||||
Helper function to simplify the syntax for next_layer events from this:
|
||||
|
||||
@ -294,21 +298,10 @@ def next_layer(
|
||||
|
||||
<< commands.Hook("next_layer", next_layer)
|
||||
>> tutils.next_layer(next_layer, tutils.EchoLayer)
|
||||
>> tutils.reply(side_effect=lambda cmd: cmd.layer = tutils.EchoLayer(cmd.data.context)
|
||||
"""
|
||||
raise RuntimeError("Does tutils.reply(side_effect=lambda cmd: cmd.layer = tutils.EchoLayer(cmd.data.context) work?")
|
||||
if isinstance(layer, type):
|
||||
def make_layer(ctx: context.Context) -> Layer:
|
||||
return layer(ctx)
|
||||
else:
|
||||
make_layer = layer
|
||||
|
||||
def set_layer(playbook: playbook) -> None:
|
||||
last_command = playbook.actual[-1]
|
||||
assert isinstance(last_command, commands.Hook)
|
||||
assert isinstance(last_command.data, NextLayer)
|
||||
last_command.data.layer = make_layer(last_command.data.context)
|
||||
def set_layer(hook: commands.Hook) -> None:
|
||||
assert isinstance(hook.data, NextLayer)
|
||||
hook.data.layer = layer(hook.data.context)
|
||||
|
||||
reply = events.HookReply(-1)
|
||||
reply._playbook_eval = set_layer
|
||||
return reply
|
||||
return reply(*args, side_effect=set_layer, **kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user