mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 18:18:25 +00:00
[sans-io] add ALPN support to TLS layer, tests++
This commit is contained in:
parent
3f7b850268
commit
b5a3343d03
@ -1,11 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import MutableMapping, Generator, Optional, Iterable, Iterator
|
from typing import MutableMapping, Optional, Iterator
|
||||||
|
|
||||||
from OpenSSL import SSL
|
from OpenSSL import SSL
|
||||||
|
|
||||||
from mitmproxy import exceptions
|
|
||||||
from mitmproxy.certs import CertStore
|
from mitmproxy.certs import CertStore
|
||||||
from mitmproxy.proxy.protocol import TlsClientHello
|
from mitmproxy.proxy.protocol import TlsClientHello
|
||||||
from mitmproxy.proxy.protocol import tls
|
from mitmproxy.proxy.protocol import tls
|
||||||
@ -153,8 +152,8 @@ class TLSLayer(layer.Layer):
|
|||||||
self.state[client] = ConnectionState.WAIT_FOR_CLIENTHELLO
|
self.state[client] = ConnectionState.WAIT_FOR_CLIENTHELLO
|
||||||
self.state[server] = ConnectionState.WAIT_FOR_CLIENTHELLO
|
self.state[server] = ConnectionState.WAIT_FOR_CLIENTHELLO
|
||||||
elif client.tls:
|
elif client.tls:
|
||||||
yield from self.start_client_tls()
|
|
||||||
self.state[server] = ConnectionState.NO_TLS
|
self.state[server] = ConnectionState.NO_TLS
|
||||||
|
yield from self.start_client_tls()
|
||||||
elif server.tls and server.connected:
|
elif server.tls and server.connected:
|
||||||
self.state[client] = ConnectionState.NO_TLS
|
self.state[client] = ConnectionState.NO_TLS
|
||||||
yield from self.start_server_tls()
|
yield from self.start_server_tls()
|
||||||
@ -186,17 +185,11 @@ class TLSLayer(layer.Layer):
|
|||||||
|
|
||||||
def parse_client_hello(self):
|
def parse_client_hello(self):
|
||||||
# Check if ClientHello is complete
|
# Check if ClientHello is complete
|
||||||
try:
|
client_hello = get_client_hello(self.recv_buffer[self.context.client])
|
||||||
client_hello = get_client_hello(self.recv_buffer[self.context.client])[4:]
|
if client_hello:
|
||||||
self.client_hello = TlsClientHello(client_hello)
|
self.client_hello = TlsClientHello(client_hello[4:])
|
||||||
except exceptions.TlsProtocolException:
|
|
||||||
return False
|
|
||||||
except EOFError as e:
|
|
||||||
raise exceptions.TlsProtocolException(
|
|
||||||
f'Cannot parse Client Hello: {e}, Raw Client Hello: {client_hello}'
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return True
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def process(self, event: events.Event):
|
def process(self, event: events.Event):
|
||||||
if isinstance(event, events.DataReceived):
|
if isinstance(event, events.DataReceived):
|
||||||
@ -235,21 +228,20 @@ class TLSLayer(layer.Layer):
|
|||||||
not self.client_hello.sni
|
not self.client_hello.sni
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
# What do we do with the client connection now?
|
||||||
if client_tls_requires_server_connection and not self.context.server.connected:
|
|
||||||
yield commands.OpenConnection(self.context.server)
|
|
||||||
|
|
||||||
if not self.context.server.connected:
|
|
||||||
# We are only in the WAIT_FOR_CLIENTHELLO branch if we have two TLS conns.
|
|
||||||
assert self.context.server.tls
|
|
||||||
self.state[server] = ConnectionState.NO_TLS
|
|
||||||
else:
|
|
||||||
yield from self.start_server_tls()
|
|
||||||
if client_tls_requires_server_connection:
|
if client_tls_requires_server_connection:
|
||||||
self.state[client] = ConnectionState.WAIT_FOR_SERVER_TLS
|
self.state[client] = ConnectionState.WAIT_FOR_SERVER_TLS
|
||||||
else:
|
else:
|
||||||
yield from self.start_client_tls()
|
yield from self.start_client_tls()
|
||||||
|
|
||||||
|
# What do we do with the server connection now?
|
||||||
|
if client_tls_requires_server_connection and not self.context.server.connected:
|
||||||
|
yield commands.OpenConnection(self.context.server)
|
||||||
|
if not self.context.server.connected:
|
||||||
|
self.state[server] = ConnectionState.NO_TLS
|
||||||
|
else:
|
||||||
|
yield from self.start_server_tls()
|
||||||
|
|
||||||
def process_negotiate(self, event: events.DataReceived):
|
def process_negotiate(self, event: events.DataReceived):
|
||||||
# bio_write errors for b"", so we need to check first if we actually received something.
|
# bio_write errors for b"", so we need to check first if we actually received something.
|
||||||
if event.data:
|
if event.data:
|
||||||
@ -291,16 +283,23 @@ class TLSLayer(layer.Layer):
|
|||||||
server = self.context.server
|
server = self.context.server
|
||||||
|
|
||||||
ssl_context = SSL.Context(SSL.SSLv23_METHOD)
|
ssl_context = SSL.Context(SSL.SSLv23_METHOD)
|
||||||
|
|
||||||
|
if self.client_hello:
|
||||||
|
alpn = [
|
||||||
|
x for x in self.client_hello.alpn_protocols
|
||||||
|
if not (x.startswith(b"h2-") or x.startswith(b"spdy"))
|
||||||
|
]
|
||||||
|
ssl_context.set_alpn_protos(alpn)
|
||||||
|
|
||||||
self.tls[server] = SSL.Connection(ssl_context)
|
self.tls[server] = SSL.Connection(ssl_context)
|
||||||
|
|
||||||
if server.sni:
|
if server.sni:
|
||||||
if server.sni is True:
|
if server.sni is True:
|
||||||
if self.client_hello:
|
if self.client_hello and self.client_hello.sni:
|
||||||
server.sni = self.client_hello.sni.encode("idna")
|
server.sni = self.client_hello.sni.encode("idna")
|
||||||
else:
|
else:
|
||||||
server.sni = server.address[0].encode("idna")
|
server.sni = server.address[0].encode("idna")
|
||||||
self.tls[server].set_tlsext_host_name(server.sni)
|
self.tls[server].set_tlsext_host_name(server.sni)
|
||||||
# FIXME: Handle ALPN
|
|
||||||
self.tls[server].set_connect_state()
|
self.tls[server].set_connect_state()
|
||||||
|
|
||||||
self.state[server] = ConnectionState.NEGOTIATING
|
self.state[server] = ConnectionState.NEGOTIATING
|
||||||
@ -312,6 +311,7 @@ class TLSLayer(layer.Layer):
|
|||||||
def start_client_tls(self):
|
def start_client_tls(self):
|
||||||
# FIXME
|
# FIXME
|
||||||
client = self.context.client
|
client = self.context.client
|
||||||
|
server = self.context.server
|
||||||
context = SSL.Context(SSL.SSLv23_METHOD)
|
context = SSL.Context(SSL.SSLv23_METHOD)
|
||||||
cert, privkey, cert_chain = CertStore.from_store(
|
cert, privkey, cert_chain = CertStore.from_store(
|
||||||
os.path.expanduser("~/.mitmproxy"), "mitmproxy"
|
os.path.expanduser("~/.mitmproxy"), "mitmproxy"
|
||||||
@ -319,6 +319,16 @@ class TLSLayer(layer.Layer):
|
|||||||
context.use_privatekey(privkey)
|
context.use_privatekey(privkey)
|
||||||
context.use_certificate(cert.x509)
|
context.use_certificate(cert.x509)
|
||||||
context.set_cipher_list(tls.DEFAULT_CLIENT_CIPHERS)
|
context.set_cipher_list(tls.DEFAULT_CLIENT_CIPHERS)
|
||||||
|
|
||||||
|
if self.state[server] == ConnectionState.ESTABLISHED:
|
||||||
|
alpn_for_client = self.tls[server].get_alpn_proto_negotiated()
|
||||||
|
|
||||||
|
def alpn_select_callback(conn_, options):
|
||||||
|
if alpn_for_client in options:
|
||||||
|
return alpn_for_client
|
||||||
|
|
||||||
|
context.set_alpn_select_callback(alpn_select_callback)
|
||||||
|
|
||||||
self.tls[client] = SSL.Connection(context)
|
self.tls[client] = SSL.Connection(context)
|
||||||
self.tls[client].set_accept_state()
|
self.tls[client].set_accept_state()
|
||||||
|
|
||||||
|
@ -64,14 +64,21 @@ def test_get_client_hello():
|
|||||||
class SSLTest:
|
class SSLTest:
|
||||||
"""Helper container for Python's builtin SSL object."""
|
"""Helper container for Python's builtin SSL object."""
|
||||||
|
|
||||||
def __init__(self, server_side=False):
|
def __init__(self, server_side=False, alpn=None):
|
||||||
self.inc = ssl.MemoryBIO()
|
self.inc = ssl.MemoryBIO()
|
||||||
self.out = ssl.MemoryBIO()
|
self.out = ssl.MemoryBIO()
|
||||||
self.ctx = ssl.SSLContext()
|
self.ctx = ssl.SSLContext()
|
||||||
|
if alpn:
|
||||||
|
self.ctx.set_alpn_protocols(alpn)
|
||||||
if server_side:
|
if server_side:
|
||||||
# FIXME: Replace hardcoded location
|
# FIXME: Replace hardcoded location
|
||||||
self.ctx.load_cert_chain(os.path.expanduser("~/.mitmproxy/mitmproxy-ca.pem"))
|
self.ctx.load_cert_chain(os.path.expanduser("~/.mitmproxy/mitmproxy-ca.pem"))
|
||||||
self.obj = self.ctx.wrap_bio(self.inc, self.out, server_side=server_side)
|
self.obj = self.ctx.wrap_bio(
|
||||||
|
self.inc,
|
||||||
|
self.out,
|
||||||
|
server_hostname=None if server_side else "example.com",
|
||||||
|
server_side=server_side,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_no_tls(tctx: context.Context):
|
def test_no_tls(tctx: context.Context):
|
||||||
@ -211,3 +218,117 @@ def test_server_tls(tctx):
|
|||||||
|
|
||||||
# Echo
|
# Echo
|
||||||
echo(playbook, tssl, tctx.server)
|
echo(playbook, tssl, tctx.server)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_tls_client_server(tctx, alpn):
|
||||||
|
layer = tls.TLSLayer(tctx)
|
||||||
|
playbook = tutils.playbook(layer)
|
||||||
|
tctx.client.tls = True
|
||||||
|
tctx.server.tls = True
|
||||||
|
tssl_client = SSLTest(alpn=alpn)
|
||||||
|
|
||||||
|
# Handshake
|
||||||
|
assert playbook
|
||||||
|
assert layer.state[tctx.client] == tls.ConnectionState.WAIT_FOR_CLIENTHELLO
|
||||||
|
assert layer.state[tctx.server] == tls.ConnectionState.WAIT_FOR_CLIENTHELLO
|
||||||
|
|
||||||
|
with pytest.raises(ssl.SSLWantReadError):
|
||||||
|
tssl_client.obj.do_handshake()
|
||||||
|
client_hello = tssl_client.out.read()
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
>> events.DataReceived(tctx.client, client_hello[:42])
|
||||||
|
<< None
|
||||||
|
)
|
||||||
|
# Still waiting...
|
||||||
|
assert layer.state[tctx.client] == tls.ConnectionState.WAIT_FOR_CLIENTHELLO
|
||||||
|
assert layer.state[tctx.server] == tls.ConnectionState.WAIT_FOR_CLIENTHELLO
|
||||||
|
|
||||||
|
# Finish sending ClientHello
|
||||||
|
playbook >> events.DataReceived(tctx.client, client_hello[42:])
|
||||||
|
return playbook, tssl_client
|
||||||
|
|
||||||
|
|
||||||
|
def test_tls_client_server_no_server_conn(tctx):
|
||||||
|
"""
|
||||||
|
Here we test the scenario where a server connection is _not_ required
|
||||||
|
to establish TLS with the client. After determining this when parsing the ClientHello,
|
||||||
|
we only establish a connection with the client. The server connection may ultimately
|
||||||
|
be established when OpenConnection is called.
|
||||||
|
"""
|
||||||
|
playbook, _ = _test_tls_client_server(tctx, None)
|
||||||
|
data = tutils.Placeholder()
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
<< commands.SendData(tctx.client, data)
|
||||||
|
)
|
||||||
|
assert data()
|
||||||
|
assert playbook.layer.state[tctx.client] == tls.ConnectionState.NEGOTIATING
|
||||||
|
assert playbook.layer.state[tctx.server] == tls.ConnectionState.NO_TLS
|
||||||
|
|
||||||
|
|
||||||
|
def test_tls_client_server_alpn(tctx):
|
||||||
|
"""
|
||||||
|
Here we test the scenario where a server connection is required (e.g. because of ALPN negotation)
|
||||||
|
to establish TLS with the client.
|
||||||
|
"""
|
||||||
|
tssl_server = SSLTest(server_side=True, alpn=["foo", "bar"])
|
||||||
|
|
||||||
|
playbook, tssl_client = _test_tls_client_server(tctx, ["qux", "foo"])
|
||||||
|
|
||||||
|
# We should now get instructed to open a server connection.
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
<< commands.OpenConnection(tctx.server)
|
||||||
|
)
|
||||||
|
tctx.server.connected = True
|
||||||
|
data = tutils.Placeholder()
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
>> events.OpenConnectionReply(-1, None)
|
||||||
|
<< commands.SendData(tctx.server, data)
|
||||||
|
)
|
||||||
|
assert playbook.layer.state[tctx.client] == tls.ConnectionState.WAIT_FOR_SERVER_TLS
|
||||||
|
assert playbook.layer.state[tctx.server] == tls.ConnectionState.NEGOTIATING
|
||||||
|
|
||||||
|
# Establish TLS with the server...
|
||||||
|
tssl_server.inc.write(data())
|
||||||
|
with pytest.raises(ssl.SSLWantReadError):
|
||||||
|
tssl_server.obj.do_handshake()
|
||||||
|
data = tutils.Placeholder()
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
>> events.DataReceived(tctx.server, tssl_server.out.read())
|
||||||
|
<< commands.SendData(tctx.server, data)
|
||||||
|
)
|
||||||
|
tssl_server.inc.write(data())
|
||||||
|
tssl_server.obj.do_handshake()
|
||||||
|
data = tutils.Placeholder()
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
>> events.DataReceived(tctx.server, tssl_server.out.read())
|
||||||
|
<< commands.SendData(tctx.client, data)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert playbook.layer.state[tctx.client] == tls.ConnectionState.NEGOTIATING
|
||||||
|
assert playbook.layer.state[tctx.server] == tls.ConnectionState.ESTABLISHED
|
||||||
|
|
||||||
|
# Server TLS is established, we can now reply to the client handshake...
|
||||||
|
tssl_client.inc.write(data())
|
||||||
|
with pytest.raises(ssl.SSLWantReadError):
|
||||||
|
tssl_client.obj.do_handshake()
|
||||||
|
data = tutils.Placeholder()
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
>> events.DataReceived(tctx.client, tssl_client.out.read())
|
||||||
|
<< commands.SendData(tctx.client, data)
|
||||||
|
)
|
||||||
|
tssl_client.inc.write(data())
|
||||||
|
tssl_client.obj.do_handshake()
|
||||||
|
|
||||||
|
# Both handshakes completed!
|
||||||
|
assert playbook.layer.state[tctx.client] == tls.ConnectionState.ESTABLISHED
|
||||||
|
assert playbook.layer.state[tctx.server] == tls.ConnectionState.ESTABLISHED
|
||||||
|
|
||||||
|
assert tssl_client.obj.selected_alpn_protocol() == "foo"
|
||||||
|
assert tssl_server.obj.selected_alpn_protocol() == "foo"
|
||||||
|
Loading…
Reference in New Issue
Block a user