mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 10:16:27 +00:00
[sans-io] tls: various improvements
This commit is contained in:
parent
09b6257de0
commit
6cf0bec912
@ -78,6 +78,8 @@ class TlsConfig:
|
|||||||
tls_start.ssl_conn.set_app_data({
|
tls_start.ssl_conn.set_app_data({
|
||||||
"server_alpn": tls_start.context.server.alpn
|
"server_alpn": tls_start.context.server.alpn
|
||||||
})
|
})
|
||||||
|
tls_start.ssl_conn.set_accept_state()
|
||||||
|
|
||||||
|
|
||||||
def create_proxy_server_ssl_conn(self, tls_start: tls.StartHookData) -> None:
|
def create_proxy_server_ssl_conn(self, tls_start: tls.StartHookData) -> None:
|
||||||
client = tls_start.context.client
|
client = tls_start.context.client
|
||||||
@ -126,6 +128,7 @@ class TlsConfig:
|
|||||||
**args
|
**args
|
||||||
)
|
)
|
||||||
tls_start.ssl_conn = SSL.Connection(ssl_ctx)
|
tls_start.ssl_conn = SSL.Connection(ssl_ctx)
|
||||||
|
tls_start.ssl_conn.set_connect_state()
|
||||||
|
|
||||||
def configure(self, updated):
|
def configure(self, updated):
|
||||||
if not any(x in updated for x in ["confdir", "certs"]):
|
if not any(x in updated for x in ["confdir", "certs"]):
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from enum import Flag, auto
|
from enum import Flag, auto
|
||||||
from typing import List, Optional, Sequence, Union
|
from typing import List, Optional, Sequence, Union
|
||||||
|
|
||||||
|
from mitmproxy import certs
|
||||||
from mitmproxy.options import Options
|
from mitmproxy.options import Options
|
||||||
|
|
||||||
|
|
||||||
@ -19,6 +20,7 @@ class Connection:
|
|||||||
state: ConnectionState
|
state: ConnectionState
|
||||||
tls: bool = False
|
tls: bool = False
|
||||||
tls_established: bool = False
|
tls_established: bool = False
|
||||||
|
certificate_chain: Optional[Sequence[certs.Cert]] = None
|
||||||
alpn: Optional[bytes] = None
|
alpn: Optional[bytes] = None
|
||||||
alpn_offers: Sequence[bytes] = ()
|
alpn_offers: Sequence[bytes] = ()
|
||||||
cipher_list: Sequence[bytes] = ()
|
cipher_list: Sequence[bytes] = ()
|
||||||
|
@ -4,10 +4,12 @@ from typing import Any, Dict, Generator, Iterator, Optional, Tuple
|
|||||||
|
|
||||||
from OpenSSL import SSL
|
from OpenSSL import SSL
|
||||||
|
|
||||||
|
from mitmproxy import certs
|
||||||
from mitmproxy.net import tls as net_tls
|
from mitmproxy.net import tls as net_tls
|
||||||
from mitmproxy.proxy2 import commands, events, layer
|
from mitmproxy.proxy2 import commands, events, layer
|
||||||
from mitmproxy.proxy2 import context
|
from mitmproxy.proxy2 import context
|
||||||
from mitmproxy.proxy2.utils import expect
|
from mitmproxy.proxy2.utils import expect
|
||||||
|
from mitmproxy.utils import human
|
||||||
|
|
||||||
|
|
||||||
def is_tls_handshake_record(d: bytes) -> bool:
|
def is_tls_handshake_record(d: bytes) -> bool:
|
||||||
@ -142,6 +144,17 @@ class _TLSLayer(layer.Layer):
|
|||||||
state = ", ".join(conn_states)
|
state = ", ".join(conn_states)
|
||||||
return f"{type(self).__name__}({state})"
|
return f"{type(self).__name__}({state})"
|
||||||
|
|
||||||
|
def start_tls(self, conn: context.Connection, initial_data: bytes = b""):
|
||||||
|
assert conn not in self.tls
|
||||||
|
assert conn.connected
|
||||||
|
conn.tls = True
|
||||||
|
|
||||||
|
tls_start = StartHookData(conn, self.context)
|
||||||
|
yield commands.Hook("tls_start", tls_start)
|
||||||
|
self.tls[conn] = tls_start.ssl_conn
|
||||||
|
|
||||||
|
yield from self.negotiate(conn, initial_data)
|
||||||
|
|
||||||
def tls_interact(self, conn: context.Connection) -> commands.TCommandGenerator:
|
def tls_interact(self, conn: context.Connection) -> commands.TCommandGenerator:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@ -163,17 +176,37 @@ class _TLSLayer(layer.Layer):
|
|||||||
yield from self.tls_interact(conn)
|
yield from self.tls_interact(conn)
|
||||||
return False, None
|
return False, None
|
||||||
except SSL.Error as e:
|
except SSL.Error as e:
|
||||||
return False, repr(e)
|
# provide more detailed information for some errors.
|
||||||
|
last_err = e.args[0][-1]
|
||||||
|
if last_err == ('SSL routines', 'tls_process_server_certificate', 'certificate verify failed'):
|
||||||
|
verify_result = SSL._lib.SSL_get_verify_result(self.tls[conn]._ssl)
|
||||||
|
error = SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(verify_result)).decode()
|
||||||
|
return False, f"Certificate verify failed: {error}"
|
||||||
|
elif last_err == ('SSL routines', 'ssl3_read_bytes', 'tlsv1 alert unknown ca'):
|
||||||
|
return False, "TLS Alert: Unknown CA"
|
||||||
|
else:
|
||||||
|
return False, repr(e)
|
||||||
else:
|
else:
|
||||||
|
# Get all peer certificates.
|
||||||
|
# https://www.openssl.org/docs/man1.1.1/man3/SSL_get_peer_cert_chain.html
|
||||||
|
# If called on the client side, the stack also contains the peer's certificate; if called on the server
|
||||||
|
# side, the peer's certificate must be obtained separately using SSL_get_peer_certificate(3).
|
||||||
|
all_certs = self.tls[conn].get_peer_cert_chain() or []
|
||||||
|
if conn == self.context.client:
|
||||||
|
cert = self.tls[conn].get_peer_certificate()
|
||||||
|
if cert:
|
||||||
|
all_certs.insert(0, cert)
|
||||||
|
|
||||||
|
|
||||||
conn.tls_established = True
|
conn.tls_established = True
|
||||||
conn.sni = self.tls[conn].get_servername()
|
conn.sni = self.tls[conn].get_servername()
|
||||||
conn.alpn = self.tls[conn].get_alpn_proto_negotiated()
|
conn.alpn = self.tls[conn].get_alpn_proto_negotiated()
|
||||||
|
conn.certificate_chain = [certs.Cert(x) for x in all_certs]
|
||||||
conn.cipher_list = self.tls[conn].get_cipher_list()
|
conn.cipher_list = self.tls[conn].get_cipher_list()
|
||||||
conn.tls_version = self.tls[conn].get_protocol_version_name()
|
conn.tls_version = self.tls[conn].get_protocol_version_name()
|
||||||
conn.timestamp_tls_setup = time.time()
|
conn.timestamp_tls_setup = time.time()
|
||||||
yield commands.Log(f"TLS established: {conn}")
|
yield commands.Log(f"TLS established: {conn}")
|
||||||
yield from self.receive(conn, b"")
|
yield from self.receive(conn, b"")
|
||||||
# TODO: Set all other connection attributes here
|
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
def receive(self, conn: context.Connection, data: bytes):
|
def receive(self, conn: context.Connection, data: bytes):
|
||||||
@ -250,22 +283,10 @@ class ServerTLSLayer(_TLSLayer):
|
|||||||
for command in super().event_to_child(event):
|
for command in super().event_to_child(event):
|
||||||
if isinstance(command, EstablishServerTLS):
|
if isinstance(command, EstablishServerTLS):
|
||||||
self.command_to_reply_to[command.connection] = command
|
self.command_to_reply_to[command.connection] = command
|
||||||
yield from self.start_server_tls(command.connection)
|
yield from self.start_tls(command.connection)
|
||||||
else:
|
else:
|
||||||
yield command
|
yield command
|
||||||
|
|
||||||
def start_server_tls(self, conn: context.Server):
|
|
||||||
assert conn not in self.tls
|
|
||||||
assert conn.connected
|
|
||||||
conn.tls = True
|
|
||||||
|
|
||||||
tls_start = StartHookData(conn, self.context)
|
|
||||||
yield commands.Hook("tls_start", tls_start)
|
|
||||||
self.tls[conn] = tls_start.ssl_conn
|
|
||||||
self.tls[conn].set_connect_state()
|
|
||||||
|
|
||||||
yield from self.negotiate(conn, b"")
|
|
||||||
|
|
||||||
|
|
||||||
class ClientTLSLayer(_TLSLayer):
|
class ClientTLSLayer(_TLSLayer):
|
||||||
"""
|
"""
|
||||||
@ -320,13 +341,12 @@ class ClientTLSLayer(_TLSLayer):
|
|||||||
yield commands.Log("Unable to establish TLS connection with server. "
|
yield commands.Log("Unable to establish TLS connection with server. "
|
||||||
"Trying to establish TLS with client anyway.")
|
"Trying to establish TLS with client anyway.")
|
||||||
|
|
||||||
yield from self.start_client_tls()
|
yield from self.start_tls(client, bytes(self.recv_buffer))
|
||||||
|
self.recv_buffer.clear()
|
||||||
self._handle_event = super()._handle_event
|
self._handle_event = super()._handle_event
|
||||||
|
|
||||||
# In any case, we now have enough information to start server TLS if needed.
|
# In any case, we now have enough information to start server TLS if needed.
|
||||||
yield from self.event_to_child(events.Start())
|
yield from self.event_to_child(events.Start())
|
||||||
elif isinstance(event, events.ConnectionClosed) and event.connection == client:
|
|
||||||
self.recv_buffer.clear()
|
|
||||||
else:
|
else:
|
||||||
yield from self.event_to_child(event)
|
yield from self.event_to_child(event)
|
||||||
|
|
||||||
@ -351,29 +371,17 @@ class ClientTLSLayer(_TLSLayer):
|
|||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
|
|
||||||
def start_client_tls(self) -> commands.TCommandGenerator:
|
|
||||||
client = self.context.client
|
|
||||||
tls_start = StartHookData(client, self.context)
|
|
||||||
yield commands.Hook("tls_start", tls_start)
|
|
||||||
self.tls[client] = tls_start.ssl_conn
|
|
||||||
self.tls[client].set_accept_state()
|
|
||||||
|
|
||||||
yield from self.negotiate(client, bytes(self.recv_buffer))
|
|
||||||
self.recv_buffer.clear()
|
|
||||||
|
|
||||||
def negotiate(self, conn: context.Connection, data: bytes) -> Generator[commands.Command, Any, bool]:
|
def negotiate(self, conn: context.Connection, data: bytes) -> Generator[commands.Command, Any, bool]:
|
||||||
done, err = yield from super().negotiate(conn, data)
|
done, err = yield from super().negotiate(conn, data)
|
||||||
if err:
|
if err:
|
||||||
if self.context.client.sni:
|
if self.context.client.sni:
|
||||||
# TODO: Also use other sources than SNI
|
dest = self.context.client.sni.decode("idna")
|
||||||
dest = " for " + self.context.client.sni.decode("idna")
|
|
||||||
else:
|
else:
|
||||||
dest = ""
|
dest = human.format_address(self.context.server.address)
|
||||||
yield commands.Log(
|
yield commands.Log(
|
||||||
f"Client TLS Handshake failed. "
|
f"Client TLS Handshake failed. "
|
||||||
f"The client may not trust the proxy's certificate{dest} ({err}).",
|
f"The client may not trust the proxy's certificate for {dest} ({err}).",
|
||||||
level="warn"
|
level="warn"
|
||||||
|
|
||||||
)
|
)
|
||||||
yield commands.CloseConnection(self.context.client)
|
yield commands.CloseConnection(self.context.client)
|
||||||
return done
|
return done
|
||||||
|
@ -16,11 +16,16 @@ import traceback
|
|||||||
import typing
|
import typing
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from OpenSSL import SSL
|
||||||
|
|
||||||
from mitmproxy import http, options as moptions
|
from mitmproxy import http, options as moptions
|
||||||
|
from mitmproxy.addons import tlsconfig
|
||||||
from mitmproxy.proxy.protocol.http import HTTPMode
|
from mitmproxy.proxy.protocol.http import HTTPMode
|
||||||
from mitmproxy.proxy2 import commands, events, layer, layers
|
from mitmproxy.proxy2 import commands, events, layer, layers
|
||||||
from mitmproxy.proxy2.context import Client, Connection, ConnectionState, Context
|
from mitmproxy.proxy2.context import Client, Connection, ConnectionState, Context
|
||||||
|
from mitmproxy.proxy2.layers import tls
|
||||||
from mitmproxy.utils import human
|
from mitmproxy.utils import human
|
||||||
|
from test.mitmproxy.proxy2.layers.test_tls import tlsdata
|
||||||
|
|
||||||
|
|
||||||
class StreamIO(typing.NamedTuple):
|
class StreamIO(typing.NamedTuple):
|
||||||
@ -270,9 +275,29 @@ if __name__ == "__main__":
|
|||||||
if "redirect" in flow.request.path:
|
if "redirect" in flow.request.path:
|
||||||
flow.request.host = "httpbin.org"
|
flow.request.host = "httpbin.org"
|
||||||
|
|
||||||
|
def tls_start(tls_start: tls.StartHookData):
|
||||||
|
# INSECURE
|
||||||
|
ssl_context = SSL.Context(SSL.SSLv23_METHOD)
|
||||||
|
if tls_start.conn == tls_start.context.client:
|
||||||
|
ssl_context.use_privatekey_file(
|
||||||
|
tlsdata.path("../../data/verificationcerts/trusted-leaf.key")
|
||||||
|
)
|
||||||
|
ssl_context.use_certificate_chain_file(
|
||||||
|
tlsdata.path("../../data/verificationcerts/trusted-leaf.crt")
|
||||||
|
)
|
||||||
|
|
||||||
|
tls_start.ssl_conn = SSL.Connection(ssl_context)
|
||||||
|
|
||||||
|
if tls_start.conn == tls_start.context.client:
|
||||||
|
tls_start.ssl_conn.set_accept_state()
|
||||||
|
else:
|
||||||
|
tls_start.ssl_conn.set_connect_state()
|
||||||
|
tls_start.ssl_conn.set_tlsext_host_name(tls_start.context.client.sni)
|
||||||
|
|
||||||
await SimpleConnectionHandler(reader, writer, opts, {
|
await SimpleConnectionHandler(reader, writer, opts, {
|
||||||
"next_layer": next_layer,
|
"next_layer": next_layer,
|
||||||
"request": request
|
"request": request,
|
||||||
|
"tls_start": tls_start,
|
||||||
}).handle_client()
|
}).handle_client()
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,6 +39,7 @@ def sign(cert: str, subject: str):
|
|||||||
f"-extfile \"openssl-{cert}.conf\" "
|
f"-extfile \"openssl-{cert}.conf\" "
|
||||||
f"-out {cert}.crt"
|
f"-out {cert}.crt"
|
||||||
)
|
)
|
||||||
|
os.remove(f"openssl-{cert}.conf")
|
||||||
|
|
||||||
|
|
||||||
def mkcert(cert, subject):
|
def mkcert(cert, subject):
|
||||||
@ -63,7 +64,7 @@ h = do("openssl x509 -hash -noout -in trusted-root.crt").decode("ascii").strip()
|
|||||||
shutil.copyfile("trusted-root.crt", "{}.0".format(h))
|
shutil.copyfile("trusted-root.crt", "{}.0".format(h))
|
||||||
|
|
||||||
# create trusted leaf cert.
|
# create trusted leaf cert.
|
||||||
mkcert("trusted-leaf", f'DNS:{SUBJECT}' )
|
mkcert("trusted-leaf", f'DNS:{SUBJECT}')
|
||||||
|
|
||||||
# create self-signed cert
|
# create self-signed cert
|
||||||
genrsa("self-signed")
|
genrsa("self-signed")
|
||||||
@ -72,4 +73,4 @@ do("openssl req -x509 -new -nodes -batch "
|
|||||||
f'-addext "subjectAltName = DNS:{SUBJECT}" '
|
f'-addext "subjectAltName = DNS:{SUBJECT}" '
|
||||||
"-days 7300 "
|
"-days 7300 "
|
||||||
"-out self-signed.crt"
|
"-out self-signed.crt"
|
||||||
)
|
)
|
@ -155,7 +155,10 @@ def reply_tls_start(alpn: typing.Optional[bytes] = None, *args, **kwargs) -> tut
|
|||||||
|
|
||||||
tls_start.ssl_conn = SSL.Connection(ssl_context)
|
tls_start.ssl_conn = SSL.Connection(ssl_context)
|
||||||
|
|
||||||
if tls_start.conn != tls_start.context.client:
|
if tls_start.conn == tls_start.context.client:
|
||||||
|
tls_start.ssl_conn.set_accept_state()
|
||||||
|
else:
|
||||||
|
tls_start.ssl_conn.set_connect_state()
|
||||||
# Set SNI
|
# Set SNI
|
||||||
tls_start.ssl_conn.set_tlsext_host_name(tls_start.conn.sni)
|
tls_start.ssl_conn.set_tlsext_host_name(tls_start.conn.sni)
|
||||||
|
|
||||||
@ -233,6 +236,16 @@ class TestServerTLS:
|
|||||||
# Echo
|
# Echo
|
||||||
_test_echo(playbook, tssl, tctx.server)
|
_test_echo(playbook, tssl, tctx.server)
|
||||||
|
|
||||||
|
with pytest.raises(ssl.SSLWantReadError):
|
||||||
|
tssl.obj.unwrap()
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
>> events.DataReceived(tctx.server, tssl.out.read())
|
||||||
|
<< commands.CloseConnection(tctx.server)
|
||||||
|
>> events.ConnectionClosed(tctx.server)
|
||||||
|
<< None
|
||||||
|
)
|
||||||
|
|
||||||
def test_untrusted_cert(self, tctx):
|
def test_untrusted_cert(self, tctx):
|
||||||
"""If the certificate is not trusted, we should fail."""
|
"""If the certificate is not trusted, we should fail."""
|
||||||
layer = tls.ServerTLSLayer(tctx)
|
layer = tls.ServerTLSLayer(tctx)
|
||||||
@ -263,8 +276,7 @@ class TestServerTLS:
|
|||||||
assert (
|
assert (
|
||||||
playbook
|
playbook
|
||||||
>> events.DataReceived(tctx.server, tssl.out.read())
|
>> events.DataReceived(tctx.server, tssl.out.read())
|
||||||
<< commands.SendData(tctx.client,
|
<< commands.SendData(tctx.client, b"server-tls-failed: Certificate verify failed: Hostname mismatch")
|
||||||
b"server-tls-failed: Error([('SSL routines', 'tls_process_server_certificate', 'certificate verify failed')])")
|
|
||||||
)
|
)
|
||||||
assert not tctx.server.tls_established
|
assert not tctx.server.tls_established
|
||||||
|
|
||||||
|
@ -189,7 +189,14 @@ class Playbook:
|
|||||||
if not self.logs:
|
if not self.logs:
|
||||||
for offset, cmd in enumerate(cmds):
|
for offset, cmd in enumerate(cmds):
|
||||||
pos = i + 1 + offset
|
pos = i + 1 + offset
|
||||||
if isinstance(cmd, commands.Log) and not isinstance(self.expected[pos], commands.Log):
|
need_to_emulate_log = (
|
||||||
|
isinstance(cmd, commands.Log) and
|
||||||
|
(
|
||||||
|
pos >= len(self.expected)
|
||||||
|
or not isinstance(self.expected[pos], commands.Log)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if need_to_emulate_log:
|
||||||
self.expected.insert(pos, cmd)
|
self.expected.insert(pos, cmd)
|
||||||
if not self.hooks:
|
if not self.hooks:
|
||||||
last_cmd = self.actual[-1]
|
last_cmd = self.actual[-1]
|
||||||
@ -319,6 +326,8 @@ class EchoLayer(Layer):
|
|||||||
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||||
if isinstance(event, events.DataReceived):
|
if isinstance(event, events.DataReceived):
|
||||||
yield commands.SendData(event.connection, event.data.lower())
|
yield commands.SendData(event.connection, event.data.lower())
|
||||||
|
if isinstance(event, events.ConnectionClosed):
|
||||||
|
yield commands.CloseConnection(event.connection)
|
||||||
|
|
||||||
|
|
||||||
def reply_next_layer(
|
def reply_next_layer(
|
||||||
|
Loading…
Reference in New Issue
Block a user