alpn: str -> bytes

This commit is contained in:
Maximilian Hils 2020-12-30 22:52:07 +01:00
parent 8ac5af62f5
commit dfba6e81a6
15 changed files with 47 additions and 54 deletions

View File

@ -8,7 +8,6 @@ from mitmproxy.net import tls as net_tls
from mitmproxy.options import CONF_BASENAME
from mitmproxy.proxy import context
from mitmproxy.proxy.layers import tls
from mitmproxy.utils.strutils import always_bytes
# We manually need to specify this, otherwise OpenSSL may select a non-HTTP2 cipher by default.
# https://ssl-config.mozilla.org/#config=old
@ -36,7 +35,7 @@ def alpn_select_callback(conn: SSL.Connection, options: List[bytes]) -> Any:
return server_alpn
http_alpns = tls.HTTP_ALPNS if http2 else tls.HTTP1_ALPNS
for alpn in options: # client sends in order of preference, so we are nice and respect that.
if alpn.decode(errors="replace") in http_alpns:
if alpn in http_alpns:
return alpn
else:
return SSL.NO_OVERLAPPING_PROTOCOLS
@ -138,7 +137,7 @@ class TlsConfig:
)
tls_start.ssl_conn = SSL.Connection(ssl_ctx)
tls_start.ssl_conn.set_app_data(AppData(
server_alpn=always_bytes(server.alpn, "utf8", "replace"),
server_alpn=server.alpn,
http2=ctx.options.http2,
))
tls_start.ssl_conn.set_accept_state()
@ -155,14 +154,13 @@ class TlsConfig:
if server.sni is True:
server.sni = client.sni or server.address[0]
sni: Optional[bytes] = server.sni.encode("ascii") if server.sni else None
if not server.alpn_offers:
if client.alpn_offers:
if ctx.options.http2:
server.alpn_offers = tuple(client.alpn_offers)
else:
server.alpn_offers = tuple(x for x in client.alpn_offers if x != "h2")
server.alpn_offers = tuple(x for x in client.alpn_offers if x != b"h2")
elif client.tls_established:
# We would perfectly support HTTP/1 -> HTTP/2, but we want to keep things on the same protocol version.
# There are some edge cases where we want to mirror the regular server's behavior accurately,
@ -172,7 +170,6 @@ class TlsConfig:
server.alpn_offers = tls.HTTP_ALPNS
else:
server.alpn_offers = tls.HTTP1_ALPNS
alpn_offers: List[bytes] = [alpn.encode() for alpn in server.alpn_offers]
if not server.cipher_list and ctx.options.ciphers_server:
server.cipher_list = ctx.options.ciphers_server.split(":")
@ -195,15 +192,16 @@ class TlsConfig:
max_version=net_tls.Version[ctx.options.tls_version_client_max],
cipher_list=cipher_list,
verify=verify,
sni=sni,
sni=server.sni,
ca_path=ctx.options.ssl_verify_upstream_trusted_confdir,
ca_pemfile=ctx.options.ssl_verify_upstream_trusted_ca,
client_cert=client_cert,
alpn_protos=alpn_offers,
alpn_protos=server.alpn_offers,
)
tls_start.ssl_conn = SSL.Connection(ssl_ctx)
tls_start.ssl_conn.set_tlsext_host_name(sni)
if server.sni:
tls_start.ssl_conn.set_tlsext_host_name(server.sni.encode())
tls_start.ssl_conn.set_connect_state()
def running(self):

View File

@ -235,11 +235,9 @@ def convert_10_11(data):
def conv_conn(conn):
conn["sni"] = strutils.always_str(conn["sni"], "ascii", "backslashreplace")
conn["alpn"] = strutils.always_str(conn.pop("alpn_proto_negotiated"), "utf8", "backslashreplace")
conn["alpn_offers"] = [
strutils.always_str(alpn, "utf8", "backslashreplace")
for alpn in (conn["alpn_offers"] or [])
]
conn["alpn"] = conn.pop("alpn_proto_negotiated")
conn["alpn_offers"] = conn["alpn_offers"] or []
conn["cipher_list"] = conn["cipher_list"] or []
conv_conn(data["client_conn"])
conv_conn(data["server_conn"])

View File

@ -130,7 +130,7 @@ def create_proxy_server_context(
max_version: Version,
cipher_list: Optional[Iterable[str]],
verify: Verify,
sni: Optional[bytes],
sni: Optional[str],
ca_path: Optional[str],
ca_pemfile: Optional[str],
client_cert: Optional[str],
@ -148,6 +148,7 @@ def create_proxy_server_context(
context.set_verify(verify.value, None)
if sni is not None:
assert isinstance(sni, str)
# Manually enable hostname verification on the context object.
# https://wiki.openssl.org/index.php/Hostname_validation
param = SSL._lib.SSL_CTX_get0_param(context._context)
@ -158,7 +159,7 @@ def create_proxy_server_context(
SSL._lib.X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS | SSL._lib.X509_CHECK_FLAG_NEVER_CHECK_SUBJECT
)
SSL._openssl_assert(
SSL._lib.X509_VERIFY_PARAM_set1_host(param, sni, 0) == 1
SSL._lib.X509_VERIFY_PARAM_set1_host(param, sni.encode(), 0) == 1
)
if ca_path is None and ca_pemfile is None:
@ -293,14 +294,11 @@ class ClientHello:
return None
@property
def alpn_protocols(self) -> List[str]:
def alpn_protocols(self) -> List[bytes]:
if self._client_hello.extensions:
for extension in self._client_hello.extensions.extensions:
if extension.type == 0x10:
try:
return [x.name.decode() for x in extension.body.alpn_protocols]
except UnicodeDecodeError:
return []
return list(x.name for x in extension.body.alpn_protocols)
return []
@property

View File

@ -56,8 +56,8 @@ class Connection(serializable.Serializable, metaclass=ABCMeta):
TLS version, with the exception of the end-entity certificate which
MUST be first.
"""
alpn: Optional[str] = None
alpn_offers: Sequence[str] = ()
alpn: Optional[bytes] = None
alpn_offers: Sequence[bytes] = ()
# we may want to add SSL_CIPHER_description here, but that's currently not exposed by cryptography
cipher: Optional[str] = None
@ -98,9 +98,7 @@ class Connection(serializable.Serializable, metaclass=ABCMeta):
@property
def alpn_proto_negotiated(self) -> Optional[bytes]: # pragma: no cover
warnings.warn("Server.alpn_proto_negotiated is deprecated, use Server.alpn instead.", DeprecationWarning)
if self.alpn is not None:
return self.alpn.encode()
return None
return self.alpn
class Client(Connection):

View File

@ -351,7 +351,7 @@ class HttpStream(layer.Layer):
yield HttpErrorHook(self.flow)
# For HTTP/2 we only want to kill the specific stream, for HTTP/1 we want to kill the connection
# *without* sending an HTTP response (that could be achieved by the user by setting flow.response).
if self.context.client.alpn == "h2":
if self.context.client.alpn == b"h2":
yield SendHttp(ResponseProtocolError(self.stream_id, "killed"), self.context.client)
else:
if self.context.client.state & ConnectionState.CAN_WRITE:
@ -532,7 +532,7 @@ class HttpLayer(layer.Layer):
self.command_sources = {}
http_conn: HttpConnection
if self.context.client.alpn == "h2":
if self.context.client.alpn == b"h2":
http_conn = Http2Server(context.fork())
else:
http_conn = Http1Server(context.fork())
@ -606,10 +606,10 @@ class HttpLayer(layer.Layer):
for connection in self.connections:
# see "tricky multiplexing edge case" in make_http_connection for an explanation
conn_is_pending_or_h2 = (
connection.alpn == "h2"
connection.alpn == b"h2"
or connection in self.waiting_for_establishment
)
h2_to_h1 = self.context.client.alpn == "h2" and not conn_is_pending_or_h2
h2_to_h1 = self.context.client.alpn == b"h2" and not conn_is_pending_or_h2
connection_suitable = (
event.connection_spec_matches(connection)
and not h2_to_h1
@ -679,7 +679,7 @@ class HttpLayer(layer.Layer):
# that neither have a content-length specified nor a chunked transfer encoding.
# We can't process these two flows to the same h1 connection as they would both have
# "read until eof" semantics. The only workaround left is to open a separate connection for each flow.
if not command.err and self.context.client.alpn == "h2" and command.connection.alpn != "h2":
if not command.err and self.context.client.alpn == b"h2" and command.connection.alpn != b"h2":
for cmd in waiting[1:]:
yield from self.get_connection(cmd, reuse=False)
break
@ -695,7 +695,7 @@ class HttpClient(layer.Layer):
err = yield commands.OpenConnection(self.context.server)
if not err:
child_layer: layer.Layer
if self.context.server.alpn == "h2":
if self.context.server.alpn == b"h2":
child_layer = Http2Client(self.context)
else:
child_layer = Http1Client(self.context)

View File

@ -91,8 +91,8 @@ def parse_client_hello(data: bytes) -> Optional[net_tls.ClientHello]:
return None
HTTP1_ALPNS = ("http/1.1", "http/1.0", "http/0.9")
HTTP_ALPNS = ("h2",) + HTTP1_ALPNS
HTTP1_ALPNS = (b"http/1.1", b"http/1.0", b"http/0.9")
HTTP_ALPNS = (b"h2",) + HTTP1_ALPNS
# We need these classes as hooks can only have one argument at the moment.
@ -196,7 +196,7 @@ class _TLSLayer(tunnel.TunnelLayer):
all_certs.insert(0, cert)
self.conn.timestamp_tls_setup = time.time()
self.conn.alpn = self.tls.get_alpn_proto_negotiated().decode()
self.conn.alpn = self.tls.get_alpn_proto_negotiated()
self.conn.certificate_list = [certs.Cert.from_pyopenssl(x) for x in all_certs]
self.conn.cipher = self.tls.get_cipher_name()
self.conn.tls_version = self.tls.get_protocol_version_name()

View File

@ -158,7 +158,7 @@ def tclient_conn() -> context.Client:
timestamp_end=946681206,
sni="address",
cipher_name="cipher",
alpn="http/1.1",
alpn=b"http/1.1",
tls_version="TLSv1.2",
tls_extensions=[(0x00, bytes.fromhex("000e00000b6578616d"))],
state=0,

View File

@ -4,7 +4,7 @@ import urwid
import mitmproxy.flow
from mitmproxy import http
from mitmproxy.tools.console import common, searchable
from mitmproxy.utils import human
from mitmproxy.utils import human, strutils
def maybe_timestamp(base, attr):
@ -49,7 +49,7 @@ def flowdetails(state, flow: mitmproxy.flow.Flow):
if resp:
parts.append(("HTTP Version", resp.http_version))
if sc.alpn:
parts.append(("ALPN", sc.alpn))
parts.append(("ALPN", strutils.bytes_to_escaped_str(sc.alpn)))
text.extend(
common.format_keyvals(parts, indent=4)
@ -69,7 +69,7 @@ def flowdetails(state, flow: mitmproxy.flow.Flow):
]
if c.altnames:
parts.append(("Alt names", ", ".join(c.altnames)))
parts.append(("Alt names", ", ".join(strutils.bytes_to_escaped_str(x) for x in c.altnames)))
text.extend(
common.format_keyvals(parts, indent=4)
)
@ -89,7 +89,7 @@ def flowdetails(state, flow: mitmproxy.flow.Flow):
if cc.cipher:
parts.append(("Cipher Name", cc.cipher))
if cc.alpn:
parts.append(("ALPN", cc.alpn))
parts.append(("ALPN", strutils.bytes_to_escaped_str(cc.alpn)))
text.extend(
common.format_keyvals(parts, indent=4)

View File

@ -20,6 +20,7 @@ from mitmproxy import io
from mitmproxy import log
from mitmproxy import optmanager
from mitmproxy import version
from mitmproxy.utils.strutils import always_str
def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
@ -48,7 +49,7 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
"timestamp_end": flow.client_conn.timestamp_end,
"sni": flow.client_conn.sni,
"cipher_name": flow.client_conn.cipher,
"alpn_proto_negotiated": flow.client_conn.alpn,
"alpn_proto_negotiated": always_str(flow.client_conn.alpn, "ascii", "backslashreplace"),
"tls_version": flow.client_conn.tls_version,
}
@ -60,7 +61,7 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
"source_address": flow.server_conn.sockname,
"tls_established": flow.server_conn.tls_established,
"sni": flow.server_conn.sni,
"alpn_proto_negotiated": flow.server_conn.alpn,
"alpn_proto_negotiated": always_str(flow.client_conn.alpn, "ascii", "backslashreplace"),
"tls_version": flow.server_conn.tls_version,
"timestamp_start": flow.server_conn.timestamp_start,
"timestamp_tcp_setup": flow.server_conn.timestamp_tcp_setup,

View File

@ -127,7 +127,7 @@ class TestTlsConfig:
ta = tlsconfig.TlsConfig()
with taddons.context(ta) as tctx:
ctx = context.Context(context.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options)
ctx.client.alpn_offers = ["h2"]
ctx.client.alpn_offers = [b"h2"]
ctx.client.cipher_list = ["TLS_AES_256_GCM_SHA384", "ECDHE-RSA-AES128-SHA"]
ctx.server.address = ("example.mitmproxy.org", 443)
@ -185,8 +185,8 @@ class TestTlsConfig:
ta.tls_start(tls_start)
assert ctx.server.alpn_offers == expected
assert_alpn(True, tls.HTTP_ALPNS + ("foo",), tls.HTTP_ALPNS + ("foo",))
assert_alpn(False, tls.HTTP_ALPNS + ("foo",), tls.HTTP1_ALPNS + ("foo",))
assert_alpn(True, tls.HTTP_ALPNS + (b"foo",), tls.HTTP_ALPNS + (b"foo",))
assert_alpn(False, tls.HTTP_ALPNS + (b"foo",), tls.HTTP1_ALPNS + (b"foo",))
assert_alpn(True, [], tls.HTTP_ALPNS)
assert_alpn(False, [], tls.HTTP1_ALPNS)
ctx.client.timestamp_tls_setup = time.time()

View File

@ -110,7 +110,7 @@ class TestClientHello:
49195, 49199, 49196, 49200, 52393, 52392, 52244, 52243, 49161,
49171, 49162, 49172, 156, 157, 47, 53, 10
]
assert c.alpn_protocols == ['h2', 'http/1.1']
assert c.alpn_protocols == [b'h2', b'http/1.1']
assert c.extensions == [
(65281, b'\x00'),
(0, b'\x00\x0e\x00\x00\x0bexample.com'),

View File

@ -44,7 +44,7 @@ def decode_frames(data: bytes) -> List[hyperframe.frame.Frame]:
def start_h2_client(tctx: Context) -> Tuple[Playbook, FrameFactory]:
tctx.client.alpn = "h2"
tctx.client.alpn = b"h2"
frame_factory = FrameFactory()
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular))
@ -58,7 +58,7 @@ def start_h2_client(tctx: Context) -> Tuple[Playbook, FrameFactory]:
def make_h2(open_connection: OpenConnection) -> None:
open_connection.connection.alpn = "h2"
open_connection.connection.alpn = b"h2"
def test_simple(tctx):

View File

@ -208,7 +208,7 @@ def h2_frames(draw):
def h2_layer(opts):
tctx = context.Context(context.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), opts)
tctx.client.alpn = "h2"
tctx.client.alpn = b"h2"
layer = http.HttpLayer(tctx, HTTPMode.regular)
for _ in layer.handle_event(Start()):

View File

@ -22,7 +22,7 @@ def event_types(events):
def h2_client(tctx: Context) -> Tuple[h2.connection.H2Connection, Playbook]:
tctx.client.alpn = "h2"
tctx.client.alpn = b"h2"
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular))
conn = h2.connection.H2Connection()

View File

@ -188,7 +188,7 @@ def reply_tls_start(alpn: typing.Optional[bytes] = None, *args, **kwargs) -> tut
tls_start.ssl_conn = SSL.Connection(ssl_context)
tls_start.ssl_conn.set_connect_state()
# Set SNI
tls_start.ssl_conn.set_tlsext_host_name(tls_start.conn.sni.encode("ascii"))
tls_start.ssl_conn.set_tlsext_host_name(tls_start.conn.sni.encode())
# Manually enable hostname verification.
# Recent OpenSSL versions provide slightly nicer ways to do this, but they are not exposed in
@ -202,7 +202,7 @@ def reply_tls_start(alpn: typing.Optional[bytes] = None, *args, **kwargs) -> tut
SSL._lib.X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS | SSL._lib.X509_CHECK_FLAG_NEVER_CHECK_SUBJECT
)
SSL._openssl_assert(
SSL._lib.X509_VERIFY_PARAM_set1_host(param, tls_start.conn.sni.encode("ascii"), 0) == 1
SSL._lib.X509_VERIFY_PARAM_set1_host(param, tls_start.conn.sni.encode(), 0) == 1
)
return tutils.reply(*args, side_effect=make_conn, **kwargs)
@ -446,8 +446,8 @@ class TestClientTLS:
assert tctx.client.tls_established
assert tctx.server.tls_established
assert tctx.server.sni == tctx.client.sni
assert tctx.client.alpn == "quux"
assert tctx.server.alpn == "quux"
assert tctx.client.alpn == b"quux"
assert tctx.server.alpn == b"quux"
_test_echo(playbook, tssl_server, tctx.server)
_test_echo(playbook, tssl_client, tctx.client)