From 33689c6b2d2373279ec9ca7bbdef75319d1d1d9c Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 4 Dec 2016 12:06:22 +0100 Subject: [PATCH] upgrade to construct 2.8 and new API --- mitmproxy/contrib/tls/_constructs.py | 235 +++++++++++++-------------- mitmproxy/contrib/tls/utils.py | 22 --- mitmproxy/proxy/protocol/tls.py | 10 +- setup.py | 2 +- 4 files changed, 119 insertions(+), 150 deletions(-) delete mode 100644 mitmproxy/contrib/tls/utils.py diff --git a/mitmproxy/contrib/tls/_constructs.py b/mitmproxy/contrib/tls/_constructs.py index f599f0fd5..8b3f12af9 100644 --- a/mitmproxy/contrib/tls/_constructs.py +++ b/mitmproxy/contrib/tls/_constructs.py @@ -3,120 +3,122 @@ # for complete details. -from construct import (Array, Bytes, Struct, UBInt16, UBInt32, UBInt8, PascalString, Embed, TunnelAdapter, GreedyRange, - Switch, OptionalGreedyRange, Optional) - -from .utils import UBInt24 - -ProtocolVersion = Struct( - "version", - UBInt8("major"), - UBInt8("minor"), +from construct import ( + Array, + Bytes, + Struct, + VarInt, + Int8ub, + Int16ub, + Int24ub, + Int32ub, + PascalString, + Embedded, + Prefixed, + Range, + GreedyRange, + Switch, + Optional, ) -TLSPlaintext = Struct( - "TLSPlaintext", - UBInt8("type"), +ProtocolVersion = "version" / Struct( + "major" / Int8ub, + "minor" / Int8ub, +) + +TLSPlaintext = "TLSPlaintext" / Struct( + "type" / Int8ub, ProtocolVersion, - UBInt16("length"), # TODO: Reject packets with length > 2 ** 14 - Bytes("fragment", lambda ctx: ctx.length), + "length" / Int16ub, # TODO: Reject packets with length > 2 ** 14 + "fragment" / Bytes(lambda ctx: ctx.length), ) -TLSCompressed = Struct( - "TLSCompressed", - UBInt8("type"), +TLSCompressed = "TLSCompressed" / Struct( + "type" / Int8ub, ProtocolVersion, - UBInt16("length"), # TODO: Reject packets with length > 2 ** 14 + 1024 - Bytes("fragment", lambda ctx: ctx.length), + "length" / Int16ub, # TODO: Reject packets with length > 2 ** 14 + 1024 + "fragment" / Bytes(lambda ctx: ctx.length), ) -TLSCiphertext = Struct( - "TLSCiphertext", - UBInt8("type"), +TLSCiphertext = "TLSCiphertext" / Struct( + "type" / Int8ub, ProtocolVersion, - UBInt16("length"), # TODO: Reject packets with length > 2 ** 14 + 2048 - Bytes("fragment", lambda ctx: ctx.length), + "length" / Int16ub, # TODO: Reject packets with length > 2 ** 14 + 2048 + "fragment" / Bytes(lambda ctx: ctx.length), ) -Random = Struct( - "random", - UBInt32("gmt_unix_time"), - Bytes("random_bytes", 28), +Random = "random" / Struct( + "gmt_unix_time" / Int32ub, + "random_bytes" / Bytes(28), ) -SessionID = Struct( - "session_id", - UBInt8("length"), - Bytes("session_id", lambda ctx: ctx.length), +SessionID = "session_id" / Struct( + "length" / Int8ub, + "session_id" / Bytes(lambda ctx: ctx.length), ) -CipherSuites = Struct( - "cipher_suites", - UBInt16("length"), # TODO: Reject packets of length 0 - Array(lambda ctx: ctx.length // 2, UBInt16("cipher_suites")), +CipherSuites = "cipher_suites" / Struct( + "length" / Int16ub, # TODO: Reject packets of length 0 + Array(lambda ctx: ctx.length // 2, "cipher_suites" / Int16ub), ) -CompressionMethods = Struct( - "compression_methods", - UBInt8("length"), # TODO: Reject packets of length 0 - Array(lambda ctx: ctx.length, UBInt8("compression_methods")), +CompressionMethods = "compression_methods" / Struct( + "length" / Int8ub, # TODO: Reject packets of length 0 + Array(lambda ctx: ctx.length, "compression_methods" / Int8ub), ) ServerName = Struct( - "", - UBInt8("type"), - PascalString("name", length_field=UBInt16("length")), + "type" / Int8ub, + "name" / PascalString("length" / Int16ub), ) -SNIExtension = Struct( - "", - TunnelAdapter( - PascalString("server_names", length_field=UBInt16("length")), - TunnelAdapter( - PascalString("", length_field=UBInt16("length")), - GreedyRange(ServerName) - ), - ), +SNIExtension = Prefixed( + Int16ub, + Struct( + Int16ub, + "server_names" / GreedyRange( + "server_name" / Struct( + "name_type" / Int8ub, + "host_name" / PascalString("length" / Int16ub), + ) + ) + ) ) -ALPNExtension = Struct( - "", - TunnelAdapter( - PascalString("alpn_protocols", length_field=UBInt16("length")), - TunnelAdapter( - PascalString("", length_field=UBInt16("length")), - GreedyRange(PascalString("name")) +ALPNExtension = Prefixed( + Int16ub, + Struct( + Int16ub, + "alpn_protocols" / GreedyRange( + "name" / PascalString(Int8ub), ), - ), + ) ) UnknownExtension = Struct( - "", - PascalString("bytes", length_field=UBInt16("extensions_length")) + "bytes" / PascalString("length" / Int16ub) ) -Extension = Struct( - "Extension", - UBInt16("type"), - Embed( +Extension = "Extension" / Struct( + "type" / Int16ub, + Embedded( Switch( - "", lambda ctx: ctx.type, + lambda ctx: ctx.type, { 0x00: SNIExtension, - 0x10: ALPNExtension + 0x10: ALPNExtension, }, default=UnknownExtension ) ) ) -extensions = TunnelAdapter( - Optional(PascalString("extensions", length_field=UBInt16("extensions_length"))), - OptionalGreedyRange(Extension) +extensions = "extensions" / Struct( + Int16ub, + "extensions" / GreedyRange(Extension) ) -ClientHello = Struct( - "ClientHello", +ClientHello = "ClientHello" / Struct( ProtocolVersion, Random, SessionID, @@ -125,31 +127,27 @@ ClientHello = Struct( extensions, ) -ServerHello = Struct( - "ServerHello", +ServerHello = "ServerHello" / Struct( ProtocolVersion, Random, SessionID, - Bytes("cipher_suite", 2), - UBInt8("compression_method"), + "cipher_suite" / Bytes(2), + "compression_method" / Int8ub, extensions, ) -ClientCertificateType = Struct( - "certificate_types", - UBInt8("length"), # TODO: Reject packets of length 0 - Array(lambda ctx: ctx.length, UBInt8("certificate_types")), +ClientCertificateType = "certificate_types" / Struct( + "length" / Int8ub, # TODO: Reject packets of length 0 + Array(lambda ctx: ctx.length, "certificate_types" / Int8ub), ) -SignatureAndHashAlgorithm = Struct( - "algorithms", - UBInt8("hash"), - UBInt8("signature"), +SignatureAndHashAlgorithm = "algorithms" / Struct( + "hash" / Int8ub, + "signature" / Int8ub, ) -SupportedSignatureAlgorithms = Struct( - "supported_signature_algorithms", - UBInt16("supported_signature_algorithms_length"), +SupportedSignatureAlgorithms = "supported_signature_algorithms" / Struct( + "supported_signature_algorithms_length" / Int16ub, # TODO: Reject packets of length 0 Array( lambda ctx: ctx.supported_signature_algorithms_length / 2, @@ -157,56 +155,49 @@ SupportedSignatureAlgorithms = Struct( ), ) -DistinguishedName = Struct( - "certificate_authorities", - UBInt16("length"), - Bytes("certificate_authorities", lambda ctx: ctx.length), +DistinguishedName = "certificate_authorities" / Struct( + "length" / Int16ub, + "certificate_authorities" / Bytes(lambda ctx: ctx.length), ) -CertificateRequest = Struct( - "CertificateRequest", +CertificateRequest = "CertificateRequest" / Struct( ClientCertificateType, SupportedSignatureAlgorithms, DistinguishedName, ) -ServerDHParams = Struct( - "ServerDHParams", - UBInt16("dh_p_length"), - Bytes("dh_p", lambda ctx: ctx.dh_p_length), - UBInt16("dh_g_length"), - Bytes("dh_g", lambda ctx: ctx.dh_g_length), - UBInt16("dh_Ys_length"), - Bytes("dh_Ys", lambda ctx: ctx.dh_Ys_length), +ServerDHParams = "ServerDHParams" / Struct( + "dh_p_length" / Int16ub, + "dh_p" / Bytes(lambda ctx: ctx.dh_p_length), + "dh_g_length" / Int16ub, + "dh_g" / Bytes(lambda ctx: ctx.dh_g_length), + "dh_Ys_length" / Int16ub, + "dh_Ys" / Bytes(lambda ctx: ctx.dh_Ys_length), ) -PreMasterSecret = Struct( - "pre_master_secret", +PreMasterSecret = "pre_master_secret" / Struct( ProtocolVersion, - Bytes("random_bytes", 46), + "random_bytes" / Bytes(46), ) -ASN1Cert = Struct( - "ASN1Cert", - UBInt32("length"), # TODO: Reject packets with length not in 1..2^24-1 - Bytes("asn1_cert", lambda ctx: ctx.length), +ASN1Cert = "ASN1Cert" / Struct( + "length" / Int32ub, # TODO: Reject packets with length not in 1..2^24-1 + "asn1_cert" / Bytes(lambda ctx: ctx.length), ) -Certificate = Struct( - "Certificate", # TODO: Reject packets with length > 2 ** 24 - 1 - UBInt32("certificates_length"), - Bytes("certificates_bytes", lambda ctx: ctx.certificates_length), +Certificate = "Certificate" / Struct( + # TODO: Reject packets with length > 2 ** 24 - 1 + "certificates_length" / Int32ub, + "certificates_bytes" / Bytes(lambda ctx: ctx.certificates_length), ) -Handshake = Struct( - "Handshake", - UBInt8("msg_type"), - UBInt24("length"), - Bytes("body", lambda ctx: ctx.length), +Handshake = "Handshake" / Struct( + "msg_type" / Int8ub, + "length" / Int24ub, + "body" / Bytes(lambda ctx: ctx.length), ) -Alert = Struct( - "Alert", - UBInt8("level"), - UBInt8("description"), +Alert = "Alert" / Struct( + "level" / Int8ub, + "description" / Int8ub, ) diff --git a/mitmproxy/contrib/tls/utils.py b/mitmproxy/contrib/tls/utils.py deleted file mode 100644 index ff442387f..000000000 --- a/mitmproxy/contrib/tls/utils.py +++ /dev/null @@ -1,22 +0,0 @@ -# This file is dual licensed under the terms of the Apache License, Version -# 2.0, and the BSD License. See the LICENSE file in the root of this repository -# for complete details. - - -import construct - -class _UBInt24(construct.Adapter): - def _encode(self, obj, context): - return bytes( - (obj & 0xFF0000) >> 16, - (obj & 0x00FF00) >> 8, - obj & 0x0000FF - ) - - def _decode(self, obj, context): - obj = bytearray(obj) - return (obj[0] << 16 | obj[1] << 8 | obj[2]) - - -def UBInt24(name): # noqa - return _UBInt24(construct.Bytes(name, 3)) diff --git a/mitmproxy/proxy/protocol/tls.py b/mitmproxy/proxy/protocol/tls.py index c6dc2c33c..58d9e28d3 100644 --- a/mitmproxy/proxy/protocol/tls.py +++ b/mitmproxy/proxy/protocol/tls.py @@ -259,19 +259,19 @@ class TlsClientHello: @property def sni(self): - for extension in self._client_hello.extensions: + for extension in self._client_hello.extensions.extensions: is_valid_sni_extension = ( extension.type == 0x00 and len(extension.server_names) == 1 and - extension.server_names[0].type == 0 and - check.is_valid_host(extension.server_names[0].name) + extension.server_names[0].name_type == 0 and + check.is_valid_host(extension.server_names[0].host_name) ) if is_valid_sni_extension: - return extension.server_names[0].name.decode("idna") + return extension.server_names[0].host_name.decode("idna") @property def alpn_protocols(self): - for extension in self._client_hello.extensions: + for extension in self._client_hello.extensions.extensions: if extension.type == 0x10: return list(extension.alpn_protocols) diff --git a/setup.py b/setup.py index 6c328af07..56ba46fc8 100644 --- a/setup.py +++ b/setup.py @@ -63,7 +63,7 @@ setup( "click>=6.2, <7.0", "certifi>=2015.11.20.1", # no semver here - this should always be on the last release! "configargparse>=0.10, <0.12", - "construct>=2.5.2, <2.6", + "construct>=2.8, <2.9", "cryptography>=1.3, <1.7", "cssutils>=1.0.1, <1.1", "Flask>=0.10.1, <0.12",