[sans-io] improve ClientHello parsing, add tests

This commit is contained in:
Maximilian Hils 2017-08-15 16:01:59 +02:00
parent 0cb06b428e
commit 9ea0259bb7
2 changed files with 109 additions and 23 deletions

View File

@ -1,7 +1,7 @@
import os
from enum import Enum
import struct
from typing import MutableMapping, Generator, Optional
from enum import Enum
from typing import MutableMapping, Generator, Optional, Iterable, Iterator
from OpenSSL import SSL
@ -23,31 +23,61 @@ class ConnectionState(Enum):
ESTABLISHED = 6
def get_client_hello(client_conn):
def is_tls_handshake_record(d: bytes) -> bool:
"""
Read all records from client buffer that contain the initial client hello message.
client_conn:
bytearray
Returns:
The raw handshake packet bytes, without TLS record header(s).
True, if the passed bytes start with the TLS record magic bytes
False, otherwise.
"""
# TLS ClientHello magic, works for SSLv3, TLSv1.0, TLSv1.1, TLSv1.2.
# TLS 1.3 mandates legacy_record_version to be 0x0301.
# http://www.moserware.com/2009/06/first-few-milliseconds-of-https.html#client-hello
return (
len(d) >= 3 and
d[0] == 0x16 and
d[1] == 0x03 and
0x0 <= d[2] <= 0x03
)
def handshake_record_contents(data: bytes) -> Iterator[bytes]:
"""
Returns a generator that yields the bytes contained in each handshake record.
This will raise an error on the first non-handshake record, so fully exhausting this
generator is a bad idea.
"""
offset = 0
while True:
if len(data) < offset + 5:
return
record_header = data[offset:offset + 5]
if not is_tls_handshake_record(record_header):
raise ValueError(f"Expected TLS record, got {record_header} instead.")
record_size = struct.unpack("!H", record_header[3:])[0]
if record_size == 0:
raise ValueError("Record must not be empty.")
offset += 5
if len(data) < offset + record_size:
return
record_body = data[offset:offset + record_size]
yield record_body
offset += record_size
def get_client_hello(data: bytes) -> Optional[bytes]:
"""
Read all TLS records that contain the initial ClientHello.
Returns the raw handshake packet bytes, without TLS record headers.
"""
client_hello = b""
client_hello_size = 1
offset = 0
while len(client_hello) < client_hello_size:
record_header = client_conn[offset:5]
if not tls.is_tls_record_magic(record_header) or len(record_header) != 5:
raise exceptions.TlsProtocolException('Expected TLS record, got "%s" instead.' % record_header)
record_size = struct.unpack("!H", record_header[3:])[0] + 5
record_body = client_conn[offset + 5: record_size]
if len(record_body) != record_size - 5:
raise exceptions.TlsProtocolException("Unexpected EOF in TLS handshake: %s" % record_body)
client_hello += record_body
offset += record_size
for d in handshake_record_contents(data):
client_hello += d
if len(client_hello) >= 4:
client_hello_size = struct.unpack("!I", b'\x00' + client_hello[1:4])[0] + 4
return client_hello
if len(client_hello) >= client_hello_size:
return client_hello[:client_hello_size]
return None
class TLSLayer(layer.Layer):

View File

@ -0,0 +1,56 @@
import pytest
from mitmproxy.proxy2.layers import tls
def test_is_tls_handshake_record():
assert tls.is_tls_handshake_record(bytes.fromhex("160300"))
assert tls.is_tls_handshake_record(bytes.fromhex("160301"))
assert tls.is_tls_handshake_record(bytes.fromhex("160302"))
assert tls.is_tls_handshake_record(bytes.fromhex("160303"))
assert not tls.is_tls_handshake_record(bytes.fromhex("ffffff"))
assert not tls.is_tls_handshake_record(bytes.fromhex(""))
assert not tls.is_tls_handshake_record(bytes.fromhex("160304"))
assert not tls.is_tls_handshake_record(bytes.fromhex("150301"))
def test_record_contents():
data = bytes.fromhex(
"1603010002beef"
"1603010001ff"
)
assert list(tls.handshake_record_contents(data)) == [
b"\xbe\xef", b"\xff"
]
for i in range(6):
assert list(tls.handshake_record_contents(data[:i])) == []
def test_record_contents_err():
with pytest.raises(ValueError, msg="Expected TLS record"):
next(tls.handshake_record_contents(b"GET /error"))
empty_record = bytes.fromhex("1603010000")
with pytest.raises(ValueError, msg="Record must not be empty"):
next(tls.handshake_record_contents(empty_record))
client_hello_no_extensions = bytes.fromhex(
"0100006103015658a756ab2c2bff55f636814deac086b7ca56b65058c7893ffc6074f5245f70205658a75475103a152637"
"78e1bb6d22e8bbd5b6b0a3a59760ad354e91ba20d353001a0035002f000a000500040009000300060008006000"
"61006200640100"
)
def test_get_client_hello():
single_record = bytes.fromhex("1603010065") + client_hello_no_extensions
assert tls.get_client_hello(single_record) == client_hello_no_extensions
split_over_two_records = (
bytes.fromhex("1603010020") + client_hello_no_extensions[:32] +
bytes.fromhex("1603010045") + client_hello_no_extensions[32:]
)
assert tls.get_client_hello(split_over_two_records) == client_hello_no_extensions
incomplete = split_over_two_records[:42]
assert tls.get_client_hello(incomplete) is None