mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 10:16:27 +00:00
[sans-io] tls layer++
This commit is contained in:
parent
0c04638d8d
commit
1c80dfe17f
138
mitmproxy/addons/tlsconfig.py
Normal file
138
mitmproxy/addons/tlsconfig.py
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
import os
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from OpenSSL import SSL, crypto
|
||||||
|
|
||||||
|
from mitmproxy import certs, ctx, exceptions
|
||||||
|
from mitmproxy.net import tls as net_tls
|
||||||
|
from mitmproxy.options import CONF_BASENAME
|
||||||
|
from mitmproxy.proxy.protocol.tls import CIPHER_ID_NAME_MAP, DEFAULT_CLIENT_CIPHERS
|
||||||
|
from mitmproxy.proxy2 import context
|
||||||
|
from mitmproxy.proxy2.layers import tls
|
||||||
|
|
||||||
|
|
||||||
|
def alpn_select_callback(conn: SSL.Connection, options):
|
||||||
|
server_alpn = conn.get_app_data()["server_alpn"]
|
||||||
|
if server_alpn and server_alpn in options:
|
||||||
|
return server_alpn
|
||||||
|
for alpn in tls.HTTP_ALPNS:
|
||||||
|
if alpn in options:
|
||||||
|
return alpn
|
||||||
|
else:
|
||||||
|
# FIXME: pyOpenSSL requires that an ALPN is negotiated, we can't return SSL_TLSEXT_ERR_NOACK.
|
||||||
|
return options[0]
|
||||||
|
|
||||||
|
|
||||||
|
class TlsConfig:
|
||||||
|
certstore: certs.CertStore
|
||||||
|
|
||||||
|
# TODO: We should re-use SSL.Context options here, if only for TLS session resumption.
|
||||||
|
# This may require patches to pyOpenSSL, as some functionality is only exposed on contexts.
|
||||||
|
|
||||||
|
def get_cert(self, context: context.Context) -> Tuple[certs.Cert, SSL.PKey, str]:
|
||||||
|
return self.certstore.get_cert(
|
||||||
|
context.client.sni, [context.client.sni]
|
||||||
|
)
|
||||||
|
|
||||||
|
def tls_start(self, tls_start: tls.TlsStart):
|
||||||
|
if tls_start.conn == tls_start.context.client:
|
||||||
|
self.create_client_proxy_ssl_conn(tls_start)
|
||||||
|
else:
|
||||||
|
self.create_proxy_server_ssl_conn(tls_start)
|
||||||
|
|
||||||
|
def create_client_proxy_ssl_conn(self, tls_start: tls.TlsStart) -> None:
|
||||||
|
tls_method, tls_options = net_tls.VERSION_CHOICES[ctx.options.ssl_version_client]
|
||||||
|
cert, key, chain_file = self.get_cert(tls_start.context)
|
||||||
|
if ctx.options.add_upstream_certs_to_client_chain:
|
||||||
|
raise NotImplementedError()
|
||||||
|
else:
|
||||||
|
extra_chain_certs = None
|
||||||
|
ssl_ctx = net_tls.create_server_context(
|
||||||
|
cert=cert,
|
||||||
|
key=key,
|
||||||
|
method=tls_method,
|
||||||
|
options=tls_options,
|
||||||
|
cipher_list=ctx.options.ciphers_client or DEFAULT_CLIENT_CIPHERS,
|
||||||
|
dhparams=self.certstore.dhparams,
|
||||||
|
chain_file=chain_file,
|
||||||
|
alpn_select_callback=alpn_select_callback,
|
||||||
|
extra_chain_certs=extra_chain_certs,
|
||||||
|
)
|
||||||
|
tls_start.ssl_conn = SSL.Connection(ssl_ctx)
|
||||||
|
tls_start.ssl_conn.set_app_data({
|
||||||
|
"server_alpn": tls_start.context.server.alpn
|
||||||
|
})
|
||||||
|
|
||||||
|
def create_proxy_server_ssl_conn(self, tls_start: tls.TlsStart) -> None:
|
||||||
|
client = tls_start.context.client
|
||||||
|
server: context.Server = tls_start.conn
|
||||||
|
|
||||||
|
if server.sni is True:
|
||||||
|
server.sni = client.sni or server.address[0].encode()
|
||||||
|
|
||||||
|
if not server.alpn_offers:
|
||||||
|
if client.alpn:
|
||||||
|
server.alpn_offers = [client.alpn]
|
||||||
|
elif client.alpn_offers:
|
||||||
|
server.alpn_offers = client.alpn_offers
|
||||||
|
|
||||||
|
# We pass through the list of ciphers send by the client, because some HTTP/2 servers
|
||||||
|
# will select a non-HTTP/2 compatible cipher from our default list and then hang up
|
||||||
|
# because it's incompatible with h2.
|
||||||
|
if not server.cipher_list:
|
||||||
|
if ctx.options.ciphers_server:
|
||||||
|
server.cipher_list = ctx.options.ciphers_server.split(":")
|
||||||
|
elif client.cipher_list:
|
||||||
|
server.cipher_list = [
|
||||||
|
x for x in client.cipher_list
|
||||||
|
if x in CIPHER_ID_NAME_MAP
|
||||||
|
]
|
||||||
|
|
||||||
|
args = net_tls.client_arguments_from_options(ctx.options)
|
||||||
|
|
||||||
|
client_certs = args.pop("client_certs")
|
||||||
|
client_cert: Optional[str] = None
|
||||||
|
if client_certs:
|
||||||
|
client_certs = os.path.expanduser(client_certs)
|
||||||
|
if os.path.isfile(client_certs):
|
||||||
|
client_cert = client_certs
|
||||||
|
else:
|
||||||
|
server_name: str = (server.sni or server.address[0].encode("idna")).decode()
|
||||||
|
path = os.path.join(client_certs, f"{server_name}.pem")
|
||||||
|
if os.path.exists(path):
|
||||||
|
client_cert = path
|
||||||
|
|
||||||
|
args["cipher_list"] = ':'.join(server.cipher_list) if server.cipher_list else None
|
||||||
|
ssl_ctx = net_tls.create_client_context(
|
||||||
|
cert=client_cert,
|
||||||
|
sni=server.sni.decode("idna"), # FIXME: Should pass-through here.
|
||||||
|
alpn_protos=server.alpn_offers,
|
||||||
|
**args
|
||||||
|
)
|
||||||
|
tls_start.ssl_conn = SSL.Connection(ssl_ctx)
|
||||||
|
|
||||||
|
def configure(self, updated):
|
||||||
|
if not any(x in updated for x in ["confdir", "certs"]):
|
||||||
|
return
|
||||||
|
|
||||||
|
certstore_path = os.path.expanduser(ctx.options.confdir)
|
||||||
|
if not os.path.exists(os.path.dirname(certstore_path)):
|
||||||
|
raise exceptions.OptionsError(
|
||||||
|
f"Certificate Authority parent directory does not exist: {os.path.dirname(certstore_path)}")
|
||||||
|
self.certstore = certs.CertStore.from_store(
|
||||||
|
path=certstore_path,
|
||||||
|
basename=CONF_BASENAME,
|
||||||
|
key_size=ctx.options.key_size
|
||||||
|
)
|
||||||
|
for certspec in ctx.options.certs:
|
||||||
|
parts = certspec.split("=", 1)
|
||||||
|
if len(parts) == 1:
|
||||||
|
parts = ["*", parts[0]]
|
||||||
|
|
||||||
|
cert = os.path.expanduser(parts[1])
|
||||||
|
if not os.path.exists(cert):
|
||||||
|
raise exceptions.OptionsError(f"Certificate file does not exist: {cert}")
|
||||||
|
try:
|
||||||
|
self.certstore.add_cert_file(parts[0], cert)
|
||||||
|
except crypto.Error as e:
|
||||||
|
raise exceptions.OptionsError(f"Invalid certificate format: {cert}") from e
|
@ -21,8 +21,12 @@ class Connection:
|
|||||||
tls_established: bool = False
|
tls_established: bool = False
|
||||||
alpn: Optional[bytes] = None
|
alpn: Optional[bytes] = None
|
||||||
alpn_offers: Sequence[bytes] = ()
|
alpn_offers: Sequence[bytes] = ()
|
||||||
|
cipher_list: Sequence[bytes] = ()
|
||||||
|
tls_version: Optional[str] = None
|
||||||
sni: Union[bytes, bool, None]
|
sni: Union[bytes, bool, None]
|
||||||
|
|
||||||
|
timestamp_tls_setup: Optional[float] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def connected(self):
|
def connected(self):
|
||||||
return self.state is ConnectionState.OPEN
|
return self.state is ConnectionState.OPEN
|
||||||
|
@ -67,7 +67,7 @@ class Layer:
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||||
"""Handle a proxy server event"""
|
"""Handle a proxy server event"""
|
||||||
yield from ()
|
yield from () # pragma: no cover
|
||||||
|
|
||||||
def handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
def handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||||
if self._paused:
|
if self._paused:
|
||||||
@ -95,10 +95,7 @@ class Layer:
|
|||||||
processing any other commands.
|
processing any other commands.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if isinstance(send, Exception):
|
command = command_generator.send(send)
|
||||||
command = command_generator.throw(type(send), send)
|
|
||||||
else:
|
|
||||||
command = command_generator.send(send)
|
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ from mitmproxy.proxy.protocol.http import HTTPMode
|
|||||||
from mitmproxy.proxy2 import commands, events
|
from mitmproxy.proxy2 import commands, events
|
||||||
from mitmproxy.proxy2.context import Client, Connection, Context, Server
|
from mitmproxy.proxy2.context import Client, Connection, Context, Server
|
||||||
from mitmproxy.proxy2.layer import Layer, NextLayer
|
from mitmproxy.proxy2.layer import Layer, NextLayer
|
||||||
from mitmproxy.proxy2.layers.tls import EstablishServerTLS, EstablishServerTLSReply
|
from mitmproxy.proxy2.layers.tls import EstablishServerTLS, EstablishServerTLSReply, HTTP_ALPNS
|
||||||
from mitmproxy.proxy2.utils import expect
|
from mitmproxy.proxy2.utils import expect
|
||||||
from mitmproxy.utils import human
|
from mitmproxy.utils import human
|
||||||
|
|
||||||
@ -676,6 +676,9 @@ class HTTPLayer(Layer):
|
|||||||
|
|
||||||
def make_http_connection(self, connection: Server) -> None:
|
def make_http_connection(self, connection: Server) -> None:
|
||||||
if connection.tls and not connection.tls_established:
|
if connection.tls and not connection.tls_established:
|
||||||
|
connection.alpn_offers = list(HTTP_ALPNS)
|
||||||
|
if not self.context.options.http2:
|
||||||
|
connection.alpn_offers.remove(b"h2")
|
||||||
new_command = EstablishServerTLS(connection)
|
new_command = EstablishServerTLS(connection)
|
||||||
new_command.blocking = object()
|
new_command.blocking = object()
|
||||||
yield new_command
|
yield new_command
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
import os
|
|
||||||
import struct
|
import struct
|
||||||
|
import time
|
||||||
from typing import Any, Dict, Generator, Iterator, Optional, Tuple
|
from typing import Any, Dict, Generator, Iterator, Optional, Tuple
|
||||||
|
|
||||||
from OpenSSL import SSL
|
from OpenSSL import SSL
|
||||||
|
|
||||||
from mitmproxy.certs import CertStore
|
from mitmproxy.net import tls as net_tls
|
||||||
from mitmproxy.net.tls import ClientHello
|
|
||||||
from mitmproxy.proxy.protocol.tls import DEFAULT_CLIENT_CIPHERS
|
|
||||||
from mitmproxy.proxy2 import commands, events, layer
|
from mitmproxy.proxy2 import commands, events, layer
|
||||||
from mitmproxy.proxy2 import context
|
from mitmproxy.proxy2 import context
|
||||||
from mitmproxy.proxy2.utils import expect
|
from mitmproxy.proxy2.utils import expect
|
||||||
@ -69,7 +67,7 @@ def get_client_hello(data: bytes) -> Optional[bytes]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def parse_client_hello(data: bytes) -> Optional[ClientHello]:
|
def parse_client_hello(data: bytes) -> Optional[net_tls.ClientHello]:
|
||||||
"""
|
"""
|
||||||
Check if the supplied bytes contain a full ClientHello message,
|
Check if the supplied bytes contain a full ClientHello message,
|
||||||
and if so, parse it.
|
and if so, parse it.
|
||||||
@ -84,10 +82,13 @@ def parse_client_hello(data: bytes) -> Optional[ClientHello]:
|
|||||||
# Check if ClientHello is complete
|
# Check if ClientHello is complete
|
||||||
client_hello = get_client_hello(data)
|
client_hello = get_client_hello(data)
|
||||||
if client_hello:
|
if client_hello:
|
||||||
return ClientHello(client_hello[4:])
|
return net_tls.ClientHello(client_hello[4:])
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
HTTP_ALPNS = (b"h2", b"http/1.1", b"http/1.0", b"http/0.9")
|
||||||
|
|
||||||
|
|
||||||
class EstablishServerTLS(commands.ConnectionCommand):
|
class EstablishServerTLS(commands.ConnectionCommand):
|
||||||
connection: context.Server
|
connection: context.Server
|
||||||
blocking = True
|
blocking = True
|
||||||
@ -99,9 +100,17 @@ class EstablishServerTLSReply(events.CommandReply):
|
|||||||
"""error message"""
|
"""error message"""
|
||||||
|
|
||||||
|
|
||||||
|
class TlsStart:
|
||||||
|
def __init__(self, conn: context.Connection, context: context.Context) -> None:
|
||||||
|
self.conn = conn
|
||||||
|
self.context = context
|
||||||
|
self.ssl_conn = None
|
||||||
|
|
||||||
|
|
||||||
class _TLSLayer(layer.Layer):
|
class _TLSLayer(layer.Layer):
|
||||||
tls: Dict[context.Connection, SSL.Connection]
|
tls: Dict[context.Connection, SSL.Connection]
|
||||||
child_layer: layer.Layer
|
child_layer: layer.Layer
|
||||||
|
ssl_context: Optional[SSL.Context] = None
|
||||||
|
|
||||||
def __init__(self, context: context.Context):
|
def __init__(self, context: context.Context):
|
||||||
super().__init__(context)
|
super().__init__(context)
|
||||||
@ -140,15 +149,18 @@ class _TLSLayer(layer.Layer):
|
|||||||
except SSL.WantReadError:
|
except SSL.WantReadError:
|
||||||
yield from self.tls_interact(conn)
|
yield from self.tls_interact(conn)
|
||||||
return False, None
|
return False, None
|
||||||
except SSL.ZeroReturnError as e:
|
except SSL.Error as e:
|
||||||
return False, repr(e)
|
return False, repr(e)
|
||||||
else:
|
else:
|
||||||
conn.tls_established = True
|
conn.tls_established = True
|
||||||
|
conn.sni = self.tls[conn].get_servername()
|
||||||
conn.alpn = self.tls[conn].get_alpn_proto_negotiated()
|
conn.alpn = self.tls[conn].get_alpn_proto_negotiated()
|
||||||
|
conn.cipher_list = self.tls[conn].get_cipher_list()
|
||||||
|
conn.tls_version = self.tls[conn].get_protocol_version_name()
|
||||||
|
conn.timestamp_tls_setup = time.time()
|
||||||
yield commands.Log(f"TLS established: {conn}")
|
yield commands.Log(f"TLS established: {conn}")
|
||||||
yield from self.receive(conn, b"")
|
yield from self.receive(conn, b"")
|
||||||
# TODO: Set all other connection attributes here
|
# TODO: Set all other connection attributes here
|
||||||
# there might already be data in the OpenSSL BIO, so we need to trigger its processing.
|
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
def receive(self, conn: context.Connection, data: bytes):
|
def receive(self, conn: context.Connection, data: bytes):
|
||||||
@ -213,8 +225,8 @@ class ServerTLSLayer(_TLSLayer):
|
|||||||
self.command_to_reply_to = {}
|
self.command_to_reply_to = {}
|
||||||
self.child_layer = layer.NextLayer(self.context)
|
self.child_layer = layer.NextLayer(self.context)
|
||||||
|
|
||||||
def negotiate(self, conn: context.Connection, data: bytes) -> Generator[
|
def negotiate(self, conn: context.Connection, data: bytes) \
|
||||||
commands.Command, Any, Tuple[bool, Optional[str]]]:
|
-> Generator[commands.Command, Any, Tuple[bool, Optional[str]]]:
|
||||||
done, err = yield from super().negotiate(conn, data)
|
done, err = yield from super().negotiate(conn, data)
|
||||||
if done or err:
|
if done or err:
|
||||||
cmd = self.command_to_reply_to.pop(conn)
|
cmd = self.command_to_reply_to.pop(conn)
|
||||||
@ -232,19 +244,11 @@ class ServerTLSLayer(_TLSLayer):
|
|||||||
def start_server_tls(self, conn: context.Server):
|
def start_server_tls(self, conn: context.Server):
|
||||||
assert conn not in self.tls
|
assert conn not in self.tls
|
||||||
assert conn.connected
|
assert conn.connected
|
||||||
|
conn.tls = True
|
||||||
|
|
||||||
ssl_context = SSL.Context(SSL.SSLv23_METHOD)
|
tls_start = TlsStart(conn, self.context)
|
||||||
if conn.alpn_offers:
|
yield commands.Hook("tls_start", tls_start)
|
||||||
ssl_context.set_alpn_protos(conn.alpn_offers)
|
self.tls[conn] = tls_start.ssl_conn
|
||||||
self.tls[conn] = SSL.Connection(ssl_context)
|
|
||||||
|
|
||||||
if conn.sni:
|
|
||||||
if conn.sni is True:
|
|
||||||
if self.context.client.sni:
|
|
||||||
conn.sni = self.context.client.sni
|
|
||||||
else:
|
|
||||||
conn.sni = conn.address[0].encode()
|
|
||||||
self.tls[conn].set_tlsext_host_name(conn.sni)
|
|
||||||
self.tls[conn].set_connect_state()
|
self.tls[conn].set_connect_state()
|
||||||
|
|
||||||
yield from self.negotiate(conn, b"")
|
yield from self.negotiate(conn, b"")
|
||||||
@ -274,6 +278,7 @@ class ClientTLSLayer(_TLSLayer):
|
|||||||
super().__init__(context)
|
super().__init__(context)
|
||||||
self.recv_buffer = bytearray()
|
self.recv_buffer = bytearray()
|
||||||
self.child_layer = layer.NextLayer(self.context)
|
self.child_layer = layer.NextLayer(self.context)
|
||||||
|
self._handle_event = self.state_start
|
||||||
|
|
||||||
@expect(events.Start)
|
@expect(events.Start)
|
||||||
def state_start(self, _) -> commands.TCommandGenerator:
|
def state_start(self, _) -> commands.TCommandGenerator:
|
||||||
@ -281,9 +286,6 @@ class ClientTLSLayer(_TLSLayer):
|
|||||||
self._handle_event = self.state_wait_for_clienthello
|
self._handle_event = self.state_wait_for_clienthello
|
||||||
yield from ()
|
yield from ()
|
||||||
|
|
||||||
_handle_event = state_start
|
|
||||||
|
|
||||||
@expect(events.DataReceived, events.ConnectionClosed)
|
|
||||||
def state_wait_for_clienthello(self, event: events.Event):
|
def state_wait_for_clienthello(self, event: events.Event):
|
||||||
client = self.context.client
|
client = self.context.client
|
||||||
if isinstance(event, events.DataReceived) and event.connection == client:
|
if isinstance(event, events.DataReceived) and event.connection == client:
|
||||||
@ -296,8 +298,7 @@ class ClientTLSLayer(_TLSLayer):
|
|||||||
if client_hello:
|
if client_hello:
|
||||||
yield commands.Log(f"Client Hello: {client_hello}")
|
yield commands.Log(f"Client Hello: {client_hello}")
|
||||||
|
|
||||||
# TODO: Don't do double conversion
|
client.sni = client_hello.sni
|
||||||
client.sni = client_hello.sni.encode("idna")
|
|
||||||
client.alpn_offers = client_hello.alpn_protocols
|
client.alpn_offers = client_hello.alpn_protocols
|
||||||
|
|
||||||
client_tls_requires_server_connection = (
|
client_tls_requires_server_connection = (
|
||||||
@ -322,8 +323,10 @@ class ClientTLSLayer(_TLSLayer):
|
|||||||
|
|
||||||
# In any case, we now have enough information to start server TLS if needed.
|
# In any case, we now have enough information to start server TLS if needed.
|
||||||
yield from self.event_to_child(events.Start())
|
yield from self.event_to_child(events.Start())
|
||||||
|
elif isinstance(event, events.ConnectionClosed) and event.connection == client:
|
||||||
|
self.recv_buffer.clear()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(event) # TODO
|
yield from self.event_to_child(event)
|
||||||
|
|
||||||
def start_server_tls(self):
|
def start_server_tls(self):
|
||||||
"""
|
"""
|
||||||
@ -339,11 +342,6 @@ class ClientTLSLayer(_TLSLayer):
|
|||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
|
|
||||||
server.alpn_offers = [
|
|
||||||
x for x in self.context.client.alpn_offers
|
|
||||||
if not (x.startswith(b"h2-") or x.startswith(b"spdy"))
|
|
||||||
]
|
|
||||||
|
|
||||||
err = yield EstablishServerTLS(server)
|
err = yield EstablishServerTLS(server)
|
||||||
if err:
|
if err:
|
||||||
yield commands.Log(
|
yield commands.Log(
|
||||||
@ -352,36 +350,10 @@ class ClientTLSLayer(_TLSLayer):
|
|||||||
return err
|
return err
|
||||||
|
|
||||||
def start_client_tls(self) -> commands.TCommandGenerator:
|
def start_client_tls(self) -> commands.TCommandGenerator:
|
||||||
# FIXME: Do this properly. Also adjust error message in negotiate()
|
|
||||||
client = self.context.client
|
client = self.context.client
|
||||||
server = self.context.server
|
tls_start = TlsStart(client, self.context)
|
||||||
context = SSL.Context(SSL.SSLv23_METHOD)
|
yield commands.Hook("tls_start", tls_start)
|
||||||
cert, privkey, cert_chain = CertStore.from_store(
|
self.tls[client] = tls_start.ssl_conn
|
||||||
os.path.expanduser("~/.mitmproxy"), "mitmproxy",
|
|
||||||
self.context.options.key_size
|
|
||||||
).get_cert(client.sni, (client.sni,))
|
|
||||||
context.use_privatekey(privkey)
|
|
||||||
context.use_certificate(cert.x509)
|
|
||||||
context.set_cipher_list(DEFAULT_CLIENT_CIPHERS)
|
|
||||||
|
|
||||||
def alpn_select_callback(conn_, options):
|
|
||||||
if server.alpn in options:
|
|
||||||
return server.alpn
|
|
||||||
elif b"h2" in options:
|
|
||||||
return b"h2"
|
|
||||||
elif b"http/1.1" in options:
|
|
||||||
return b"http/1.1"
|
|
||||||
elif b"http/1.0" in options:
|
|
||||||
return b"http/1.0"
|
|
||||||
elif b"http/0.9" in options:
|
|
||||||
return b"http/0.9"
|
|
||||||
else:
|
|
||||||
# FIXME: We MUST return something here. At this point we are at loss.
|
|
||||||
return options[0]
|
|
||||||
|
|
||||||
context.set_alpn_select_callback(alpn_select_callback)
|
|
||||||
|
|
||||||
self.tls[client] = SSL.Connection(context)
|
|
||||||
self.tls[client].set_accept_state()
|
self.tls[client].set_accept_state()
|
||||||
|
|
||||||
yield from self.negotiate(client, bytes(self.recv_buffer))
|
yield from self.negotiate(client, bytes(self.recv_buffer))
|
||||||
@ -390,11 +362,16 @@ class ClientTLSLayer(_TLSLayer):
|
|||||||
def negotiate(self, conn: context.Connection, data: bytes) -> Generator[commands.Command, Any, bool]:
|
def negotiate(self, conn: context.Connection, data: bytes) -> Generator[commands.Command, Any, bool]:
|
||||||
done, err = yield from super().negotiate(conn, data)
|
done, err = yield from super().negotiate(conn, data)
|
||||||
if err:
|
if err:
|
||||||
|
if self.context.client.sni:
|
||||||
|
# TODO: Also use other sources than SNI
|
||||||
|
dest = " for " + self.context.client.sni.decode("idna")
|
||||||
|
else:
|
||||||
|
dest = ""
|
||||||
yield commands.Log(
|
yield commands.Log(
|
||||||
f"Client TLS Handshake failed. "
|
f"Client TLS Handshake failed. "
|
||||||
f"The client may not trust the proxy's certificate (SNI: {self.context.client.sni}).",
|
f"The client may not trust the proxy's certificate{dest} ({err}).",
|
||||||
level="warn"
|
level="warn"
|
||||||
# TODO: Also use other sources than SNI
|
|
||||||
)
|
)
|
||||||
yield commands.CloseConnection(self.context.client)
|
yield commands.CloseConnection(self.context.client)
|
||||||
return done
|
return done
|
||||||
|
@ -24,7 +24,7 @@ def test_open_connection(tctx):
|
|||||||
def test_open_connection_err(tctx):
|
def test_open_connection_err(tctx):
|
||||||
f = Placeholder()
|
f = Placeholder()
|
||||||
assert (
|
assert (
|
||||||
playbook(TCPLayer(tctx), hooks=True)
|
playbook(TCPLayer(tctx))
|
||||||
<< Hook("tcp_start", f)
|
<< Hook("tcp_start", f)
|
||||||
>> reply()
|
>> reply()
|
||||||
<< OpenConnection(tctx.server)
|
<< OpenConnection(tctx.server)
|
||||||
@ -40,7 +40,7 @@ def test_simple(tctx):
|
|||||||
f = Placeholder()
|
f = Placeholder()
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
playbook(TCPLayer(tctx), hooks=True)
|
playbook(TCPLayer(tctx))
|
||||||
<< Hook("tcp_start", f)
|
<< Hook("tcp_start", f)
|
||||||
>> reply()
|
>> reply()
|
||||||
<< OpenConnection(tctx.server)
|
<< OpenConnection(tctx.server)
|
||||||
@ -71,7 +71,7 @@ def test_receive_data_before_server_connected(tctx):
|
|||||||
will still be forwarded.
|
will still be forwarded.
|
||||||
"""
|
"""
|
||||||
assert (
|
assert (
|
||||||
playbook(TCPLayer(tctx))
|
playbook(TCPLayer(tctx), hooks=False)
|
||||||
<< OpenConnection(tctx.server)
|
<< OpenConnection(tctx.server)
|
||||||
>> DataReceived(tctx.client, b"hello!")
|
>> DataReceived(tctx.client, b"hello!")
|
||||||
>> reply(None, to=-2)
|
>> reply(None, to=-2)
|
||||||
|
@ -3,11 +3,15 @@ import ssl
|
|||||||
import typing
|
import typing
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from OpenSSL import SSL
|
||||||
|
|
||||||
from mitmproxy.proxy2 import context, events, commands
|
from mitmproxy.proxy2 import commands, context, events
|
||||||
from mitmproxy.proxy2.layers import tls
|
from mitmproxy.proxy2.layers import tls
|
||||||
|
from mitmproxy.utils import data
|
||||||
from test.mitmproxy.proxy2 import tutils
|
from test.mitmproxy.proxy2 import tutils
|
||||||
|
|
||||||
|
tlsdata = data.Data(__name__)
|
||||||
|
|
||||||
|
|
||||||
def test_is_tls_handshake_record():
|
def test_is_tls_handshake_record():
|
||||||
assert tls.is_tls_handshake_record(bytes.fromhex("160300"))
|
assert tls.is_tls_handshake_record(bytes.fromhex("160300"))
|
||||||
@ -33,11 +37,11 @@ def test_record_contents():
|
|||||||
|
|
||||||
|
|
||||||
def test_record_contents_err():
|
def test_record_contents_err():
|
||||||
with pytest.raises(ValueError, msg="Expected TLS record"):
|
with pytest.raises(ValueError, match="Expected TLS record"):
|
||||||
next(tls.handshake_record_contents(b"GET /error"))
|
next(tls.handshake_record_contents(b"GET /error"))
|
||||||
|
|
||||||
empty_record = bytes.fromhex("1603010000")
|
empty_record = bytes.fromhex("1603010000")
|
||||||
with pytest.raises(ValueError, msg="Record must not be empty"):
|
with pytest.raises(ValueError, match="Record must not be empty"):
|
||||||
next(tls.handshake_record_contents(empty_record))
|
next(tls.handshake_record_contents(empty_record))
|
||||||
|
|
||||||
|
|
||||||
@ -53,8 +57,8 @@ def test_get_client_hello():
|
|||||||
assert tls.get_client_hello(single_record) == client_hello_no_extensions
|
assert tls.get_client_hello(single_record) == client_hello_no_extensions
|
||||||
|
|
||||||
split_over_two_records = (
|
split_over_two_records = (
|
||||||
bytes.fromhex("1603010020") + client_hello_no_extensions[:32] +
|
bytes.fromhex("1603010020") + client_hello_no_extensions[:32] +
|
||||||
bytes.fromhex("1603010045") + client_hello_no_extensions[32:]
|
bytes.fromhex("1603010045") + client_hello_no_extensions[32:]
|
||||||
)
|
)
|
||||||
assert tls.get_client_hello(split_over_two_records) == client_hello_no_extensions
|
assert tls.get_client_hello(split_over_two_records) == client_hello_no_extensions
|
||||||
|
|
||||||
@ -65,7 +69,8 @@ def test_get_client_hello():
|
|||||||
class SSLTest:
|
class SSLTest:
|
||||||
"""Helper container for Python's builtin SSL object."""
|
"""Helper container for Python's builtin SSL object."""
|
||||||
|
|
||||||
def __init__(self, server_side=False, alpn=None):
|
def __init__(self, server_side: bool = False, alpn: typing.List[bytes] = None,
|
||||||
|
sni: typing.Optional[bytes] = b"example.com"):
|
||||||
self.inc = ssl.MemoryBIO()
|
self.inc = ssl.MemoryBIO()
|
||||||
self.out = ssl.MemoryBIO()
|
self.out = ssl.MemoryBIO()
|
||||||
self.ctx = ssl.SSLContext()
|
self.ctx = ssl.SSLContext()
|
||||||
@ -77,83 +82,78 @@ class SSLTest:
|
|||||||
self.obj = self.ctx.wrap_bio(
|
self.obj = self.ctx.wrap_bio(
|
||||||
self.inc,
|
self.inc,
|
||||||
self.out,
|
self.out,
|
||||||
server_hostname=None if server_side else "example.com",
|
server_hostname=None if server_side else sni,
|
||||||
server_side=server_side,
|
server_side=server_side,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _test_tls_client_server(
|
def _test_echo(playbook: tutils.playbook, tssl: SSLTest, conn: context.Connection) -> None:
|
||||||
tctx: context.Context,
|
|
||||||
alpn: typing.Optional[str]
|
|
||||||
) -> typing.Tuple[tutils.playbook[tls.ClientTLSLayer], SSLTest]:
|
|
||||||
layer = tls.ClientTLSLayer(tctx)
|
|
||||||
playbook = tutils.playbook(layer)
|
|
||||||
tctx.server.tls = True
|
|
||||||
tctx.server.address = ("example.com", 443)
|
|
||||||
tssl_client = SSLTest(alpn=alpn)
|
|
||||||
|
|
||||||
# Handshake
|
|
||||||
assert (
|
|
||||||
playbook
|
|
||||||
<< None
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(ssl.SSLWantReadError):
|
|
||||||
tssl_client.obj.do_handshake()
|
|
||||||
client_hello = tssl_client.out.read()
|
|
||||||
assert (
|
|
||||||
playbook
|
|
||||||
>> events.DataReceived(tctx.client, client_hello[:42])
|
|
||||||
<< None
|
|
||||||
)
|
|
||||||
# Still waiting...
|
|
||||||
# Finish sending ClientHello
|
|
||||||
playbook >> events.DataReceived(tctx.client, client_hello[42:])
|
|
||||||
return playbook, tssl_client
|
|
||||||
|
|
||||||
|
|
||||||
def echo(playbook: tutils.playbook, tssl: SSLTest, conn: context.Connection) -> None:
|
|
||||||
tssl.obj.write(b"Hello World")
|
tssl.obj.write(b"Hello World")
|
||||||
data = tutils.Placeholder()
|
data = tutils.Placeholder()
|
||||||
assert (
|
assert (
|
||||||
playbook
|
playbook
|
||||||
>> events.DataReceived(conn, tssl.out.read())
|
>> events.DataReceived(conn, tssl.out.read())
|
||||||
<< commands.Hook("next_layer", tutils.Placeholder())
|
<< commands.SendData(conn, data)
|
||||||
>> tutils.next_layer(tutils.EchoLayer)
|
|
||||||
<< commands.SendData(conn, data)
|
|
||||||
)
|
)
|
||||||
tssl.inc.write(data())
|
tssl.inc.write(data())
|
||||||
assert tssl.obj.read() == b"hello world"
|
assert tssl.obj.read() == b"hello world"
|
||||||
|
|
||||||
|
|
||||||
|
class TlsEchoLayer(tutils.EchoLayer):
|
||||||
|
err: typing.Optional[str] = None
|
||||||
|
|
||||||
|
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||||
|
if isinstance(event, events.DataReceived) and event.data == b"establish-server-tls":
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
self.err = yield tls.EstablishServerTLS(self.context.server)
|
||||||
|
else:
|
||||||
|
yield from super()._handle_event(event)
|
||||||
|
|
||||||
|
|
||||||
|
def interact(playbook: tutils.playbook, conn: context.Connection, tssl: SSLTest):
|
||||||
|
data = tutils.Placeholder()
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
>> events.DataReceived(conn, tssl.out.read())
|
||||||
|
<< commands.SendData(conn, data)
|
||||||
|
)
|
||||||
|
tssl.inc.write(data())
|
||||||
|
|
||||||
|
|
||||||
|
def reply_tls_start(*args, **kwargs) -> tutils.reply:
|
||||||
|
"""
|
||||||
|
Helper function to simplify the syntax for tls_start hooks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def make_conn(hook: commands.Hook) -> None:
|
||||||
|
tls_start = hook.data
|
||||||
|
assert isinstance(tls_start, tls.TlsStart)
|
||||||
|
ssl_context = SSL.Context(SSL.SSLv23_METHOD)
|
||||||
|
if tls_start.conn == tls_start.context.client:
|
||||||
|
ssl_context.use_privatekey_file(
|
||||||
|
tlsdata.path("../../net/data/verificationcerts/trusted-leaf.key")
|
||||||
|
)
|
||||||
|
ssl_context.use_certificate_chain_file(
|
||||||
|
tlsdata.path("../../net/data/verificationcerts/trusted-leaf.crt")
|
||||||
|
)
|
||||||
|
tls_start.ssl_conn = SSL.Connection(ssl_context)
|
||||||
|
|
||||||
|
return tutils.reply(*args, side_effect=make_conn, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class TestServerTLS:
|
class TestServerTLS:
|
||||||
def test_no_tls(self, tctx: context.Context):
|
def test_no_tls(self, tctx: context.Context):
|
||||||
"""Test TLS layer without TLS"""
|
"""Test TLS layer without TLS"""
|
||||||
layer = tls.ServerTLSLayer(tctx)
|
layer = tls.ServerTLSLayer(tctx)
|
||||||
playbook = tutils.playbook(layer)
|
layer.child_layer = TlsEchoLayer(tctx)
|
||||||
|
|
||||||
# Handshake
|
# Handshake
|
||||||
assert (
|
assert (
|
||||||
playbook
|
tutils.playbook(layer)
|
||||||
>> events.DataReceived(tctx.client, b"Hello World")
|
>> events.DataReceived(tctx.client, b"Hello World")
|
||||||
<< commands.Hook("next_layer", tutils.Placeholder())
|
<< commands.SendData(tctx.client, b"hello world")
|
||||||
>> tutils.next_layer(tutils.EchoLayer)
|
>> events.DataReceived(tctx.server, b"Foo")
|
||||||
<< commands.SendData(tctx.client, b"hello world")
|
<< commands.SendData(tctx.server, b"foo")
|
||||||
)
|
|
||||||
|
|
||||||
def test_no_connection(self, tctx):
|
|
||||||
"""
|
|
||||||
The server TLS layer is initiated, but there is no active connection yet, so nothing
|
|
||||||
should be done.
|
|
||||||
"""
|
|
||||||
layer = tls.ServerTLSLayer(tctx)
|
|
||||||
playbook = tutils.playbook(layer)
|
|
||||||
tctx.server.tls = True
|
|
||||||
|
|
||||||
# We did not have a server connection before, so let's do nothing.
|
|
||||||
assert (
|
|
||||||
playbook
|
|
||||||
<< None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_simple(self, tctx):
|
def test_simple(self, tctx):
|
||||||
@ -161,166 +161,171 @@ class TestServerTLS:
|
|||||||
playbook = tutils.playbook(layer)
|
playbook = tutils.playbook(layer)
|
||||||
tctx.server.connected = True
|
tctx.server.connected = True
|
||||||
tctx.server.address = ("example.com", 443)
|
tctx.server.address = ("example.com", 443)
|
||||||
tctx.server.tls = True
|
|
||||||
|
|
||||||
tssl = SSLTest(server_side=True)
|
tssl = SSLTest(server_side=True)
|
||||||
|
|
||||||
# send ClientHello
|
# send ClientHello
|
||||||
data = tutils.Placeholder()
|
data = tutils.Placeholder()
|
||||||
assert (
|
assert (
|
||||||
playbook
|
playbook
|
||||||
<< commands.SendData(tctx.server, data)
|
>> events.DataReceived(tctx.client, b"establish-server-tls")
|
||||||
|
<< commands.Hook("next_layer", tutils.Placeholder())
|
||||||
|
>> tutils.next_layer(TlsEchoLayer)
|
||||||
|
<< commands.Hook("tls_start", tutils.Placeholder())
|
||||||
|
>> reply_tls_start()
|
||||||
|
<< commands.SendData(tctx.server, data)
|
||||||
)
|
)
|
||||||
|
|
||||||
# receive ServerHello, finish client handshake
|
# receive ServerHello, finish client handshake
|
||||||
tssl.inc.write(data())
|
tssl.inc.write(data())
|
||||||
with pytest.raises(ssl.SSLWantReadError):
|
with pytest.raises(ssl.SSLWantReadError):
|
||||||
tssl.obj.do_handshake()
|
tssl.obj.do_handshake()
|
||||||
data = tutils.Placeholder()
|
interact(playbook, tctx.server, tssl)
|
||||||
assert (
|
|
||||||
playbook
|
|
||||||
>> events.DataReceived(tctx.server, tssl.out.read())
|
|
||||||
<< commands.SendData(tctx.server, data)
|
|
||||||
)
|
|
||||||
tssl.inc.write(data())
|
|
||||||
|
|
||||||
# finish server handshake
|
# finish server handshake
|
||||||
tssl.obj.do_handshake()
|
tssl.obj.do_handshake()
|
||||||
assert (
|
assert (
|
||||||
playbook
|
playbook
|
||||||
>> events.DataReceived(tctx.server, tssl.out.read())
|
>> events.DataReceived(tctx.server, tssl.out.read())
|
||||||
<< None
|
<< None
|
||||||
)
|
)
|
||||||
|
|
||||||
assert tctx.server.tls_established
|
assert tctx.server.tls_established
|
||||||
assert tctx.server.sni == b"example.com"
|
|
||||||
|
|
||||||
# Echo
|
# Echo
|
||||||
echo(playbook, tssl, tctx.server)
|
_test_echo(playbook, tssl, tctx.server)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_client_tls_layer(tctx: context.Context) -> typing.Tuple[tutils.playbook, tls.ClientTLSLayer]:
|
||||||
|
# This is a bit contrived as the client layer expects a server layer as parent.
|
||||||
|
# We also set child layers manually to avoid NextLayer noise.
|
||||||
|
server_layer = tls.ServerTLSLayer(tctx)
|
||||||
|
client_layer = tls.ClientTLSLayer(tctx)
|
||||||
|
server_layer.child_layer = client_layer
|
||||||
|
client_layer.child_layer = TlsEchoLayer(tctx)
|
||||||
|
playbook = tutils.playbook(server_layer)
|
||||||
|
return playbook, client_layer
|
||||||
|
|
||||||
|
|
||||||
|
def _test_tls_client_server(
|
||||||
|
tctx: context.Context,
|
||||||
|
sni: typing.Optional[bytes]
|
||||||
|
) -> typing.Tuple[tutils.playbook, tls.ClientTLSLayer, SSLTest]:
|
||||||
|
playbook, client_layer = _make_client_tls_layer(tctx)
|
||||||
|
tctx.server.tls = True
|
||||||
|
tctx.server.address = ("example.com", 443)
|
||||||
|
tssl_client = SSLTest(sni=sni)
|
||||||
|
|
||||||
|
# Send ClientHello
|
||||||
|
with pytest.raises(ssl.SSLWantReadError):
|
||||||
|
tssl_client.obj.do_handshake()
|
||||||
|
|
||||||
|
return playbook, client_layer, tssl_client
|
||||||
|
|
||||||
|
|
||||||
class TestClientTLS:
|
class TestClientTLS:
|
||||||
def test_simple(self, tctx: context.Context):
|
def test_client_only(self, tctx: context.Context):
|
||||||
"""Test TLS with client only"""
|
"""Test TLS with client only"""
|
||||||
layer = tls.ClientTLSLayer(tctx)
|
playbook, client_layer = _make_client_tls_layer(tctx)
|
||||||
playbook = tutils.playbook(layer)
|
|
||||||
tssl = SSLTest()
|
tssl = SSLTest()
|
||||||
|
assert not tctx.client.tls_established
|
||||||
|
|
||||||
# Handshake
|
# Start Handshake, send ClientHello and ServerHello
|
||||||
assert playbook
|
|
||||||
assert layer._handle_event == layer.state_wait_for_clienthello
|
|
||||||
|
|
||||||
def interact():
|
|
||||||
data = tutils.Placeholder()
|
|
||||||
assert (
|
|
||||||
playbook
|
|
||||||
>> events.DataReceived(tctx.client, tssl.out.read())
|
|
||||||
<< commands.SendData(tctx.client, data)
|
|
||||||
)
|
|
||||||
tssl.inc.write(data())
|
|
||||||
try:
|
|
||||||
tssl.obj.do_handshake()
|
|
||||||
except ssl.SSLWantReadError:
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# receive ClientHello, send ServerHello
|
|
||||||
with pytest.raises(ssl.SSLWantReadError):
|
with pytest.raises(ssl.SSLWantReadError):
|
||||||
tssl.obj.do_handshake()
|
tssl.obj.do_handshake()
|
||||||
assert not interact()
|
data = tutils.Placeholder()
|
||||||
# Finish Handshake
|
assert (
|
||||||
assert interact()
|
playbook
|
||||||
|
>> events.DataReceived(tctx.client, tssl.out.read())
|
||||||
|
<< commands.Hook("tls_start", tutils.Placeholder())
|
||||||
|
>> reply_tls_start()
|
||||||
|
<< commands.SendData(tctx.client, data)
|
||||||
|
)
|
||||||
|
tssl.inc.write(data())
|
||||||
tssl.obj.do_handshake()
|
tssl.obj.do_handshake()
|
||||||
|
# Finish Handshake
|
||||||
|
interact(playbook, tctx.client, tssl)
|
||||||
|
|
||||||
assert layer._handle_event == layer.state_process
|
assert tssl.obj.getpeercert(True)
|
||||||
|
assert tctx.client.tls_established
|
||||||
|
|
||||||
# Echo
|
# Echo
|
||||||
echo(playbook, tssl, tctx.client)
|
_test_echo(playbook, tssl, tctx.client)
|
||||||
assert (
|
assert (
|
||||||
playbook
|
playbook
|
||||||
>> events.DataReceived(tctx.server, b"Hello")
|
>> events.DataReceived(tctx.server, b"Plaintext")
|
||||||
<< commands.SendData(tctx.server, b"hello")
|
<< commands.SendData(tctx.server, b"plaintext")
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_no_server_conn_required(self, tctx):
|
def test_server_not_required(self, tctx):
|
||||||
"""
|
"""
|
||||||
Here we test the scenario where a server connection is _not_ required
|
Here we test the scenario where a server connection is _not_ required
|
||||||
to establish TLS with the client. After determining this when parsing the ClientHello,
|
to establish TLS with the client. After determining this when parsing the ClientHello,
|
||||||
we only establish a connection with the client. The server connection may ultimately
|
we only establish a connection with the client. The server connection may ultimately
|
||||||
be established when OpenConnection is called.
|
be established when OpenConnection is called.
|
||||||
"""
|
"""
|
||||||
playbook, _ = _test_tls_client_server(tctx, None)
|
playbook, client_layer, tssl = _test_tls_client_server(tctx, sni=b"example.com")
|
||||||
data = tutils.Placeholder()
|
data = tutils.Placeholder()
|
||||||
assert (
|
assert (
|
||||||
playbook
|
playbook
|
||||||
<< commands.SendData(tctx.client, data)
|
>> events.DataReceived(tctx.client, tssl.out.read())
|
||||||
|
<< commands.Hook("tls_start", tutils.Placeholder())
|
||||||
|
>> reply_tls_start()
|
||||||
|
<< commands.SendData(tctx.client, data)
|
||||||
)
|
)
|
||||||
assert data()
|
tssl.inc.write(data())
|
||||||
assert playbook.layer._handle_event == playbook.layer.state_process
|
tssl.obj.do_handshake()
|
||||||
|
interact(playbook, tctx.client, tssl)
|
||||||
|
assert tctx.client.tls_established
|
||||||
|
|
||||||
def test_alpn(self, tctx):
|
def test_server_required(self, tctx):
|
||||||
"""
|
"""
|
||||||
Here we test the scenario where a server connection is required (e.g. because of ALPN negotation)
|
Here we test the scenario where a server connection is required (because SNI is missing)
|
||||||
to establish TLS with the client.
|
to establish TLS with the client.
|
||||||
"""
|
"""
|
||||||
tssl_server = SSLTest(server_side=True, alpn=["foo", "bar"])
|
tssl_server = SSLTest(server_side=True)
|
||||||
|
playbook, client_layer, tssl_client = _test_tls_client_server(tctx, sni=None)
|
||||||
playbook, tssl_client = _test_tls_client_server(tctx, ["qux", "foo"])
|
|
||||||
|
|
||||||
# We should now get instructed to open a server connection.
|
# We should now get instructed to open a server connection.
|
||||||
assert (
|
|
||||||
playbook
|
|
||||||
<< commands.OpenConnection(tctx.server)
|
|
||||||
)
|
|
||||||
tctx.server.connected = True
|
|
||||||
data = tutils.Placeholder()
|
data = tutils.Placeholder()
|
||||||
assert (
|
assert (
|
||||||
playbook
|
playbook
|
||||||
>> events.OpenConnectionReply(-1, None)
|
>> events.DataReceived(tctx.client, tssl_client.out.read())
|
||||||
<< commands.SendData(tctx.server, data)
|
<< commands.OpenConnection(tctx.server)
|
||||||
|
>> tutils.reply(None)
|
||||||
|
<< commands.Hook("tls_start", tutils.Placeholder())
|
||||||
|
>> reply_tls_start()
|
||||||
|
<< commands.SendData(tctx.server, data)
|
||||||
)
|
)
|
||||||
assert playbook.layer._handle_event == playbook.layer.state_wait_for_server_tls
|
|
||||||
assert playbook.layer.child_layer.tls[tctx.server]
|
|
||||||
|
|
||||||
# Establish TLS with the server...
|
# Establish TLS with the server...
|
||||||
tssl_server.inc.write(data())
|
tssl_server.inc.write(data())
|
||||||
with pytest.raises(ssl.SSLWantReadError):
|
with pytest.raises(ssl.SSLWantReadError):
|
||||||
tssl_server.obj.do_handshake()
|
tssl_server.obj.do_handshake()
|
||||||
|
|
||||||
data = tutils.Placeholder()
|
data = tutils.Placeholder()
|
||||||
assert (
|
assert (
|
||||||
playbook
|
playbook
|
||||||
>> events.DataReceived(tctx.server, tssl_server.out.read())
|
>> events.DataReceived(tctx.server, tssl_server.out.read())
|
||||||
<< commands.SendData(tctx.server, data)
|
<< commands.SendData(tctx.server, data)
|
||||||
|
<< commands.Hook("tls_start", tutils.Placeholder())
|
||||||
)
|
)
|
||||||
tssl_server.inc.write(data())
|
tssl_server.inc.write(data())
|
||||||
tssl_server.obj.do_handshake()
|
|
||||||
data = tutils.Placeholder()
|
|
||||||
assert (
|
|
||||||
playbook
|
|
||||||
>> events.DataReceived(tctx.server, tssl_server.out.read())
|
|
||||||
<< commands.SendData(tctx.client, data)
|
|
||||||
)
|
|
||||||
|
|
||||||
assert playbook.layer._handle_event == playbook.layer.state_process
|
|
||||||
assert tctx.server.tls_established
|
assert tctx.server.tls_established
|
||||||
|
|
||||||
# Server TLS is established, we can now reply to the client handshake...
|
# Server TLS is established, we can now reply to the client handshake...
|
||||||
tssl_client.inc.write(data())
|
|
||||||
with pytest.raises(ssl.SSLWantReadError):
|
|
||||||
tssl_client.obj.do_handshake()
|
|
||||||
data = tutils.Placeholder()
|
data = tutils.Placeholder()
|
||||||
assert (
|
assert (
|
||||||
playbook
|
playbook
|
||||||
>> events.DataReceived(tctx.client, tssl_client.out.read())
|
>> reply_tls_start()
|
||||||
<< commands.SendData(tctx.client, data)
|
<< commands.SendData(tctx.client, data)
|
||||||
)
|
)
|
||||||
tssl_client.inc.write(data())
|
tssl_client.inc.write(data())
|
||||||
tssl_client.obj.do_handshake()
|
tssl_client.obj.do_handshake()
|
||||||
|
interact(playbook, tctx.client, tssl_client)
|
||||||
|
|
||||||
# Both handshakes completed!
|
# Both handshakes completed!
|
||||||
assert tctx.client.tls_established
|
assert tctx.client.tls_established
|
||||||
assert tctx.server.tls_established
|
assert tctx.server.tls_established
|
||||||
|
_test_echo(playbook, tssl_server, tctx.server)
|
||||||
assert tssl_client.obj.selected_alpn_protocol() == "foo"
|
_test_echo(playbook, tssl_client, tctx.client)
|
||||||
assert tssl_server.obj.selected_alpn_protocol() == "foo"
|
|
||||||
|
@ -5,13 +5,13 @@ from test.mitmproxy.proxy2 import tutils
|
|||||||
class TestNextLayer:
|
class TestNextLayer:
|
||||||
def test_simple(self, tctx):
|
def test_simple(self, tctx):
|
||||||
nl = layer.NextLayer(tctx)
|
nl = layer.NextLayer(tctx)
|
||||||
playbook = tutils.playbook(nl)
|
playbook = tutils.playbook(nl, hooks=True)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
playbook
|
playbook
|
||||||
>> events.DataReceived(tctx.client, b"foo")
|
>> events.DataReceived(tctx.client, b"foo")
|
||||||
<< commands.Hook("next_layer", nl)
|
<< commands.Hook("next_layer", nl)
|
||||||
>> events.HookReply(-1)
|
>> tutils.reply()
|
||||||
>> events.DataReceived(tctx.client, b"bar")
|
>> events.DataReceived(tctx.client, b"bar")
|
||||||
<< commands.Hook("next_layer", nl)
|
<< commands.Hook("next_layer", nl)
|
||||||
)
|
)
|
||||||
@ -21,7 +21,7 @@ class TestNextLayer:
|
|||||||
nl.layer = tutils.EchoLayer(tctx)
|
nl.layer = tutils.EchoLayer(tctx)
|
||||||
assert (
|
assert (
|
||||||
playbook
|
playbook
|
||||||
>> events.HookReply(-1)
|
>> tutils.reply()
|
||||||
<< commands.SendData(tctx.client, b"foo")
|
<< commands.SendData(tctx.client, b"foo")
|
||||||
<< commands.SendData(tctx.client, b"bar")
|
<< commands.SendData(tctx.client, b"bar")
|
||||||
)
|
)
|
||||||
@ -45,7 +45,7 @@ class TestNextLayer:
|
|||||||
|
|
||||||
assert (
|
assert (
|
||||||
playbook
|
playbook
|
||||||
>> events.HookReply(-2)
|
>> tutils.reply(to=-2)
|
||||||
<< commands.SendData(tctx.client, b"foo")
|
<< commands.SendData(tctx.client, b"foo")
|
||||||
<< commands.SendData(tctx.client, b"bar")
|
<< commands.SendData(tctx.client, b"bar")
|
||||||
)
|
)
|
||||||
@ -63,7 +63,7 @@ class TestNextLayer:
|
|||||||
handle = nl.handle_event
|
handle = nl.handle_event
|
||||||
assert (
|
assert (
|
||||||
playbook
|
playbook
|
||||||
>> events.HookReply(-1)
|
>> tutils.reply()
|
||||||
<< commands.SendData(tctx.client, b"foo")
|
<< commands.SendData(tctx.client, b"foo")
|
||||||
)
|
)
|
||||||
sd, = handle(events.DataReceived(tctx.client, b"bar"))
|
sd, = handle(events.DataReceived(tctx.client, b"bar"))
|
||||||
|
@ -38,7 +38,7 @@ class TLayer(Layer):
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def tplaybook(tctx):
|
def tplaybook(tctx):
|
||||||
return tutils.playbook(TLayer(tctx), [])
|
return tutils.playbook(TLayer(tctx), expected=[])
|
||||||
|
|
||||||
|
|
||||||
def test_simple(tplaybook):
|
def test_simple(tplaybook):
|
||||||
@ -158,7 +158,7 @@ def test_command_reply(tplaybook):
|
|||||||
tplaybook
|
tplaybook
|
||||||
>> TEvent()
|
>> TEvent()
|
||||||
<< TCommand()
|
<< TCommand()
|
||||||
>> TCommandReply(-1, 42)
|
>> tutils.reply(42)
|
||||||
)
|
)
|
||||||
assert tplaybook.actual[1] == tplaybook.actual[2].command
|
assert tplaybook.actual[1] == tplaybook.actual[2].command
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ import collections.abc
|
|||||||
import copy
|
import copy
|
||||||
import difflib
|
import difflib
|
||||||
import itertools
|
import itertools
|
||||||
|
import sys
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from mitmproxy.proxy2 import commands, context
|
from mitmproxy.proxy2 import commands, context
|
||||||
@ -101,7 +102,7 @@ class playbook:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
layer: Layer,
|
layer: Layer,
|
||||||
hooks: bool = False,
|
hooks: bool = True,
|
||||||
logs: bool = False,
|
logs: bool = False,
|
||||||
expected: typing.Optional[TPlaybook] = None,
|
expected: typing.Optional[TPlaybook] = None,
|
||||||
):
|
):
|
||||||
@ -196,13 +197,13 @@ class playbook:
|
|||||||
class reply(events.Event):
|
class reply(events.Event):
|
||||||
args: typing.Tuple[typing.Any, ...]
|
args: typing.Tuple[typing.Any, ...]
|
||||||
to: typing.Union[commands.Command, int]
|
to: typing.Union[commands.Command, int]
|
||||||
side_effect: typing.Callable[[commands.Command], typing.Any]
|
side_effect: typing.Callable[[typing.Any], typing.Any]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
to: typing.Union[commands.Command, int] = -1,
|
to: typing.Union[commands.Command, int] = -1,
|
||||||
side_effect: typing.Callable[[commands.Command], typing.Any] = lambda cmd: None
|
side_effect: typing.Callable[[typing.Any], None] = lambda x: None
|
||||||
):
|
):
|
||||||
"""Utility method to reply to the latest hook in playbooks."""
|
"""Utility method to reply to the latest hook in playbooks."""
|
||||||
self.args = args
|
self.args = args
|
||||||
@ -226,6 +227,7 @@ class reply(events.Event):
|
|||||||
actual_str = "\n".join(_fmt_entry(x) for x in playbook.actual)
|
actual_str = "\n".join(_fmt_entry(x) for x in playbook.actual)
|
||||||
raise AssertionError(f"Expected command ({self.to}) did not occur:\n{actual_str}")
|
raise AssertionError(f"Expected command ({self.to}) did not occur:\n{actual_str}")
|
||||||
|
|
||||||
|
assert isinstance(self.to, commands.Command)
|
||||||
self.side_effect(self.to)
|
self.side_effect(self.to)
|
||||||
reply_cls = command_reply_subclasses[type(self.to)]
|
reply_cls = command_reply_subclasses[type(self.to)]
|
||||||
try:
|
try:
|
||||||
@ -272,14 +274,16 @@ def Placeholder() -> typing.Any:
|
|||||||
class EchoLayer(Layer):
|
class EchoLayer(Layer):
|
||||||
"""Echo layer that sends all data back to the client in lowercase."""
|
"""Echo layer that sends all data back to the client in lowercase."""
|
||||||
|
|
||||||
def _handle_event(self, event: events.Event):
|
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||||
if isinstance(event, events.DataReceived):
|
if isinstance(event, events.DataReceived):
|
||||||
yield commands.SendData(event.connection, event.data.lower())
|
yield commands.SendData(event.connection, event.data.lower())
|
||||||
|
|
||||||
|
|
||||||
def next_layer(
|
def next_layer(
|
||||||
layer: typing.Union[typing.Type[Layer], typing.Callable[[context.Context], Layer]]
|
layer: typing.Union[typing.Type[Layer], typing.Callable[[context.Context], Layer]],
|
||||||
) -> events.HookReply:
|
*args,
|
||||||
|
**kwargs
|
||||||
|
) -> reply:
|
||||||
"""
|
"""
|
||||||
Helper function to simplify the syntax for next_layer events from this:
|
Helper function to simplify the syntax for next_layer events from this:
|
||||||
|
|
||||||
@ -294,21 +298,10 @@ def next_layer(
|
|||||||
|
|
||||||
<< commands.Hook("next_layer", next_layer)
|
<< commands.Hook("next_layer", next_layer)
|
||||||
>> tutils.next_layer(next_layer, tutils.EchoLayer)
|
>> tutils.next_layer(next_layer, tutils.EchoLayer)
|
||||||
>> tutils.reply(side_effect=lambda cmd: cmd.layer = tutils.EchoLayer(cmd.data.context)
|
|
||||||
"""
|
"""
|
||||||
raise RuntimeError("Does tutils.reply(side_effect=lambda cmd: cmd.layer = tutils.EchoLayer(cmd.data.context) work?")
|
|
||||||
if isinstance(layer, type):
|
|
||||||
def make_layer(ctx: context.Context) -> Layer:
|
|
||||||
return layer(ctx)
|
|
||||||
else:
|
|
||||||
make_layer = layer
|
|
||||||
|
|
||||||
def set_layer(playbook: playbook) -> None:
|
def set_layer(hook: commands.Hook) -> None:
|
||||||
last_command = playbook.actual[-1]
|
assert isinstance(hook.data, NextLayer)
|
||||||
assert isinstance(last_command, commands.Hook)
|
hook.data.layer = layer(hook.data.context)
|
||||||
assert isinstance(last_command.data, NextLayer)
|
|
||||||
last_command.data.layer = make_layer(last_command.data.context)
|
|
||||||
|
|
||||||
reply = events.HookReply(-1)
|
return reply(*args, side_effect=set_layer, **kwargs)
|
||||||
reply._playbook_eval = set_layer
|
|
||||||
return reply
|
|
||||||
|
Loading…
Reference in New Issue
Block a user