Merge pull request #67 from Kriechi/http2-wip

HTTP/2: preparations for pathod
This commit is contained in:
Aldo Cortesi 2015-06-15 11:02:44 +12:00
commit 4fbe406e2e
4 changed files with 223 additions and 51 deletions

View File

@ -26,55 +26,80 @@ class HTTP2Protocol(object):
)
# "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"
CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'
CLIENT_CONNECTION_PREFACE =\
'505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex')
ALPN_PROTO_H2 = 'h2'
def __init__(self, tcp_client):
self.tcp_client = tcp_client
def __init__(self, tcp_handler, is_server=False):
self.tcp_handler = tcp_handler
self.is_server = is_server
self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy()
self.current_stream_id = None
self.encoder = Encoder()
self.decoder = Decoder()
self.connection_preface_performed = False
def check_alpn(self):
alp = self.tcp_client.get_alpn_proto_negotiated()
alp = self.tcp_handler.get_alpn_proto_negotiated()
if alp != self.ALPN_PROTO_H2:
raise NotImplementedError(
"HTTP2Protocol can not handle unknown ALP: %s" % alp)
return True
def perform_connection_preface(self):
self.tcp_client.wfile.write(
bytes(self.CLIENT_CONNECTION_PREFACE.decode('hex')))
self.send_frame(frame.SettingsFrame(state=self))
# read server settings frame
frm = frame.Frame.from_file(self.tcp_client.rfile, self)
def _receive_settings(self):
frm = frame.Frame.from_file(self.tcp_handler.rfile, self)
assert isinstance(frm, frame.SettingsFrame)
self._apply_settings(frm.settings)
# read setting ACK frame
def _read_settings_ack(self):
settings_ack_frame = self.read_frame()
assert isinstance(settings_ack_frame, frame.SettingsFrame)
assert settings_ack_frame.flags & frame.Frame.FLAG_ACK
assert len(settings_ack_frame.settings) == 0
def perform_server_connection_preface(self, force=False):
if force or not self.connection_preface_performed:
self.connection_preface_performed = True
magic_length = len(self.CLIENT_CONNECTION_PREFACE)
magic = self.tcp_handler.rfile.safe_read(magic_length)
assert magic == self.CLIENT_CONNECTION_PREFACE
self.send_frame(frame.SettingsFrame(state=self))
self._receive_settings()
self._read_settings_ack()
def perform_client_connection_preface(self, force=False):
if force or not self.connection_preface_performed:
self.connection_preface_performed = True
self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE)
self.send_frame(frame.SettingsFrame(state=self))
self._receive_settings()
self._read_settings_ack()
def next_stream_id(self):
if self.current_stream_id is None:
self.current_stream_id = 1
if self.is_server:
# servers must use even stream ids
self.current_stream_id = 2
else:
# clients must use odd stream ids
self.current_stream_id = 1
else:
self.current_stream_id += 2
return self.current_stream_id
def send_frame(self, frame):
raw_bytes = frame.to_bytes()
self.tcp_client.wfile.write(raw_bytes)
self.tcp_client.wfile.flush()
self.tcp_handler.wfile.write(raw_bytes)
self.tcp_handler.wfile.flush()
def read_frame(self):
frm = frame.Frame.from_file(self.tcp_client.rfile, self)
frm = frame.Frame.from_file(self.tcp_handler.rfile, self)
if isinstance(frm, frame.SettingsFrame):
self._apply_settings(frm.settings)
@ -127,10 +152,13 @@ class HTTP2Protocol(object):
if headers is None:
headers = []
authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host
headers = [
(b':method', bytes(method)),
(b':path', bytes(path)),
(b':scheme', b'https')] + headers
(b':scheme', b'https'),
(b':authority', authority),
] + headers
stream_id = self.next_stream_id()
@ -139,25 +167,50 @@ class HTTP2Protocol(object):
self._create_body(body, stream_id)))
def read_response(self):
headers, body = self._receive_transmission()
return headers[':status'], headers, body
def read_request(self):
return self._receive_transmission()
def _receive_transmission(self):
body_expected = True
header_block_fragment = b''
body = b''
while True:
frm = self.read_frame()
if isinstance(frm, frame.HeadersFrame):
if isinstance(frm, frame.HeadersFrame)\
or isinstance(frm, frame.ContinuationFrame):
header_block_fragment += frm.header_block_fragment
if frm.flags | frame.Frame.FLAG_END_HEADERS:
if frm.flags & frame.Frame.FLAG_END_HEADERS:
if frm.flags & frame.Frame.FLAG_END_STREAM:
body_expected = False
break
while True:
while body_expected:
frm = self.read_frame()
if isinstance(frm, frame.DataFrame):
body += frm.payload
if frm.flags | frame.Frame.FLAG_END_STREAM:
if frm.flags & frame.Frame.FLAG_END_STREAM:
break
# TODO: implement window update & flow
headers = {}
for header, value in self.decoder.decode(header_block_fragment):
headers[header] = value
return headers[':status'], headers, body
return headers, body
def create_response(self, code, headers=None, body=None):
if headers is None:
headers = []
headers = [(b':status', bytes(str(code)))] + headers
stream_id = self.next_stream_id()
return list(itertools.chain(
self._create_headers(headers, stream_id, end_stream=(body is None)),
self._create_body(body, stream_id)))

View File

@ -19,6 +19,9 @@ SSLv2_METHOD = SSL.SSLv2_METHOD
SSLv3_METHOD = SSL.SSLv3_METHOD
SSLv23_METHOD = SSL.SSLv23_METHOD
TLSv1_METHOD = SSL.TLSv1_METHOD
TLSv1_1_METHOD = SSL.TLSv1_1_METHOD
TLSv1_2_METHOD = SSL.TLSv1_2_METHOD
OP_NO_SSLv2 = SSL.OP_NO_SSLv2
OP_NO_SSLv3 = SSL.OP_NO_SSLv3
@ -376,7 +379,7 @@ class _Connection(object):
alpn_select=None,
):
"""
:param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD
:param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD, TLSv1_1_METHOD, or TLSv1_2_METHOD
:param options: A bit field consisting of OpenSSL.SSL.OP_* values
:param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html
:rtype : SSL.Context
@ -404,16 +407,17 @@ class _Connection(object):
context.set_info_callback(log_ssl_key)
if OpenSSL._util.lib.Cryptography_HAS_ALPN:
# advertise application layer protocols
if alpn_protos is not None:
# advertise application layer protocols
context.set_alpn_protos(alpn_protos)
# select application layer protocol
if alpn_select is not None:
def alpn_select_f(conn, options):
return bytes(alpn_select)
context.set_alpn_select_callback(alpn_select_f)
elif alpn_select is not None:
# select application layer protocol
def alpn_select_callback(conn, options):
if alpn_select in options:
return bytes(alpn_select)
else: # pragma no cover
return options[0]
context.set_alpn_select_callback(alpn_select_callback)
return context
@ -499,9 +503,9 @@ class TCPClient(_Connection):
return self.connection.gettimeout()
def get_alpn_proto_negotiated(self):
if OpenSSL._util.lib.Cryptography_HAS_ALPN:
if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established:
return self.connection.get_alpn_proto_negotiated()
else: # pragma no cover
else:
return None
@ -531,7 +535,6 @@ class BaseHandler(_Connection):
request_client_cert=None,
chain_file=None,
dhparams=None,
alpn_select=None,
**sslctx_kwargs):
"""
cert: A certutils.SSLCert object.
@ -558,9 +561,7 @@ class BaseHandler(_Connection):
until then we're conservative.
"""
context = self._create_ssl_context(
alpn_select=alpn_select,
**sslctx_kwargs)
context = self._create_ssl_context(**sslctx_kwargs)
context.use_privatekey(key)
context.use_certificate(cert.x509)
@ -585,7 +586,7 @@ class BaseHandler(_Connection):
return context
def convert_to_ssl(self, cert, key, alpn_select=None, **sslctx_kwargs):
def convert_to_ssl(self, cert, key, **sslctx_kwargs):
"""
Convert connection to SSL.
For a list of parameters, see BaseHandler._create_ssl_context(...)
@ -594,7 +595,6 @@ class BaseHandler(_Connection):
context = self.create_ssl_context(
cert,
key,
alpn_select=alpn_select,
**sslctx_kwargs)
self.connection = SSL.Connection(context, self.connection)
self.connection.set_accept_state()
@ -612,6 +612,12 @@ class BaseHandler(_Connection):
def settimeout(self, n):
self.connection.settimeout(n)
def get_alpn_proto_negotiated(self):
if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established:
return self.connection.get_alpn_proto_negotiated()
else:
return None
class TCPServer(object):
request_queue_size = 20

View File

@ -1,4 +1,3 @@
import OpenSSL
from netlib import http2
@ -50,7 +49,39 @@ class TestCheckALPNMismatch(test.ServerTestBase):
tutils.raises(NotImplementedError, protocol.check_alpn)
class TestPerformConnectionPreface(test.ServerTestBase):
class TestPerformServerConnectionPreface(test.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
# send magic
self.wfile.write(
'505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex'))
self.wfile.flush()
# send empty settings frame
self.wfile.write('000000040000000000'.decode('hex'))
self.wfile.flush()
# check empty settings frame
assert self.rfile.read(9) ==\
'000000040000000000'.decode('hex')
# check settings acknowledgement
assert self.rfile.read(9) == \
'000000040100000000'.decode('hex')
# send settings acknowledgement
self.wfile.write('000000040100000000'.decode('hex'))
self.wfile.flush()
def test_perform_server_connection_preface(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
protocol = http2.HTTP2Protocol(c)
protocol.perform_server_connection_preface()
class TestPerformClientConnectionPreface(test.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
@ -74,21 +105,18 @@ class TestPerformConnectionPreface(test.ServerTestBase):
self.wfile.write('000000040100000000'.decode('hex'))
self.wfile.flush()
ssl = True
def test_perform_connection_preface(self):
def test_perform_client_connection_preface(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl()
protocol = http2.HTTP2Protocol(c)
protocol.perform_connection_preface()
protocol.perform_client_connection_preface()
class TestStreamIds():
class TestClientStreamIds():
c = tcp.TCPClient(("127.0.0.1", 0))
protocol = http2.HTTP2Protocol(c)
def test_stream_ids(self):
def test_client_stream_ids(self):
assert self.protocol.current_stream_id is None
assert self.protocol.next_stream_id() == 1
assert self.protocol.current_stream_id == 1
@ -98,6 +126,20 @@ class TestStreamIds():
assert self.protocol.current_stream_id == 5
class TestServerStreamIds():
c = tcp.TCPClient(("127.0.0.1", 0))
protocol = http2.HTTP2Protocol(c, is_server=True)
def test_server_stream_ids(self):
assert self.protocol.current_stream_id is None
assert self.protocol.next_stream_id() == 2
assert self.protocol.current_stream_id == 2
assert self.protocol.next_stream_id() == 4
assert self.protocol.current_stream_id == 4
assert self.protocol.next_stream_id() == 6
assert self.protocol.current_stream_id == 6
class TestApplySettings(test.ServerTestBase):
class handler(tcp.BaseHandler):
@ -180,14 +222,14 @@ class TestCreateRequest():
def test_create_request_simple(self):
bytes = http2.HTTP2Protocol(self.c).create_request('GET', '/')
assert len(bytes) == 1
assert bytes[0] == '000003010500000001828487'.decode('hex')
assert bytes[0] == '00000c0105000000018284874187089d5c0b8170ff'.decode('hex')
def test_create_request_with_body(self):
bytes = http2.HTTP2Protocol(self.c).create_request(
'GET', '/', [(b'foo', b'bar')], 'foobar')
assert len(bytes) == 2
assert bytes[0] ==\
'00000b010400000001828487408294e7838c767f'.decode('hex')
'0000140104000000018284874187089d5c0b8170ff408294e7838c767f'.decode('hex')
assert bytes[1] ==\
'000006000100000001666f6f626172'.decode('hex')
@ -213,5 +255,71 @@ class TestReadResponse(test.ServerTestBase):
status, headers, body = protocol.read_response()
assert headers == {':status': '200', 'etag': 'foobar'}
assert status == '200'
assert status == "200"
assert body == b'foobar'
class TestReadEmptyResponse(test.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
self.wfile.write(
b'00000801050000000188628594e78c767f'.decode('hex'))
self.wfile.flush()
ssl = True
def test_read_empty_response(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl()
protocol = http2.HTTP2Protocol(c)
status, headers, body = protocol.read_response()
assert headers == {':status': '200', 'etag': 'foobar'}
assert status == "200"
assert body == b''
class TestReadRequest(test.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
self.wfile.write(
b'000003010400000001828487'.decode('hex'))
self.wfile.write(
b'000006000100000001666f6f626172'.decode('hex'))
self.wfile.flush()
ssl = True
def test_read_request(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl()
protocol = http2.HTTP2Protocol(c, is_server=True)
headers, body = protocol.read_request()
assert headers == {':method': 'GET', ':path': '/', ':scheme': 'https'}
assert body == b'foobar'
class TestCreateResponse():
c = tcp.TCPClient(("127.0.0.1", 0))
def test_create_request_simple(self):
bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(200)
assert len(bytes) == 1
assert bytes[0] ==\
'00000101050000000288'.decode('hex')
def test_create_request_with_body(self):
bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(
200, [(b'foo', b'bar')], 'foobar')
assert len(bytes) == 2
assert bytes[0] ==\
'00000901040000000288408294e7838c767f'.decode('hex')
assert bytes[1] ==\
'000006000100000002666f6f626172'.decode('hex')

View File

@ -376,6 +376,11 @@ class TestALPN(test.ServerTestBase):
c.convert_to_ssl(alpn_protos=["foobar"])
assert c.get_alpn_proto_negotiated() == "foobar"
def test_no_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
assert c.get_alpn_proto_negotiated() == None
else:
def test_none_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))