[sans-io] improve ClientHello parsing, add tests

This commit is contained in:
Maximilian Hils 2017-08-15 16:01:59 +02:00
parent 0cb06b428e
commit 9ea0259bb7
2 changed files with 109 additions and 23 deletions

View File

@ -1,7 +1,7 @@
import os import os
from enum import Enum
import struct import struct
from typing import MutableMapping, Generator, Optional from enum import Enum
from typing import MutableMapping, Generator, Optional, Iterable, Iterator
from OpenSSL import SSL from OpenSSL import SSL
@ -23,31 +23,61 @@ class ConnectionState(Enum):
ESTABLISHED = 6 ESTABLISHED = 6
def get_client_hello(client_conn): def is_tls_handshake_record(d: bytes) -> bool:
""" """
Read all records from client buffer that contain the initial client hello message.
client_conn:
bytearray
Returns: Returns:
The raw handshake packet bytes, without TLS record header(s). True, if the passed bytes start with the TLS record magic bytes
False, otherwise.
"""
# TLS ClientHello magic, works for SSLv3, TLSv1.0, TLSv1.1, TLSv1.2.
# TLS 1.3 mandates legacy_record_version to be 0x0301.
# http://www.moserware.com/2009/06/first-few-milliseconds-of-https.html#client-hello
return (
len(d) >= 3 and
d[0] == 0x16 and
d[1] == 0x03 and
0x0 <= d[2] <= 0x03
)
def handshake_record_contents(data: bytes) -> Iterator[bytes]:
"""
Returns a generator that yields the bytes contained in each handshake record.
This will raise an error on the first non-handshake record, so fully exhausting this
generator is a bad idea.
"""
offset = 0
while True:
if len(data) < offset + 5:
return
record_header = data[offset:offset + 5]
if not is_tls_handshake_record(record_header):
raise ValueError(f"Expected TLS record, got {record_header} instead.")
record_size = struct.unpack("!H", record_header[3:])[0]
if record_size == 0:
raise ValueError("Record must not be empty.")
offset += 5
if len(data) < offset + record_size:
return
record_body = data[offset:offset + record_size]
yield record_body
offset += record_size
def get_client_hello(data: bytes) -> Optional[bytes]:
"""
Read all TLS records that contain the initial ClientHello.
Returns the raw handshake packet bytes, without TLS record headers.
""" """
client_hello = b"" client_hello = b""
client_hello_size = 1 for d in handshake_record_contents(data):
offset = 0 client_hello += d
while len(client_hello) < client_hello_size: if len(client_hello) >= 4:
record_header = client_conn[offset:5] client_hello_size = struct.unpack("!I", b'\x00' + client_hello[1:4])[0] + 4
if not tls.is_tls_record_magic(record_header) or len(record_header) != 5: if len(client_hello) >= client_hello_size:
raise exceptions.TlsProtocolException('Expected TLS record, got "%s" instead.' % record_header) return client_hello[:client_hello_size]
record_size = struct.unpack("!H", record_header[3:])[0] + 5 return None
record_body = client_conn[offset + 5: record_size]
if len(record_body) != record_size - 5:
raise exceptions.TlsProtocolException("Unexpected EOF in TLS handshake: %s" % record_body)
client_hello += record_body
offset += record_size
client_hello_size = struct.unpack("!I", b'\x00' + client_hello[1:4])[0] + 4
return client_hello
class TLSLayer(layer.Layer): class TLSLayer(layer.Layer):

View File

@ -0,0 +1,56 @@
import pytest
from mitmproxy.proxy2.layers import tls
def test_is_tls_handshake_record():
assert tls.is_tls_handshake_record(bytes.fromhex("160300"))
assert tls.is_tls_handshake_record(bytes.fromhex("160301"))
assert tls.is_tls_handshake_record(bytes.fromhex("160302"))
assert tls.is_tls_handshake_record(bytes.fromhex("160303"))
assert not tls.is_tls_handshake_record(bytes.fromhex("ffffff"))
assert not tls.is_tls_handshake_record(bytes.fromhex(""))
assert not tls.is_tls_handshake_record(bytes.fromhex("160304"))
assert not tls.is_tls_handshake_record(bytes.fromhex("150301"))
def test_record_contents():
data = bytes.fromhex(
"1603010002beef"
"1603010001ff"
)
assert list(tls.handshake_record_contents(data)) == [
b"\xbe\xef", b"\xff"
]
for i in range(6):
assert list(tls.handshake_record_contents(data[:i])) == []
def test_record_contents_err():
with pytest.raises(ValueError, msg="Expected TLS record"):
next(tls.handshake_record_contents(b"GET /error"))
empty_record = bytes.fromhex("1603010000")
with pytest.raises(ValueError, msg="Record must not be empty"):
next(tls.handshake_record_contents(empty_record))
client_hello_no_extensions = bytes.fromhex(
"0100006103015658a756ab2c2bff55f636814deac086b7ca56b65058c7893ffc6074f5245f70205658a75475103a152637"
"78e1bb6d22e8bbd5b6b0a3a59760ad354e91ba20d353001a0035002f000a000500040009000300060008006000"
"61006200640100"
)
def test_get_client_hello():
single_record = bytes.fromhex("1603010065") + client_hello_no_extensions
assert tls.get_client_hello(single_record) == client_hello_no_extensions
split_over_two_records = (
bytes.fromhex("1603010020") + client_hello_no_extensions[:32] +
bytes.fromhex("1603010045") + client_hello_no_extensions[32:]
)
assert tls.get_client_hello(split_over_two_records) == client_hello_no_extensions
incomplete = split_over_two_records[:42]
assert tls.get_client_hello(incomplete) is None