mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 00:01:36 +00:00
[sans-io] tls: various improvements
This commit is contained in:
parent
09b6257de0
commit
6cf0bec912
@ -78,6 +78,8 @@ class TlsConfig:
|
||||
tls_start.ssl_conn.set_app_data({
|
||||
"server_alpn": tls_start.context.server.alpn
|
||||
})
|
||||
tls_start.ssl_conn.set_accept_state()
|
||||
|
||||
|
||||
def create_proxy_server_ssl_conn(self, tls_start: tls.StartHookData) -> None:
|
||||
client = tls_start.context.client
|
||||
@ -126,6 +128,7 @@ class TlsConfig:
|
||||
**args
|
||||
)
|
||||
tls_start.ssl_conn = SSL.Connection(ssl_ctx)
|
||||
tls_start.ssl_conn.set_connect_state()
|
||||
|
||||
def configure(self, updated):
|
||||
if not any(x in updated for x in ["confdir", "certs"]):
|
||||
|
@ -1,6 +1,7 @@
|
||||
from enum import Flag, auto
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
from mitmproxy import certs
|
||||
from mitmproxy.options import Options
|
||||
|
||||
|
||||
@ -19,6 +20,7 @@ class Connection:
|
||||
state: ConnectionState
|
||||
tls: bool = False
|
||||
tls_established: bool = False
|
||||
certificate_chain: Optional[Sequence[certs.Cert]] = None
|
||||
alpn: Optional[bytes] = None
|
||||
alpn_offers: Sequence[bytes] = ()
|
||||
cipher_list: Sequence[bytes] = ()
|
||||
|
@ -4,10 +4,12 @@ from typing import Any, Dict, Generator, Iterator, Optional, Tuple
|
||||
|
||||
from OpenSSL import SSL
|
||||
|
||||
from mitmproxy import certs
|
||||
from mitmproxy.net import tls as net_tls
|
||||
from mitmproxy.proxy2 import commands, events, layer
|
||||
from mitmproxy.proxy2 import context
|
||||
from mitmproxy.proxy2.utils import expect
|
||||
from mitmproxy.utils import human
|
||||
|
||||
|
||||
def is_tls_handshake_record(d: bytes) -> bool:
|
||||
@ -142,6 +144,17 @@ class _TLSLayer(layer.Layer):
|
||||
state = ", ".join(conn_states)
|
||||
return f"{type(self).__name__}({state})"
|
||||
|
||||
def start_tls(self, conn: context.Connection, initial_data: bytes = b""):
|
||||
assert conn not in self.tls
|
||||
assert conn.connected
|
||||
conn.tls = True
|
||||
|
||||
tls_start = StartHookData(conn, self.context)
|
||||
yield commands.Hook("tls_start", tls_start)
|
||||
self.tls[conn] = tls_start.ssl_conn
|
||||
|
||||
yield from self.negotiate(conn, initial_data)
|
||||
|
||||
def tls_interact(self, conn: context.Connection) -> commands.TCommandGenerator:
|
||||
while True:
|
||||
try:
|
||||
@ -163,17 +176,37 @@ class _TLSLayer(layer.Layer):
|
||||
yield from self.tls_interact(conn)
|
||||
return False, None
|
||||
except SSL.Error as e:
|
||||
return False, repr(e)
|
||||
# provide more detailed information for some errors.
|
||||
last_err = e.args[0][-1]
|
||||
if last_err == ('SSL routines', 'tls_process_server_certificate', 'certificate verify failed'):
|
||||
verify_result = SSL._lib.SSL_get_verify_result(self.tls[conn]._ssl)
|
||||
error = SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(verify_result)).decode()
|
||||
return False, f"Certificate verify failed: {error}"
|
||||
elif last_err == ('SSL routines', 'ssl3_read_bytes', 'tlsv1 alert unknown ca'):
|
||||
return False, "TLS Alert: Unknown CA"
|
||||
else:
|
||||
return False, repr(e)
|
||||
else:
|
||||
# Get all peer certificates.
|
||||
# https://www.openssl.org/docs/man1.1.1/man3/SSL_get_peer_cert_chain.html
|
||||
# If called on the client side, the stack also contains the peer's certificate; if called on the server
|
||||
# side, the peer's certificate must be obtained separately using SSL_get_peer_certificate(3).
|
||||
all_certs = self.tls[conn].get_peer_cert_chain() or []
|
||||
if conn == self.context.client:
|
||||
cert = self.tls[conn].get_peer_certificate()
|
||||
if cert:
|
||||
all_certs.insert(0, cert)
|
||||
|
||||
|
||||
conn.tls_established = True
|
||||
conn.sni = self.tls[conn].get_servername()
|
||||
conn.alpn = self.tls[conn].get_alpn_proto_negotiated()
|
||||
conn.certificate_chain = [certs.Cert(x) for x in all_certs]
|
||||
conn.cipher_list = self.tls[conn].get_cipher_list()
|
||||
conn.tls_version = self.tls[conn].get_protocol_version_name()
|
||||
conn.timestamp_tls_setup = time.time()
|
||||
yield commands.Log(f"TLS established: {conn}")
|
||||
yield from self.receive(conn, b"")
|
||||
# TODO: Set all other connection attributes here
|
||||
return True, None
|
||||
|
||||
def receive(self, conn: context.Connection, data: bytes):
|
||||
@ -250,22 +283,10 @@ class ServerTLSLayer(_TLSLayer):
|
||||
for command in super().event_to_child(event):
|
||||
if isinstance(command, EstablishServerTLS):
|
||||
self.command_to_reply_to[command.connection] = command
|
||||
yield from self.start_server_tls(command.connection)
|
||||
yield from self.start_tls(command.connection)
|
||||
else:
|
||||
yield command
|
||||
|
||||
def start_server_tls(self, conn: context.Server):
|
||||
assert conn not in self.tls
|
||||
assert conn.connected
|
||||
conn.tls = True
|
||||
|
||||
tls_start = StartHookData(conn, self.context)
|
||||
yield commands.Hook("tls_start", tls_start)
|
||||
self.tls[conn] = tls_start.ssl_conn
|
||||
self.tls[conn].set_connect_state()
|
||||
|
||||
yield from self.negotiate(conn, b"")
|
||||
|
||||
|
||||
class ClientTLSLayer(_TLSLayer):
|
||||
"""
|
||||
@ -320,13 +341,12 @@ class ClientTLSLayer(_TLSLayer):
|
||||
yield commands.Log("Unable to establish TLS connection with server. "
|
||||
"Trying to establish TLS with client anyway.")
|
||||
|
||||
yield from self.start_client_tls()
|
||||
yield from self.start_tls(client, bytes(self.recv_buffer))
|
||||
self.recv_buffer.clear()
|
||||
self._handle_event = super()._handle_event
|
||||
|
||||
# In any case, we now have enough information to start server TLS if needed.
|
||||
yield from self.event_to_child(events.Start())
|
||||
elif isinstance(event, events.ConnectionClosed) and event.connection == client:
|
||||
self.recv_buffer.clear()
|
||||
else:
|
||||
yield from self.event_to_child(event)
|
||||
|
||||
@ -351,29 +371,17 @@ class ClientTLSLayer(_TLSLayer):
|
||||
)
|
||||
return err
|
||||
|
||||
def start_client_tls(self) -> commands.TCommandGenerator:
|
||||
client = self.context.client
|
||||
tls_start = StartHookData(client, self.context)
|
||||
yield commands.Hook("tls_start", tls_start)
|
||||
self.tls[client] = tls_start.ssl_conn
|
||||
self.tls[client].set_accept_state()
|
||||
|
||||
yield from self.negotiate(client, bytes(self.recv_buffer))
|
||||
self.recv_buffer.clear()
|
||||
|
||||
def negotiate(self, conn: context.Connection, data: bytes) -> Generator[commands.Command, Any, bool]:
|
||||
done, err = yield from super().negotiate(conn, data)
|
||||
if err:
|
||||
if self.context.client.sni:
|
||||
# TODO: Also use other sources than SNI
|
||||
dest = " for " + self.context.client.sni.decode("idna")
|
||||
dest = self.context.client.sni.decode("idna")
|
||||
else:
|
||||
dest = ""
|
||||
dest = human.format_address(self.context.server.address)
|
||||
yield commands.Log(
|
||||
f"Client TLS Handshake failed. "
|
||||
f"The client may not trust the proxy's certificate{dest} ({err}).",
|
||||
f"The client may not trust the proxy's certificate for {dest} ({err}).",
|
||||
level="warn"
|
||||
|
||||
)
|
||||
yield commands.CloseConnection(self.context.client)
|
||||
return done
|
||||
|
@ -16,11 +16,16 @@ import traceback
|
||||
import typing
|
||||
from contextlib import contextmanager
|
||||
|
||||
from OpenSSL import SSL
|
||||
|
||||
from mitmproxy import http, options as moptions
|
||||
from mitmproxy.addons import tlsconfig
|
||||
from mitmproxy.proxy.protocol.http import HTTPMode
|
||||
from mitmproxy.proxy2 import commands, events, layer, layers
|
||||
from mitmproxy.proxy2.context import Client, Connection, ConnectionState, Context
|
||||
from mitmproxy.proxy2.layers import tls
|
||||
from mitmproxy.utils import human
|
||||
from test.mitmproxy.proxy2.layers.test_tls import tlsdata
|
||||
|
||||
|
||||
class StreamIO(typing.NamedTuple):
|
||||
@ -270,9 +275,29 @@ if __name__ == "__main__":
|
||||
if "redirect" in flow.request.path:
|
||||
flow.request.host = "httpbin.org"
|
||||
|
||||
def tls_start(tls_start: tls.StartHookData):
|
||||
# INSECURE
|
||||
ssl_context = SSL.Context(SSL.SSLv23_METHOD)
|
||||
if tls_start.conn == tls_start.context.client:
|
||||
ssl_context.use_privatekey_file(
|
||||
tlsdata.path("../../data/verificationcerts/trusted-leaf.key")
|
||||
)
|
||||
ssl_context.use_certificate_chain_file(
|
||||
tlsdata.path("../../data/verificationcerts/trusted-leaf.crt")
|
||||
)
|
||||
|
||||
tls_start.ssl_conn = SSL.Connection(ssl_context)
|
||||
|
||||
if tls_start.conn == tls_start.context.client:
|
||||
tls_start.ssl_conn.set_accept_state()
|
||||
else:
|
||||
tls_start.ssl_conn.set_connect_state()
|
||||
tls_start.ssl_conn.set_tlsext_host_name(tls_start.context.client.sni)
|
||||
|
||||
await SimpleConnectionHandler(reader, writer, opts, {
|
||||
"next_layer": next_layer,
|
||||
"request": request
|
||||
"request": request,
|
||||
"tls_start": tls_start,
|
||||
}).handle_client()
|
||||
|
||||
|
||||
|
@ -39,6 +39,7 @@ def sign(cert: str, subject: str):
|
||||
f"-extfile \"openssl-{cert}.conf\" "
|
||||
f"-out {cert}.crt"
|
||||
)
|
||||
os.remove(f"openssl-{cert}.conf")
|
||||
|
||||
|
||||
def mkcert(cert, subject):
|
||||
@ -63,7 +64,7 @@ h = do("openssl x509 -hash -noout -in trusted-root.crt").decode("ascii").strip()
|
||||
shutil.copyfile("trusted-root.crt", "{}.0".format(h))
|
||||
|
||||
# create trusted leaf cert.
|
||||
mkcert("trusted-leaf", f'DNS:{SUBJECT}' )
|
||||
mkcert("trusted-leaf", f'DNS:{SUBJECT}')
|
||||
|
||||
# create self-signed cert
|
||||
genrsa("self-signed")
|
||||
|
@ -155,7 +155,10 @@ def reply_tls_start(alpn: typing.Optional[bytes] = None, *args, **kwargs) -> tut
|
||||
|
||||
tls_start.ssl_conn = SSL.Connection(ssl_context)
|
||||
|
||||
if tls_start.conn != tls_start.context.client:
|
||||
if tls_start.conn == tls_start.context.client:
|
||||
tls_start.ssl_conn.set_accept_state()
|
||||
else:
|
||||
tls_start.ssl_conn.set_connect_state()
|
||||
# Set SNI
|
||||
tls_start.ssl_conn.set_tlsext_host_name(tls_start.conn.sni)
|
||||
|
||||
@ -233,6 +236,16 @@ class TestServerTLS:
|
||||
# Echo
|
||||
_test_echo(playbook, tssl, tctx.server)
|
||||
|
||||
with pytest.raises(ssl.SSLWantReadError):
|
||||
tssl.obj.unwrap()
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.server, tssl.out.read())
|
||||
<< commands.CloseConnection(tctx.server)
|
||||
>> events.ConnectionClosed(tctx.server)
|
||||
<< None
|
||||
)
|
||||
|
||||
def test_untrusted_cert(self, tctx):
|
||||
"""If the certificate is not trusted, we should fail."""
|
||||
layer = tls.ServerTLSLayer(tctx)
|
||||
@ -263,8 +276,7 @@ class TestServerTLS:
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.server, tssl.out.read())
|
||||
<< commands.SendData(tctx.client,
|
||||
b"server-tls-failed: Error([('SSL routines', 'tls_process_server_certificate', 'certificate verify failed')])")
|
||||
<< commands.SendData(tctx.client, b"server-tls-failed: Certificate verify failed: Hostname mismatch")
|
||||
)
|
||||
assert not tctx.server.tls_established
|
||||
|
||||
|
@ -189,7 +189,14 @@ class Playbook:
|
||||
if not self.logs:
|
||||
for offset, cmd in enumerate(cmds):
|
||||
pos = i + 1 + offset
|
||||
if isinstance(cmd, commands.Log) and not isinstance(self.expected[pos], commands.Log):
|
||||
need_to_emulate_log = (
|
||||
isinstance(cmd, commands.Log) and
|
||||
(
|
||||
pos >= len(self.expected)
|
||||
or not isinstance(self.expected[pos], commands.Log)
|
||||
)
|
||||
)
|
||||
if need_to_emulate_log:
|
||||
self.expected.insert(pos, cmd)
|
||||
if not self.hooks:
|
||||
last_cmd = self.actual[-1]
|
||||
@ -319,6 +326,8 @@ class EchoLayer(Layer):
|
||||
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
if isinstance(event, events.DataReceived):
|
||||
yield commands.SendData(event.connection, event.data.lower())
|
||||
if isinstance(event, events.ConnectionClosed):
|
||||
yield commands.CloseConnection(event.connection)
|
||||
|
||||
|
||||
def reply_next_layer(
|
||||
|
Loading…
Reference in New Issue
Block a user