mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 18:18:25 +00:00
[sans-io] split tls layer into client and server layers
this drastically reduces the complexity of the TLS code and makes it easier to implement the remaining bits.
This commit is contained in:
parent
34274744a1
commit
8f3db90def
@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, List, Union
|
from typing import Optional, List, Union, Sequence
|
||||||
|
|
||||||
from mitmproxy.options import Options
|
from mitmproxy.options import Options
|
||||||
|
|
||||||
@ -10,7 +10,9 @@ class Connection:
|
|||||||
address: tuple
|
address: tuple
|
||||||
connected: bool = False
|
connected: bool = False
|
||||||
tls: bool = False
|
tls: bool = False
|
||||||
|
tls_established: bool = False
|
||||||
alpn: Optional[bytes] = None
|
alpn: Optional[bytes] = None
|
||||||
|
alpn_offers: Sequence[bytes] = ()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"{type(self).__name__}({repr(self.__dict__)})"
|
return f"{type(self).__name__}({repr(self.__dict__)})"
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
from . import modes
|
from . import modes
|
||||||
from .http import HTTPLayer
|
from .http import HTTPLayer
|
||||||
from .tcp import TCPLayer
|
from .tcp import TCPLayer
|
||||||
from .tls import TLSLayer
|
from .tls import ClientTLSLayer, ServerTLSLayer
|
||||||
from .websocket import WebsocketLayer
|
from .websocket import WebsocketLayer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"modes",
|
"modes",
|
||||||
"HTTPLayer",
|
"HTTPLayer",
|
||||||
"TCPLayer",
|
"TCPLayer",
|
||||||
"TLSLayer",
|
"ClientTLSLayer", "ServerTLSLayer",
|
||||||
"WebsocketLayer"
|
"WebsocketLayer"
|
||||||
]
|
]
|
||||||
|
@ -6,9 +6,12 @@ from mitmproxy.proxy2.context import Context, Server
|
|||||||
class ReverseProxy(layer.Layer):
|
class ReverseProxy(layer.Layer):
|
||||||
def __init__(self, context: Context):
|
def __init__(self, context: Context):
|
||||||
super().__init__(context)
|
super().__init__(context)
|
||||||
server_addr = server_spec.parse_with_mode(context.options.mode)[1].address
|
spec = server_spec.parse_with_mode(context.options.mode)[1]
|
||||||
self.context.server = Server(server_addr)
|
self.context.server = Server(spec.address)
|
||||||
|
if spec.scheme != "http":
|
||||||
|
self.context.server.tls = True
|
||||||
|
if not context.options.keep_host_header:
|
||||||
|
self.context.server.sni = spec.address[0]
|
||||||
child_layer = layer.NextLayer(self.context)
|
child_layer = layer.NextLayer(self.context)
|
||||||
self._handle_event = child_layer.handle_event
|
self._handle_event = child_layer.handle_event
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
from enum import Enum
|
from typing import MutableMapping, Optional, Iterator, Union, Generator, Any
|
||||||
from typing import MutableMapping, Optional, Iterator
|
|
||||||
|
|
||||||
from OpenSSL import SSL
|
from OpenSSL import SSL
|
||||||
|
|
||||||
@ -13,14 +12,6 @@ from mitmproxy.proxy2 import layer, commands, events
|
|||||||
from mitmproxy.proxy2.utils import expect
|
from mitmproxy.proxy2.utils import expect
|
||||||
|
|
||||||
|
|
||||||
class ConnectionState(Enum):
|
|
||||||
NO_TLS = 1
|
|
||||||
WAIT_FOR_CLIENTHELLO = 2
|
|
||||||
WAIT_FOR_SERVER_TLS = 3
|
|
||||||
NEGOTIATING = 5
|
|
||||||
ESTABLISHED = 6
|
|
||||||
|
|
||||||
|
|
||||||
def is_tls_handshake_record(d: bytes) -> bool:
|
def is_tls_handshake_record(d: bytes) -> bool:
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
@ -78,265 +69,34 @@ def get_client_hello(data: bytes) -> Optional[bytes]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class TLSLayer(layer.Layer):
|
def parse_client_hello(data: bytes) -> Optional[TlsClientHello]:
|
||||||
"""
|
"""
|
||||||
The TLS layer manages both client- and server-side TLS connection state.
|
Check if the supplied bytes contain a full ClientHello message,
|
||||||
This unfortunately is quite complex as the client handshake may depend on the server
|
and if so, parse it.
|
||||||
handshake and vice versa: We need the client's SNI and ALPN to connect upstream,
|
|
||||||
and we need the server's ALPN choice to complete our client TLS handshake.
|
|
||||||
On top, we may have configurations where TLS is only added on one end,
|
|
||||||
and we also may have OpenConnection events which change the server's TLS configuration.
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- A ClientHello object on success
|
||||||
|
- None, if the TLS record is not complete
|
||||||
|
|
||||||
The following state machine shows the possible states for client and server connection:
|
Raises:
|
||||||
|
- A ValueError, if the passed ClientHello is invalid
|
||||||
Legend:
|
|
||||||
/: NO_TLS
|
|
||||||
WCH: WAIT_FOR_CLIENTHELLO
|
|
||||||
WST: WAIT_FOR_SERVER_TLS
|
|
||||||
N: NEGOTIATING
|
|
||||||
E: ESTABLISHED
|
|
||||||
|
|
||||||
+------------+ +---+ +------------+
|
|
||||||
|Client State|--------> | / | |Server State|
|
|
||||||
+------------+ no tls +---+ server tls, +------------+ server tls,
|
|
||||||
| client tls | | | no client tls
|
|
||||||
v client tls +-----------------+ | +--------------------+
|
|
||||||
| | |
|
|
||||||
+------------------------------+ | | no server tls |
|
|
||||||
| no server tls | v v v
|
|
||||||
| v OpenConn(TLS)
|
|
||||||
v server tls +---+ not needed +--------------------> +---+
|
|
||||||
+---> +---+ |WCH+-------------> | / | | N | <-+
|
|
||||||
+---+ | | N | +---+ +---+ <--------------------+ |
|
|
||||||
|WCH| | +> +---+ | OpenConn(No TLS) | |
|
|
||||||
+---+ | | | | ^ | |
|
|
||||||
| | | v | already connec- | handshake done v |
|
|
||||||
|ClientHello arrives | | +---+ | ted or server | |
|
|
||||||
| | | | E | | info needed | OpenConn(No TLS) +---+ |
|
|
||||||
+----------------------+ | +---+ | +--------------------+ E | |
|
|
||||||
| no server info needed | | +---+ |
|
|
||||||
| | | |
|
|
||||||
v server info needed | +------------------------------------------------+
|
|
||||||
|
|
|
||||||
+---+ |
|
|
||||||
|WST|-----------------------+
|
|
||||||
+---+ server tls established
|
|
||||||
(or errored)
|
|
||||||
"""
|
"""
|
||||||
tls: MutableMapping[context.Connection, SSL.Connection]
|
|
||||||
state: MutableMapping[context.Connection, ConnectionState]
|
|
||||||
recv_buffer: MutableMapping[context.Connection, bytearray]
|
|
||||||
client_hello: Optional[TlsClientHello]
|
|
||||||
|
|
||||||
child_layer: layer.Layer
|
|
||||||
|
|
||||||
def __init__(self, context: context.Context):
|
|
||||||
super().__init__(context)
|
|
||||||
self.tls = {}
|
|
||||||
self.state = {}
|
|
||||||
self.recv_buffer = {
|
|
||||||
context.client: bytearray(),
|
|
||||||
context.server: bytearray()
|
|
||||||
}
|
|
||||||
self.client_hello = None
|
|
||||||
|
|
||||||
self.child_layer = layer.NextLayer(context)
|
|
||||||
|
|
||||||
@expect(events.Start)
|
|
||||||
def start(self, _) -> commands.TCommandGenerator:
|
|
||||||
client = self.context.client
|
|
||||||
server = self.context.server
|
|
||||||
|
|
||||||
if client.tls and server.tls:
|
|
||||||
self.state[client] = ConnectionState.WAIT_FOR_CLIENTHELLO
|
|
||||||
self.state[server] = ConnectionState.WAIT_FOR_CLIENTHELLO
|
|
||||||
elif client.tls:
|
|
||||||
self.state[server] = ConnectionState.NO_TLS
|
|
||||||
yield from self.start_client_tls()
|
|
||||||
elif server.tls and server.connected:
|
|
||||||
self.state[client] = ConnectionState.NO_TLS
|
|
||||||
yield from self.start_server_tls()
|
|
||||||
else:
|
|
||||||
self.state[client] = ConnectionState.NO_TLS
|
|
||||||
self.state[server] = ConnectionState.NO_TLS
|
|
||||||
|
|
||||||
yield from self.child_layer.handle_event(events.Start())
|
|
||||||
self._handle_event = self.process
|
|
||||||
|
|
||||||
_handle_event = start
|
|
||||||
|
|
||||||
def send(self, send_command: commands.SendData) -> commands.TCommandGenerator:
|
|
||||||
if self.state[send_command.connection] == ConnectionState.NO_TLS:
|
|
||||||
yield send_command
|
|
||||||
else:
|
|
||||||
yield commands.Log(f"Plain{send_command}")
|
|
||||||
self.tls[send_command.connection].sendall(send_command.data)
|
|
||||||
yield from self.tls_interact(send_command.connection)
|
|
||||||
|
|
||||||
def event_to_child(self, event: events.Event) -> commands.TCommandGenerator:
|
|
||||||
for command in self.child_layer.handle_event(event):
|
|
||||||
if isinstance(command, commands.SendData):
|
|
||||||
yield from self.send(command)
|
|
||||||
elif isinstance(command, commands.OpenConnection):
|
|
||||||
raise NotImplementedError("Cannot open connection")
|
|
||||||
else:
|
|
||||||
yield command
|
|
||||||
|
|
||||||
def parse_client_hello(self):
|
|
||||||
# Check if ClientHello is complete
|
# Check if ClientHello is complete
|
||||||
client_hello = get_client_hello(self.recv_buffer[self.context.client])
|
client_hello = get_client_hello(data)
|
||||||
if client_hello:
|
if client_hello:
|
||||||
self.client_hello = TlsClientHello(client_hello[4:])
|
return TlsClientHello(client_hello[4:])
|
||||||
return True
|
return None
|
||||||
return False
|
|
||||||
|
|
||||||
def process(self, event: events.Event):
|
|
||||||
if isinstance(event, events.DataReceived):
|
|
||||||
state = self.state[event.connection]
|
|
||||||
|
|
||||||
if state is ConnectionState.WAIT_FOR_CLIENTHELLO:
|
class _TLSLayer(layer.Layer):
|
||||||
yield from self.process_wait_for_clienthello(event)
|
send_buffer: MutableMapping[SSL.Connection, bytearray]
|
||||||
elif state is ConnectionState.WAIT_FOR_SERVER_TLS:
|
tls: MutableMapping[context.Connection, SSL.Connection]
|
||||||
self.recv_buffer[self.context.client].extend(event.data)
|
child_layer: Optional[layer.Layer] = None
|
||||||
elif state is ConnectionState.NEGOTIATING:
|
|
||||||
yield from self.process_negotiate(event)
|
|
||||||
elif state is ConnectionState.NO_TLS:
|
|
||||||
yield from self.event_to_child(event)
|
|
||||||
elif state is ConnectionState.ESTABLISHED:
|
|
||||||
yield from self.process_relay(event)
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Unexpected state")
|
|
||||||
else:
|
|
||||||
yield from self.event_to_child(event)
|
|
||||||
|
|
||||||
def process_wait_for_clienthello(self, event: events.DataReceived):
|
def __init__(self, context):
|
||||||
client = self.context.client
|
super().__init__(context)
|
||||||
server = self.context.server
|
self.send_buffer = {}
|
||||||
# We are not ready to process this yet.
|
self.tls = {}
|
||||||
self.recv_buffer[event.connection].extend(event.data)
|
|
||||||
|
|
||||||
if event.connection == client and self.parse_client_hello():
|
|
||||||
yield commands.Log(f"Client Hello: {self.client_hello}")
|
|
||||||
|
|
||||||
client_tls_requires_server_connection = (
|
|
||||||
self.context.server.tls and
|
|
||||||
self.context.options.upstream_cert and
|
|
||||||
(
|
|
||||||
self.context.options.add_upstream_certs_to_client_chain or
|
|
||||||
self.client_hello.alpn_protocols or
|
|
||||||
not self.client_hello.sni
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# What do we do with the client connection now?
|
|
||||||
if client_tls_requires_server_connection:
|
|
||||||
self.state[client] = ConnectionState.WAIT_FOR_SERVER_TLS
|
|
||||||
else:
|
|
||||||
yield from self.start_client_tls()
|
|
||||||
|
|
||||||
# What do we do with the server connection now?
|
|
||||||
if client_tls_requires_server_connection and not self.context.server.connected:
|
|
||||||
yield commands.OpenConnection(self.context.server)
|
|
||||||
if not self.context.server.connected:
|
|
||||||
self.state[server] = ConnectionState.NO_TLS
|
|
||||||
else:
|
|
||||||
yield from self.start_server_tls()
|
|
||||||
|
|
||||||
def process_negotiate(self, event: events.DataReceived):
|
|
||||||
# bio_write errors for b"", so we need to check first if we actually received something.
|
|
||||||
if event.data:
|
|
||||||
self.tls[event.connection].bio_write(event.data)
|
|
||||||
try:
|
|
||||||
self.tls[event.connection].do_handshake()
|
|
||||||
except SSL.WantReadError:
|
|
||||||
yield from self.tls_interact(event.connection)
|
|
||||||
else:
|
|
||||||
self.state[event.connection] = ConnectionState.ESTABLISHED
|
|
||||||
event.connection.sni = self.tls[event.connection].get_servername()
|
|
||||||
event.connection.alpn = self.tls[event.connection].get_alpn_proto_negotiated()
|
|
||||||
|
|
||||||
# there might already be data in the OpenSSL BIO, so we need to trigger its processing.
|
|
||||||
yield from self.process(events.DataReceived(event.connection, b""))
|
|
||||||
|
|
||||||
if self.state[self.context.client] == ConnectionState.WAIT_FOR_SERVER_TLS:
|
|
||||||
assert event.connection == self.context.server
|
|
||||||
yield from self.start_client_tls()
|
|
||||||
|
|
||||||
def process_relay(self, event: events.DataReceived):
|
|
||||||
if event.data:
|
|
||||||
self.tls[event.connection].bio_write(event.data)
|
|
||||||
yield from self.tls_interact(event.connection)
|
|
||||||
|
|
||||||
plaintext = bytearray()
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
plaintext.extend(self.tls[event.connection].recv(65535))
|
|
||||||
except (SSL.WantReadError, SSL.ZeroReturnError):
|
|
||||||
break
|
|
||||||
|
|
||||||
if plaintext:
|
|
||||||
evt = events.DataReceived(event.connection, bytes(plaintext))
|
|
||||||
yield commands.Log(f"Plain{evt}")
|
|
||||||
yield from self.event_to_child(evt)
|
|
||||||
|
|
||||||
def start_server_tls(self):
|
|
||||||
server = self.context.server
|
|
||||||
|
|
||||||
ssl_context = SSL.Context(SSL.SSLv23_METHOD)
|
|
||||||
|
|
||||||
if self.client_hello:
|
|
||||||
alpn = [
|
|
||||||
x for x in self.client_hello.alpn_protocols
|
|
||||||
if not (x.startswith(b"h2-") or x.startswith(b"spdy"))
|
|
||||||
]
|
|
||||||
ssl_context.set_alpn_protos(alpn)
|
|
||||||
|
|
||||||
self.tls[server] = SSL.Connection(ssl_context)
|
|
||||||
|
|
||||||
if server.sni:
|
|
||||||
if server.sni is True:
|
|
||||||
if self.client_hello and self.client_hello.sni:
|
|
||||||
server.sni = self.client_hello.sni.encode("idna")
|
|
||||||
else:
|
|
||||||
server.sni = server.address[0].encode("idna")
|
|
||||||
self.tls[server].set_tlsext_host_name(server.sni)
|
|
||||||
self.tls[server].set_connect_state()
|
|
||||||
|
|
||||||
self.state[server] = ConnectionState.NEGOTIATING
|
|
||||||
yield from self.process(events.DataReceived(
|
|
||||||
server, bytes(self.recv_buffer[server])
|
|
||||||
))
|
|
||||||
self.recv_buffer[server] = bytearray()
|
|
||||||
|
|
||||||
def start_client_tls(self):
|
|
||||||
# FIXME
|
|
||||||
client = self.context.client
|
|
||||||
server = self.context.server
|
|
||||||
context = SSL.Context(SSL.SSLv23_METHOD)
|
|
||||||
cert, privkey, cert_chain = CertStore.from_store(
|
|
||||||
os.path.expanduser("~/.mitmproxy"), "mitmproxy"
|
|
||||||
).get_cert(b"example.com", (b"example.com",))
|
|
||||||
context.use_privatekey(privkey)
|
|
||||||
context.use_certificate(cert.x509)
|
|
||||||
context.set_cipher_list(tls.DEFAULT_CLIENT_CIPHERS)
|
|
||||||
|
|
||||||
if self.state[server] == ConnectionState.ESTABLISHED:
|
|
||||||
alpn_for_client = self.tls[server].get_alpn_proto_negotiated()
|
|
||||||
|
|
||||||
def alpn_select_callback(conn_, options):
|
|
||||||
if alpn_for_client in options:
|
|
||||||
return alpn_for_client
|
|
||||||
|
|
||||||
context.set_alpn_select_callback(alpn_select_callback)
|
|
||||||
|
|
||||||
self.tls[client] = SSL.Connection(context)
|
|
||||||
self.tls[client].set_accept_state()
|
|
||||||
|
|
||||||
self.state[client] = ConnectionState.NEGOTIATING
|
|
||||||
yield from self.process(events.DataReceived(
|
|
||||||
client, bytes(self.recv_buffer[client])
|
|
||||||
))
|
|
||||||
self.recv_buffer[client] = bytearray()
|
|
||||||
|
|
||||||
def tls_interact(self, conn: context.Connection):
|
def tls_interact(self, conn: context.Connection):
|
||||||
while True:
|
while True:
|
||||||
@ -347,3 +107,269 @@ class TLSLayer(layer.Layer):
|
|||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
yield commands.SendData(conn, data)
|
yield commands.SendData(conn, data)
|
||||||
|
|
||||||
|
def send(
|
||||||
|
self,
|
||||||
|
send_command: commands.SendData,
|
||||||
|
) -> commands.TCommandGenerator:
|
||||||
|
tls_conn = self.tls[send_command.connection]
|
||||||
|
if send_command.connection.tls_established:
|
||||||
|
tls_conn.sendall(send_command.data)
|
||||||
|
yield from self.tls_interact(send_command.connection)
|
||||||
|
else:
|
||||||
|
buf = self.send_buffer.setdefault(tls_conn, bytearray())
|
||||||
|
buf.extend(send_command.data)
|
||||||
|
|
||||||
|
def negotiate(self, event: events.DataReceived) -> Generator[commands.Command, Any, bool]:
|
||||||
|
"""
|
||||||
|
Make sure to trigger processing if done!
|
||||||
|
"""
|
||||||
|
# bio_write errors for b"", so we need to check first if we actually received something.
|
||||||
|
tls_conn = self.tls[event.connection]
|
||||||
|
if event.data:
|
||||||
|
tls_conn.bio_write(event.data)
|
||||||
|
try:
|
||||||
|
tls_conn.do_handshake()
|
||||||
|
except SSL.WantReadError:
|
||||||
|
yield from self.tls_interact(event.connection)
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
event.connection.tls_established = True
|
||||||
|
event.connection.alpn = tls_conn.get_alpn_proto_negotiated()
|
||||||
|
print(f"TLS established: {event.connection}")
|
||||||
|
# TODO: Set all other connection attributes here
|
||||||
|
# there might already be data in the OpenSSL BIO, so we need to trigger its processing.
|
||||||
|
yield from self.relay(events.DataReceived(event.connection, b""))
|
||||||
|
if tls_conn in self.send_buffer:
|
||||||
|
data_to_send = bytes(self.send_buffer.pop(tls_conn))
|
||||||
|
yield from self.send(commands.SendData(event.connection, data_to_send))
|
||||||
|
return True
|
||||||
|
|
||||||
|
def relay(self, event: events.DataReceived):
|
||||||
|
tls_conn = self.tls[event.connection]
|
||||||
|
if event.data:
|
||||||
|
tls_conn.bio_write(event.data)
|
||||||
|
yield from self.tls_interact(event.connection)
|
||||||
|
|
||||||
|
plaintext = bytearray()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
plaintext.extend(tls_conn.recv(65535))
|
||||||
|
except (SSL.WantReadError, SSL.ZeroReturnError):
|
||||||
|
break
|
||||||
|
|
||||||
|
if plaintext:
|
||||||
|
evt = events.DataReceived(event.connection, bytes(plaintext))
|
||||||
|
# yield commands.Log(f"Plain{evt}")
|
||||||
|
yield from self.event_to_child(evt)
|
||||||
|
|
||||||
|
def event_to_child(self, event: events.Event) -> commands.TCommandGenerator:
|
||||||
|
for command in self.child_layer.handle_event(event):
|
||||||
|
if isinstance(command, commands.SendData) and command.connection in self.tls:
|
||||||
|
yield from self.send(command)
|
||||||
|
else:
|
||||||
|
yield command
|
||||||
|
|
||||||
|
|
||||||
|
class ServerTLSLayer(_TLSLayer):
|
||||||
|
"""
|
||||||
|
This layer manages TLS on potentially multiple server connections.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, context: context.Context):
|
||||||
|
super().__init__(context)
|
||||||
|
self.child_layer = layer.NextLayer(context)
|
||||||
|
|
||||||
|
@expect(events.Start)
|
||||||
|
def start(self, event: events.Start) -> commands.TCommandGenerator:
|
||||||
|
yield from self.child_layer.handle_event(event)
|
||||||
|
|
||||||
|
server = self.context.server
|
||||||
|
if server.connected and server.tls:
|
||||||
|
yield from self._start_tls(server)
|
||||||
|
self._handle_event = self.process
|
||||||
|
|
||||||
|
_handle_event = start
|
||||||
|
|
||||||
|
def process(self, event: Union[events.DataReceived, events.ConnectionClosed]):
|
||||||
|
if isinstance(event, events.DataReceived) and event.connection in self.tls:
|
||||||
|
if not event.connection.tls_established:
|
||||||
|
yield from self.negotiate(event)
|
||||||
|
else:
|
||||||
|
yield from self.relay(event)
|
||||||
|
elif isinstance(event, events.OpenConnectionReply):
|
||||||
|
err = event.reply
|
||||||
|
conn = event.command.connection
|
||||||
|
if not err and conn.tls:
|
||||||
|
yield from self._start_tls(conn)
|
||||||
|
yield from self.event_to_child(event)
|
||||||
|
elif isinstance(event, events.ConnectionClosed):
|
||||||
|
yield from self.event_to_child(event)
|
||||||
|
self.send_buffer.pop(
|
||||||
|
self.tls.pop(event.connection, None),
|
||||||
|
None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield from self.event_to_child(event)
|
||||||
|
|
||||||
|
def _start_tls(self, server: context.Server):
|
||||||
|
ssl_context = SSL.Context(SSL.SSLv23_METHOD)
|
||||||
|
|
||||||
|
if server.alpn_offers:
|
||||||
|
ssl_context.set_alpn_protos(server.alpn_offers)
|
||||||
|
|
||||||
|
self.tls[server] = SSL.Connection(ssl_context)
|
||||||
|
|
||||||
|
if server.sni:
|
||||||
|
if server.sni is True:
|
||||||
|
if self.context.client.sni:
|
||||||
|
server.sni = self.context.client.sni.encode("idna")
|
||||||
|
else:
|
||||||
|
server.sni = server.address[0].encode("idna")
|
||||||
|
self.tls[server].set_tlsext_host_name(server.sni)
|
||||||
|
self.tls[server].set_connect_state()
|
||||||
|
|
||||||
|
yield from self.process(events.DataReceived(server, b""))
|
||||||
|
|
||||||
|
|
||||||
|
class ClientTLSLayer(_TLSLayer):
|
||||||
|
"""
|
||||||
|
This layer establishes TLS on a single client connection.
|
||||||
|
|
||||||
|
┌─────┐
|
||||||
|
│Start│
|
||||||
|
└┬────┘
|
||||||
|
↓
|
||||||
|
┌────────────────────┐
|
||||||
|
│Wait for ClientHello│
|
||||||
|
└┬───────────────────┘
|
||||||
|
│ Do we need server TLS info
|
||||||
|
│ to establish TLS with client?
|
||||||
|
│ ┌───────────────────┐
|
||||||
|
├─────→│Wait for Server TLS│
|
||||||
|
│ yes └┬──────────────────┘
|
||||||
|
│no │
|
||||||
|
↓ ↓
|
||||||
|
┌────────────────┐
|
||||||
|
│Process messages│
|
||||||
|
└────────────────┘
|
||||||
|
|
||||||
|
"""
|
||||||
|
recv_buffer: bytearray
|
||||||
|
|
||||||
|
def __init__(self, context: context.Context):
|
||||||
|
super().__init__(context)
|
||||||
|
self.recv_buffer = bytearray()
|
||||||
|
self.child_layer = ServerTLSLayer(self.context)
|
||||||
|
|
||||||
|
@expect(events.Start)
|
||||||
|
def state_start(self, _) -> commands.TCommandGenerator:
|
||||||
|
self.context.client.tls = True
|
||||||
|
self._handle_event = self.state_wait_for_clienthello
|
||||||
|
yield from ()
|
||||||
|
|
||||||
|
_handle_event = state_start
|
||||||
|
|
||||||
|
@expect(events.DataReceived, events.ConnectionClosed)
|
||||||
|
def state_wait_for_clienthello(self, event: events.Event):
|
||||||
|
client = self.context.client
|
||||||
|
server = self.context.server
|
||||||
|
if isinstance(event, events.DataReceived) and event.connection == client:
|
||||||
|
self.recv_buffer.extend(event.data)
|
||||||
|
try:
|
||||||
|
client_hello = parse_client_hello(self.recv_buffer)
|
||||||
|
except ValueError as e:
|
||||||
|
raise NotImplementedError() from e # TODO
|
||||||
|
|
||||||
|
if client_hello:
|
||||||
|
yield commands.Log(f"Client Hello: {client_hello}")
|
||||||
|
|
||||||
|
client.sni = client_hello.sni
|
||||||
|
client.alpn_offers = client_hello.alpn_protocols
|
||||||
|
|
||||||
|
client_tls_requires_server_connection = (
|
||||||
|
self.context.server.tls and
|
||||||
|
self.context.options.upstream_cert and
|
||||||
|
(
|
||||||
|
self.context.options.add_upstream_certs_to_client_chain or
|
||||||
|
client.alpn_offers or
|
||||||
|
not client.sni
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# What do we do with the client connection now?
|
||||||
|
if client_tls_requires_server_connection and not server.tls_established:
|
||||||
|
yield from self.start_server_tls()
|
||||||
|
self._handle_event = self.state_wait_for_server_tls
|
||||||
|
else:
|
||||||
|
yield from self.start_negotiate()
|
||||||
|
self._handle_event = self.state_process
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(event) # TODO
|
||||||
|
|
||||||
|
def state_wait_for_server_tls(self, event: events.Event):
|
||||||
|
yield from self.event_to_child(event)
|
||||||
|
# TODO: Handle case where TLS establishment fails.
|
||||||
|
# We still need a good way to signal this - one possibility would be by closing
|
||||||
|
# the connection?
|
||||||
|
if self.context.server.tls_established:
|
||||||
|
yield from self.start_negotiate()
|
||||||
|
self._handle_event = self.state_process
|
||||||
|
|
||||||
|
def state_process(self, event: events.Event):
|
||||||
|
if isinstance(event, events.DataReceived) and event.connection == self.context.client:
|
||||||
|
if not event.connection.tls_established:
|
||||||
|
yield from self.negotiate(event)
|
||||||
|
else:
|
||||||
|
yield from self.relay(event)
|
||||||
|
else:
|
||||||
|
yield from self.event_to_child(event)
|
||||||
|
|
||||||
|
def start_server_tls(self):
|
||||||
|
"""
|
||||||
|
We often need information from the upstream connection to establish TLS with the client.
|
||||||
|
For example, we need to check if the client does ALPN or not.
|
||||||
|
"""
|
||||||
|
if not self.context.server.connected:
|
||||||
|
err = yield commands.OpenConnection(self.context.server)
|
||||||
|
if err:
|
||||||
|
yield commands.Log(
|
||||||
|
"Cannot establish server connection, which is required to establish TLS with the client."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.context.server.alpn_offers = [
|
||||||
|
x for x in self.context.client.alpn_offers
|
||||||
|
if not (x.startswith(b"h2-") or x.startswith(b"spdy"))
|
||||||
|
]
|
||||||
|
|
||||||
|
yield from self.child_layer.handle_event(events.Start())
|
||||||
|
|
||||||
|
def start_negotiate(self):
|
||||||
|
if not self.child_layer:
|
||||||
|
yield from self.child_layer.handle_event(events.Start())
|
||||||
|
|
||||||
|
# FIXME: Do this properly
|
||||||
|
client = self.context.client
|
||||||
|
server = self.context.server
|
||||||
|
context = SSL.Context(SSL.SSLv23_METHOD)
|
||||||
|
cert, privkey, cert_chain = CertStore.from_store(
|
||||||
|
os.path.expanduser("~/.mitmproxy"), "mitmproxy"
|
||||||
|
).get_cert(b"example.com", (b"example.com",))
|
||||||
|
context.use_privatekey(privkey)
|
||||||
|
context.use_certificate(cert.x509)
|
||||||
|
context.set_cipher_list(tls.DEFAULT_CLIENT_CIPHERS)
|
||||||
|
|
||||||
|
if server.alpn:
|
||||||
|
def alpn_select_callback(conn_, options):
|
||||||
|
if server.alpn in options:
|
||||||
|
return server.alpn
|
||||||
|
|
||||||
|
context.set_alpn_select_callback(alpn_select_callback)
|
||||||
|
|
||||||
|
self.tls[self.context.client] = SSL.Connection(context)
|
||||||
|
self.tls[self.context.client].set_accept_state()
|
||||||
|
|
||||||
|
yield from self.state_process(events.DataReceived(
|
||||||
|
client, bytes(self.recv_buffer)
|
||||||
|
))
|
||||||
|
self.recv_buffer = bytearray()
|
||||||
|
@ -59,7 +59,12 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||||||
# self._debug("transports closed!")
|
# self._debug("transports closed!")
|
||||||
|
|
||||||
async def close_connection(self, connection):
|
async def close_connection(self, connection):
|
||||||
io = self.transports.pop(connection, None)
|
try:
|
||||||
|
io = self.transports.pop(connection)
|
||||||
|
except KeyError:
|
||||||
|
self.log(f"already closed: {connection}", "warn")
|
||||||
|
return
|
||||||
|
else:
|
||||||
self.log(f"closing {connection}", "debug")
|
self.log(f"closing {connection}", "debug")
|
||||||
try:
|
try:
|
||||||
await io.w.drain()
|
await io.w.drain()
|
||||||
@ -109,10 +114,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||||||
print(message)
|
print(message)
|
||||||
|
|
||||||
def server_event(self, event: events.Event) -> None:
|
def server_event(self, event: events.Event) -> None:
|
||||||
self.log(f">> {event}", "debug")
|
|
||||||
layer_commands = self.layer.handle_event(event)
|
layer_commands = self.layer.handle_event(event)
|
||||||
for command in layer_commands:
|
for command in layer_commands:
|
||||||
self.log(f"<< {command}", "debug")
|
|
||||||
if isinstance(command, commands.OpenConnection):
|
if isinstance(command, commands.OpenConnection):
|
||||||
asyncio.ensure_future(
|
asyncio.ensure_future(
|
||||||
self.open_connection(command)
|
self.open_connection(command)
|
||||||
@ -155,13 +158,20 @@ if __name__ == "__main__":
|
|||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
opts = moptions.Options()
|
opts = moptions.Options()
|
||||||
# opts.mode = "reverse:example.com"
|
opts.mode = "reverse:example.com"
|
||||||
|
# test client-tls-first scenario
|
||||||
|
# opts.upstream_cert = False
|
||||||
|
|
||||||
|
layers.ClientTLSLayer.debug = ""
|
||||||
|
layers.ServerTLSLayer.debug = " "
|
||||||
|
layers.TCPLayer.debug = " "
|
||||||
|
|
||||||
async def handle(reader, writer):
|
async def handle(reader, writer):
|
||||||
layer_stack = [
|
layer_stack = [
|
||||||
# layers.TLSLayer,
|
layers.ClientTLSLayer,
|
||||||
lambda c: layers.HTTPLayer(c, HTTPMode.regular),
|
#layers.ServerTLSLayer,
|
||||||
layers.TCPLayer,
|
layers.TCPLayer,
|
||||||
|
# lambda c: layers.HTTPLayer(c, HTTPMode.transparent),
|
||||||
]
|
]
|
||||||
|
|
||||||
def next_layer(nl: layer.NextLayer):
|
def next_layer(nl: layer.NextLayer):
|
||||||
|
@ -57,7 +57,8 @@ def expect(*event_types):
|
|||||||
yield from f(self, event)
|
yield from f(self, event)
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Invalid event type: Expected {}, got {}".format(event_types, event))
|
"Invalid event type at {}: Expected {}, got {}.".format(f, event_types, event)
|
||||||
|
)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user