diff --git a/mitmproxy/contrib/kaitaistruct/make.sh b/mitmproxy/contrib/kaitaistruct/make.sh index 218d51989..9ef688865 100755 --- a/mitmproxy/contrib/kaitaistruct/make.sh +++ b/mitmproxy/contrib/kaitaistruct/make.sh @@ -6,5 +6,6 @@ wget -N https://raw.githubusercontent.com/kaitai-io/kaitai_struct_formats/master wget -N https://raw.githubusercontent.com/kaitai-io/kaitai_struct_formats/master/image/gif.ksy wget -N https://raw.githubusercontent.com/kaitai-io/kaitai_struct_formats/master/image/jpeg.ksy wget -N https://raw.githubusercontent.com/kaitai-io/kaitai_struct_formats/master/image/png.ksy +wget -N https://raw.githubusercontent.com/mitmproxy/mitmproxy/master/mitmproxy/contrib/tls_client_hello.py kaitai-struct-compiler --target python --opaque-types=true *.ksy diff --git a/mitmproxy/contrib/kaitaistruct/tls_client_hello.py b/mitmproxy/contrib/kaitaistruct/tls_client_hello.py new file mode 100644 index 000000000..6aff9b142 --- /dev/null +++ b/mitmproxy/contrib/kaitaistruct/tls_client_hello.py @@ -0,0 +1,146 @@ +# This is a generated file! Please edit source .ksy file and use kaitai-struct-compiler to rebuild + +import array +import struct +import zlib +from enum import Enum +from pkg_resources import parse_version + +from kaitaistruct import __version__ as ks_version, KaitaiStruct, KaitaiStream, BytesIO + +if parse_version(ks_version) < parse_version('0.7'): + raise Exception("Incompatible Kaitai Struct Python API: 0.7 or later is required, but you have %s" % (ks_version)) + + +class TlsClientHello(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.version = self._root.Version(self._io, self, self._root) + self.random = self._root.Random(self._io, self, self._root) + self.session_id = self._root.SessionId(self._io, self, self._root) + self.cipher_suites = self._root.CipherSuites(self._io, self, self._root) + self.compression_methods = self._root.CompressionMethods(self._io, self, self._root) + if self._io.is_eof() == True: + self.extensions = [None] * (0) + for i in range(0): + self.extensions[i] = self._io.read_bytes(0) + + if self._io.is_eof() == False: + self.extensions = self._root.Extensions(self._io, self, self._root) + + class ServerName(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.name_type = self._io.read_u1() + self.length = self._io.read_u2be() + self.host_name = self._io.read_bytes(self.length) + + class Random(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.gmt_unix_time = self._io.read_u4be() + self.random = self._io.read_bytes(28) + + class SessionId(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.len = self._io.read_u1() + self.sid = self._io.read_bytes(self.len) + + class Sni(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.list_length = self._io.read_u2be() + self.server_names = [] + while not self._io.is_eof(): + self.server_names.append(self._root.ServerName(self._io, self, self._root)) + + class CipherSuites(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.len = self._io.read_u2be() + self.cipher_suites = [None] * (self.len // 2) + for i in range(self.len // 2): + self.cipher_suites[i] = self._root.CipherSuite(self._io, self, self._root) + + class CompressionMethods(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.len = self._io.read_u1() + self.compression_methods = self._io.read_bytes(self.len) + + class Alpn(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.ext_len = self._io.read_u2be() + self.alpn_protocols = [] + while not self._io.is_eof(): + self.alpn_protocols.append(self._root.Protocol(self._io, self, self._root)) + + class Extensions(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.len = self._io.read_u2be() + self.extensions = [] + while not self._io.is_eof(): + self.extensions.append(self._root.Extension(self._io, self, self._root)) + + class Version(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.major = self._io.read_u1() + self.minor = self._io.read_u1() + + class CipherSuite(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.cipher_suite = self._io.read_u2be() + + class Protocol(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.strlen = self._io.read_u1() + self.name = self._io.read_bytes(self.strlen) + + class Extension(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.type = self._io.read_u2be() + self.len = self._io.read_u2be() + _on = self.type + if _on == 0: + self._raw_body = self._io.read_bytes(self.len) + io = KaitaiStream(BytesIO(self._raw_body)) + self.body = self._root.Sni(io, self, self._root) + elif _on == 16: + self._raw_body = self._io.read_bytes(self.len) + io = KaitaiStream(BytesIO(self._raw_body)) + self.body = self._root.Alpn(io, self, self._root) + else: + self.body = self._io.read_bytes(self.len) diff --git a/mitmproxy/contrib/tls_client_hello.ksy b/mitmproxy/contrib/tls_client_hello.ksy new file mode 100644 index 000000000..5b6eb0fb9 --- /dev/null +++ b/mitmproxy/contrib/tls_client_hello.ksy @@ -0,0 +1,139 @@ +meta: + id: tls_client_hello + endian: be + +seq: + - id: version + type: version + + - id: random + type: random + + - id: session_id + type: session_id + + - id: cipher_suites + type: cipher_suites + + - id: compression_methods + type: compression_methods + + - id: extensions + size: 0 + repeat: expr + repeat-expr: 0 + if: _io.eof == true + + - id: extensions + type: extensions + if: _io.eof == false + +types: + version: + seq: + - id: major + type: u1 + + - id: minor + type: u1 + + random: + seq: + - id: gmt_unix_time + type: u4 + + - id: random + size: 28 + + session_id: + seq: + - id: len + type: u1 + + - id: sid + size: len + + cipher_suites: + seq: + - id: len + type: u2 + + - id: cipher_suites + type: cipher_suite + repeat: expr + repeat-expr: len/2 + + cipher_suite: + seq: + - id: cipher_suite + type: u2 + + compression_methods: + seq: + - id: len + type: u1 + + - id: compression_methods + size: len + + extensions: + seq: + - id: len + type: u2 + + - id: extensions + type: extension + repeat: eos + + extension: + seq: + - id: type + type: u2 + + - id: len + type: u2 + + - id: body + size: len + type: + switch-on: type + cases: + 0: sni + 16: alpn + + sni: + seq: + - id: list_length + type: u2 + + - id: server_names + type: server_name + repeat: eos + + server_name: + seq: + - id: name_type + type: u1 + + - id: length + type: u2 + + - id: host_name + size: length + + alpn: + seq: + - id: ext_len + type: u2 + + - id: alpn_protocols + type: protocol + repeat: eos + + protocol: + seq: + - id: strlen + type: u1 + + - id: name + size: strlen diff --git a/mitmproxy/proxy/protocol/tls.py b/mitmproxy/proxy/protocol/tls.py index f55855f0a..d42c7fdd2 100644 --- a/mitmproxy/proxy/protocol/tls.py +++ b/mitmproxy/proxy/protocol/tls.py @@ -1,10 +1,11 @@ import struct from typing import Optional # noqa from typing import Union +import io -import construct +from kaitaistruct import KaitaiStream from mitmproxy import exceptions -from mitmproxy.contrib import tls_parser +from mitmproxy.contrib.kaitaistruct import tls_client_hello from mitmproxy.proxy.protocol import base from mitmproxy.net import check @@ -263,7 +264,7 @@ def get_client_hello(client_conn): class TlsClientHello: def __init__(self, raw_client_hello): - self._client_hello = tls_parser.ClientHello.parse(raw_client_hello) + self._client_hello = tls_client_hello.TlsClientHello(KaitaiStream(io.BytesIO(raw_client_hello))) def raw(self): return self._client_hello @@ -278,12 +279,12 @@ class TlsClientHello: for extension in self._client_hello.extensions.extensions: is_valid_sni_extension = ( extension.type == 0x00 and - len(extension.server_names) == 1 and - extension.server_names[0].name_type == 0 and - check.is_valid_host(extension.server_names[0].host_name) + len(extension.body.server_names) == 1 and + extension.body.server_names[0].name_type == 0 and + check.is_valid_host(extension.body.server_names[0].host_name) ) if is_valid_sni_extension: - return extension.server_names[0].host_name.decode("idna") + return extension.body.server_names[0].host_name.decode("idna") return None @property @@ -291,7 +292,7 @@ class TlsClientHello: if self._client_hello.extensions: for extension in self._client_hello.extensions.extensions: if extension.type == 0x10: - return list(extension.alpn_protocols) + return list(extension.body.alpn_protocols) return [] @classmethod @@ -310,7 +311,7 @@ class TlsClientHello: try: return cls(raw_client_hello) - except construct.ConstructError as e: + except EOFError as e: raise exceptions.TlsProtocolException( 'Cannot parse Client Hello: %s, Raw Client Hello: %s' % (repr(e), raw_client_hello.encode("hex")) @@ -518,7 +519,8 @@ class TlsLayer(base.Layer): # We only support http/1.1 and h2. # If the server only supports spdy (next to http/1.1), it may select that # and mitmproxy would enter TCP passthrough mode, which we want to avoid. - alpn = [x for x in self._client_hello.alpn_protocols if not (x.startswith(b"h2-") or x.startswith(b"spdy"))] + alpn = [x.name for x in self._client_hello.alpn_protocols if + not (x.name.startswith(b"h2-") or x.name.startswith(b"spdy"))] if alpn and b"h2" in alpn and not self.config.options.http2: alpn.remove(b"h2") @@ -537,8 +539,8 @@ class TlsLayer(base.Layer): if not ciphers_server and self._client_tls: ciphers_server = [] for id in self._client_hello.cipher_suites: - if id in CIPHER_ID_NAME_MAP.keys(): - ciphers_server.append(CIPHER_ID_NAME_MAP[id]) + if id.cipher_suite in CIPHER_ID_NAME_MAP.keys(): + ciphers_server.append(CIPHER_ID_NAME_MAP[id.cipher_suite]) ciphers_server = ':'.join(ciphers_server) self.server_conn.establish_ssl( diff --git a/test/mitmproxy/contrib/test_tls_parser.py b/test/mitmproxy/contrib/test_tls_parser.py index 66972b623..e4d9177f4 100644 --- a/test/mitmproxy/contrib/test_tls_parser.py +++ b/test/mitmproxy/contrib/test_tls_parser.py @@ -1,4 +1,6 @@ -from mitmproxy.contrib import tls_parser +import io +from kaitaistruct import KaitaiStream +from mitmproxy.contrib.kaitaistruct import tls_client_hello def test_parse_chrome(): @@ -12,18 +14,20 @@ def test_parse_chrome(): "00000000001200000010000e000c02683208687474702f312e3175500000000b00020100000a00080006001d00" "170018" ) - c = tls_parser.ClientHello.parse(data) + + c = tls_client_hello.TlsClientHello(KaitaiStream(io.BytesIO(data))) assert c.version.major == 3 assert c.version.minor == 3 alpn = [a for a in c.extensions.extensions if a.type == 16] assert len(alpn) == 1 - assert alpn[0].alpn_protocols == [b"h2", b"http/1.1"] + assert alpn[0].body.alpn_protocols[0].name == b"h2" + assert alpn[0].body.alpn_protocols[1].name == b"http/1.1" sni = [a for a in c.extensions.extensions if a.type == 0] assert len(sni) == 1 - assert sni[0].server_names[0].name_type == 0 - assert sni[0].server_names[0].host_name == b"example.com" + assert sni[0].body.server_names[0].name_type == 0 + assert sni[0].body.server_names[0].host_name == b"example.com" def test_parse_no_extensions(): @@ -32,7 +36,8 @@ def test_parse_no_extensions(): "78e1bb6d22e8bbd5b6b0a3a59760ad354e91ba20d353001a0035002f000a000500040009000300060008006000" "61006200640100" ) - c = tls_parser.ClientHello.parse(data) + + c = tls_client_hello.TlsClientHello(KaitaiStream(io.BytesIO(data))) assert c.version.major == 3 assert c.version.minor == 1 - assert c.extensions is None + assert c.extensions == [] diff --git a/test/mitmproxy/proxy/protocol/test_tls.py b/test/mitmproxy/proxy/protocol/test_tls.py index e17ee46fe..980ba7bd6 100644 --- a/test/mitmproxy/proxy/protocol/test_tls.py +++ b/test/mitmproxy/proxy/protocol/test_tls.py @@ -23,4 +23,5 @@ class TestClientHello: ) c = TlsClientHello(data) assert c.sni == 'example.com' - assert c.alpn_protocols == [b'h2', b'http/1.1'] + assert c.alpn_protocols[0].name == b'h2' + assert c.alpn_protocols[1].name == b'http/1.1'