diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index feac220c9..1e722dfbf 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -26,55 +26,80 @@ class HTTP2Protocol(object): ) # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" - CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' + CLIENT_CONNECTION_PREFACE =\ + '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') ALPN_PROTO_H2 = 'h2' - def __init__(self, tcp_client): - self.tcp_client = tcp_client + def __init__(self, tcp_handler, is_server=False): + self.tcp_handler = tcp_handler + self.is_server = is_server self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() self.current_stream_id = None self.encoder = Encoder() self.decoder = Decoder() + self.connection_preface_performed = False def check_alpn(self): - alp = self.tcp_client.get_alpn_proto_negotiated() + alp = self.tcp_handler.get_alpn_proto_negotiated() if alp != self.ALPN_PROTO_H2: raise NotImplementedError( "HTTP2Protocol can not handle unknown ALP: %s" % alp) return True - def perform_connection_preface(self): - self.tcp_client.wfile.write( - bytes(self.CLIENT_CONNECTION_PREFACE.decode('hex'))) - self.send_frame(frame.SettingsFrame(state=self)) - - # read server settings frame - frm = frame.Frame.from_file(self.tcp_client.rfile, self) + def _receive_settings(self): + frm = frame.Frame.from_file(self.tcp_handler.rfile, self) assert isinstance(frm, frame.SettingsFrame) self._apply_settings(frm.settings) - # read setting ACK frame + def _read_settings_ack(self): settings_ack_frame = self.read_frame() assert isinstance(settings_ack_frame, frame.SettingsFrame) assert settings_ack_frame.flags & frame.Frame.FLAG_ACK assert len(settings_ack_frame.settings) == 0 + def perform_server_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + self.connection_preface_performed = True + + magic_length = len(self.CLIENT_CONNECTION_PREFACE) + magic = self.tcp_handler.rfile.safe_read(magic_length) + assert magic == self.CLIENT_CONNECTION_PREFACE + + self.send_frame(frame.SettingsFrame(state=self)) + self._receive_settings() + self._read_settings_ack() + + def perform_client_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + self.connection_preface_performed = True + + self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) + + self.send_frame(frame.SettingsFrame(state=self)) + self._receive_settings() + self._read_settings_ack() + def next_stream_id(self): if self.current_stream_id is None: - self.current_stream_id = 1 + if self.is_server: + # servers must use even stream ids + self.current_stream_id = 2 + else: + # clients must use odd stream ids + self.current_stream_id = 1 else: self.current_stream_id += 2 return self.current_stream_id def send_frame(self, frame): raw_bytes = frame.to_bytes() - self.tcp_client.wfile.write(raw_bytes) - self.tcp_client.wfile.flush() + self.tcp_handler.wfile.write(raw_bytes) + self.tcp_handler.wfile.flush() def read_frame(self): - frm = frame.Frame.from_file(self.tcp_client.rfile, self) + frm = frame.Frame.from_file(self.tcp_handler.rfile, self) if isinstance(frm, frame.SettingsFrame): self._apply_settings(frm.settings) @@ -127,10 +152,13 @@ class HTTP2Protocol(object): if headers is None: headers = [] + authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host headers = [ (b':method', bytes(method)), (b':path', bytes(path)), - (b':scheme', b'https')] + headers + (b':scheme', b'https'), + (b':authority', authority), + ] + headers stream_id = self.next_stream_id() @@ -139,25 +167,50 @@ class HTTP2Protocol(object): self._create_body(body, stream_id))) def read_response(self): + headers, body = self._receive_transmission() + return headers[':status'], headers, body + + def read_request(self): + return self._receive_transmission() + + def _receive_transmission(self): + body_expected = True + header_block_fragment = b'' body = b'' while True: frm = self.read_frame() - if isinstance(frm, frame.HeadersFrame): + if isinstance(frm, frame.HeadersFrame)\ + or isinstance(frm, frame.ContinuationFrame): header_block_fragment += frm.header_block_fragment - if frm.flags | frame.Frame.FLAG_END_HEADERS: + if frm.flags & frame.Frame.FLAG_END_HEADERS: + if frm.flags & frame.Frame.FLAG_END_STREAM: + body_expected = False break - while True: + while body_expected: frm = self.read_frame() if isinstance(frm, frame.DataFrame): body += frm.payload - if frm.flags | frame.Frame.FLAG_END_STREAM: + if frm.flags & frame.Frame.FLAG_END_STREAM: break + # TODO: implement window update & flow headers = {} for header, value in self.decoder.decode(header_block_fragment): headers[header] = value - return headers[':status'], headers, body + return headers, body + + def create_response(self, code, headers=None, body=None): + if headers is None: + headers = [] + + headers = [(b':status', bytes(str(code)))] + headers + + stream_id = self.next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(body is None)), + self._create_body(body, stream_id))) diff --git a/netlib/tcp.py b/netlib/tcp.py index 9a9800359..897e3e657 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -19,6 +19,9 @@ SSLv2_METHOD = SSL.SSLv2_METHOD SSLv3_METHOD = SSL.SSLv3_METHOD SSLv23_METHOD = SSL.SSLv23_METHOD TLSv1_METHOD = SSL.TLSv1_METHOD +TLSv1_1_METHOD = SSL.TLSv1_1_METHOD +TLSv1_2_METHOD = SSL.TLSv1_2_METHOD + OP_NO_SSLv2 = SSL.OP_NO_SSLv2 OP_NO_SSLv3 = SSL.OP_NO_SSLv3 @@ -376,7 +379,7 @@ class _Connection(object): alpn_select=None, ): """ - :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD + :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD, TLSv1_1_METHOD, or TLSv1_2_METHOD :param options: A bit field consisting of OpenSSL.SSL.OP_* values :param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html :rtype : SSL.Context @@ -404,16 +407,17 @@ class _Connection(object): context.set_info_callback(log_ssl_key) if OpenSSL._util.lib.Cryptography_HAS_ALPN: - # advertise application layer protocols if alpn_protos is not None: + # advertise application layer protocols context.set_alpn_protos(alpn_protos) - - # select application layer protocol - if alpn_select is not None: - def alpn_select_f(conn, options): - return bytes(alpn_select) - - context.set_alpn_select_callback(alpn_select_f) + elif alpn_select is not None: + # select application layer protocol + def alpn_select_callback(conn, options): + if alpn_select in options: + return bytes(alpn_select) + else: # pragma no cover + return options[0] + context.set_alpn_select_callback(alpn_select_callback) return context @@ -499,9 +503,9 @@ class TCPClient(_Connection): return self.connection.gettimeout() def get_alpn_proto_negotiated(self): - if OpenSSL._util.lib.Cryptography_HAS_ALPN: + if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: return self.connection.get_alpn_proto_negotiated() - else: # pragma no cover + else: return None @@ -531,7 +535,6 @@ class BaseHandler(_Connection): request_client_cert=None, chain_file=None, dhparams=None, - alpn_select=None, **sslctx_kwargs): """ cert: A certutils.SSLCert object. @@ -558,9 +561,7 @@ class BaseHandler(_Connection): until then we're conservative. """ - context = self._create_ssl_context( - alpn_select=alpn_select, - **sslctx_kwargs) + context = self._create_ssl_context(**sslctx_kwargs) context.use_privatekey(key) context.use_certificate(cert.x509) @@ -585,7 +586,7 @@ class BaseHandler(_Connection): return context - def convert_to_ssl(self, cert, key, alpn_select=None, **sslctx_kwargs): + def convert_to_ssl(self, cert, key, **sslctx_kwargs): """ Convert connection to SSL. For a list of parameters, see BaseHandler._create_ssl_context(...) @@ -594,7 +595,6 @@ class BaseHandler(_Connection): context = self.create_ssl_context( cert, key, - alpn_select=alpn_select, **sslctx_kwargs) self.connection = SSL.Connection(context, self.connection) self.connection.set_accept_state() @@ -612,6 +612,12 @@ class BaseHandler(_Connection): def settimeout(self, n): self.connection.settimeout(n) + def get_alpn_proto_negotiated(self): + if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: + return self.connection.get_alpn_proto_negotiated() + else: + return None + class TCPServer(object): request_queue_size = 20 diff --git a/test/http2/test_http2_protocol.py b/test/http2/test_http2_protocol.py index cb46bc68a..34c69fa92 100644 --- a/test/http2/test_http2_protocol.py +++ b/test/http2/test_http2_protocol.py @@ -1,4 +1,3 @@ - import OpenSSL from netlib import http2 @@ -50,7 +49,39 @@ class TestCheckALPNMismatch(test.ServerTestBase): tutils.raises(NotImplementedError, protocol.check_alpn) -class TestPerformConnectionPreface(test.ServerTestBase): +class TestPerformServerConnectionPreface(test.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + # send magic + self.wfile.write( + '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex')) + self.wfile.flush() + + # send empty settings frame + self.wfile.write('000000040000000000'.decode('hex')) + self.wfile.flush() + + # check empty settings frame + assert self.rfile.read(9) ==\ + '000000040000000000'.decode('hex') + + # check settings acknowledgement + assert self.rfile.read(9) == \ + '000000040100000000'.decode('hex') + + # send settings acknowledgement + self.wfile.write('000000040100000000'.decode('hex')) + self.wfile.flush() + + def test_perform_server_connection_preface(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + protocol = http2.HTTP2Protocol(c) + protocol.perform_server_connection_preface() + + +class TestPerformClientConnectionPreface(test.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): @@ -74,21 +105,18 @@ class TestPerformConnectionPreface(test.ServerTestBase): self.wfile.write('000000040100000000'.decode('hex')) self.wfile.flush() - ssl = True - - def test_perform_connection_preface(self): + def test_perform_client_connection_preface(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl() protocol = http2.HTTP2Protocol(c) - protocol.perform_connection_preface() + protocol.perform_client_connection_preface() -class TestStreamIds(): +class TestClientStreamIds(): c = tcp.TCPClient(("127.0.0.1", 0)) protocol = http2.HTTP2Protocol(c) - def test_stream_ids(self): + def test_client_stream_ids(self): assert self.protocol.current_stream_id is None assert self.protocol.next_stream_id() == 1 assert self.protocol.current_stream_id == 1 @@ -98,6 +126,20 @@ class TestStreamIds(): assert self.protocol.current_stream_id == 5 +class TestServerStreamIds(): + c = tcp.TCPClient(("127.0.0.1", 0)) + protocol = http2.HTTP2Protocol(c, is_server=True) + + def test_server_stream_ids(self): + assert self.protocol.current_stream_id is None + assert self.protocol.next_stream_id() == 2 + assert self.protocol.current_stream_id == 2 + assert self.protocol.next_stream_id() == 4 + assert self.protocol.current_stream_id == 4 + assert self.protocol.next_stream_id() == 6 + assert self.protocol.current_stream_id == 6 + + class TestApplySettings(test.ServerTestBase): class handler(tcp.BaseHandler): @@ -180,14 +222,14 @@ class TestCreateRequest(): def test_create_request_simple(self): bytes = http2.HTTP2Protocol(self.c).create_request('GET', '/') assert len(bytes) == 1 - assert bytes[0] == '000003010500000001828487'.decode('hex') + assert bytes[0] == '00000c0105000000018284874187089d5c0b8170ff'.decode('hex') def test_create_request_with_body(self): bytes = http2.HTTP2Protocol(self.c).create_request( 'GET', '/', [(b'foo', b'bar')], 'foobar') assert len(bytes) == 2 assert bytes[0] ==\ - '00000b010400000001828487408294e7838c767f'.decode('hex') + '0000140104000000018284874187089d5c0b8170ff408294e7838c767f'.decode('hex') assert bytes[1] ==\ '000006000100000001666f6f626172'.decode('hex') @@ -213,5 +255,71 @@ class TestReadResponse(test.ServerTestBase): status, headers, body = protocol.read_response() assert headers == {':status': '200', 'etag': 'foobar'} - assert status == '200' + assert status == "200" assert body == b'foobar' + + +class TestReadEmptyResponse(test.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + self.wfile.write( + b'00000801050000000188628594e78c767f'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_read_empty_response(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = http2.HTTP2Protocol(c) + + status, headers, body = protocol.read_response() + + assert headers == {':status': '200', 'etag': 'foobar'} + assert status == "200" + assert body == b'' + + +class TestReadRequest(test.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + self.wfile.write( + b'000003010400000001828487'.decode('hex')) + self.wfile.write( + b'000006000100000001666f6f626172'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_read_request(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = http2.HTTP2Protocol(c, is_server=True) + + headers, body = protocol.read_request() + + assert headers == {':method': 'GET', ':path': '/', ':scheme': 'https'} + assert body == b'foobar' + + +class TestCreateResponse(): + c = tcp.TCPClient(("127.0.0.1", 0)) + + def test_create_request_simple(self): + bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(200) + assert len(bytes) == 1 + assert bytes[0] ==\ + '00000101050000000288'.decode('hex') + + def test_create_request_with_body(self): + bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response( + 200, [(b'foo', b'bar')], 'foobar') + assert len(bytes) == 2 + assert bytes[0] ==\ + '00000901040000000288408294e7838c767f'.decode('hex') + assert bytes[1] ==\ + '000006000100000002666f6f626172'.decode('hex') diff --git a/test/test_tcp.py b/test/test_tcp.py index d5506556c..8aa34d2b7 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -376,6 +376,11 @@ class TestALPN(test.ServerTestBase): c.convert_to_ssl(alpn_protos=["foobar"]) assert c.get_alpn_proto_negotiated() == "foobar" + def test_no_alpn(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + assert c.get_alpn_proto_negotiated() == None + else: def test_none_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port))