diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index e46ad7abf..b098110a1 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -4,8 +4,10 @@ import collections import string import sys import urlparse +import time from netlib import odict, utils, tcp, http +from netlib.http import semantics from .. import status_codes from ..exceptions import * @@ -14,13 +16,10 @@ class TCPHandler(object): self.rfile = rfile self.wfile = wfile -class HTTP1Protocol(object): +class HTTP1Protocol(semantics.ProtocolMixin): def __init__(self, tcp_handler=None, rfile=None, wfile=None): - if tcp_handler: - self.tcp_handler = tcp_handler - else: - self.tcp_handler = TCPHandler(rfile, wfile) + self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): @@ -39,6 +38,10 @@ class HTTP1Protocol(object): Raises: HttpError: If the input is invalid. """ + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() + httpversion, host, port, scheme, method, path, headers, body = ( None, None, None, None, None, None, None, None) @@ -106,6 +109,12 @@ class HTTP1Protocol(object): True ) + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + timestamp_end = time.time() + return http.Request( form_in, method, @@ -115,7 +124,9 @@ class HTTP1Protocol(object): path, httpversion, headers, - body + body, + timestamp_start, + timestamp_end, ) @@ -124,12 +135,15 @@ class HTTP1Protocol(object): Returns an http.Response By default, both response header and body are read. - If include_body=False is specified, content may be one of the + If include_body=False is specified, body may be one of the following: - None, if the response is technically allowed to have a response body - "", if the response must not have a response body (e.g. it's a response to a HEAD request) """ + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() line = self.tcp_handler.rfile.readline() # Possible leftover from previous message @@ -149,7 +163,7 @@ class HTTP1Protocol(object): raise HttpError(502, "Invalid headers.") if include_body: - content = self.read_http_body( + body = self.read_http_body( headers, body_size_limit, request_method, @@ -157,10 +171,55 @@ class HTTP1Protocol(object): False ) else: - # if include_body==False then a None content means the body should be + # if include_body==False then a None body means the body should be # read separately - content = None - return http.Response(httpversion, code, msg, headers, content) + body = None + + + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + if include_body: + timestamp_end = time.time() + else: + timestamp_end = None + + return http.Response( + httpversion, + code, + msg, + headers, + body, + timestamp_start=timestamp_start, + timestamp_end=timestamp_end, + ) + + + def assemble_request(self, request): + assert isinstance(request, semantics.Request) + + if request.body == semantics.CONTENT_MISSING: + raise http.HttpError( + 502, + "Cannot assemble flow with CONTENT_MISSING" + ) + first_line = self._assemble_request_first_line(request) + headers = self._assemble_request_headers(request) + return "%s\r\n%s\r\n%s" % (first_line, headers, request.body) + + + def assemble_response(self, response): + assert isinstance(response, semantics.Response) + + if response.body == semantics.CONTENT_MISSING: + raise http.HttpError( + 502, + "Cannot assemble flow with CONTENT_MISSING" + ) + first_line = self._assemble_response_first_line(response) + headers = self._assemble_response_headers(response) + return "%s\r\n%s\r\n%s" % (first_line, headers, response.body) def read_headers(self): @@ -331,7 +390,6 @@ class HTTP1Protocol(object): return line - def _read_chunked(self, limit, is_request): """ Read a chunked HTTP body. @@ -494,3 +552,74 @@ class HTTP1Protocol(object): except ValueError: return None return (proto, code, msg) + + + @classmethod + def _assemble_request_first_line(self, request): + if request.form_in == "relative": + request_line = '%s %s HTTP/%s.%s' % ( + request.method, + request.path, + request.httpversion[0], + request.httpversion[1], + ) + elif request.form_in == "authority": + request_line = '%s %s:%s HTTP/%s.%s' % ( + request.method, + request.host, + request.port, + request.httpversion[0], + request.httpversion[1], + ) + elif request.form_in == "absolute": + request_line = '%s %s://%s:%s%s HTTP/%s.%s' % ( + request.method, + request.scheme, + request.host, + request.port, + request.path, + request.httpversion[0], + request.httpversion[1], + ) + else: + raise http.HttpError(400, "Invalid request form") + return request_line + + def _assemble_request_headers(self, request): + headers = request.headers.copy() + for k in request._headers_to_strip_off: + del headers[k] + if 'host' not in headers and request.scheme and request.host and request.port: + headers["Host"] = [utils.hostport(request.scheme, + request.host, + request.port)] + + # If content is defined (i.e. not None or CONTENT_MISSING), we always + # add a content-length header. + if request.body or request.body == "": + headers["Content-Length"] = [str(len(request.body))] + + return headers.format() + + + def _assemble_response_first_line(self, response): + return 'HTTP/%s.%s %s %s' % ( + response.httpversion[0], + response.httpversion[1], + response.status_code, + response.msg, + ) + + def _assemble_response_headers(self, response, preserve_transfer_encoding=False): + headers = response.headers.copy() + for k in response._headers_to_strip_off: + del headers[k] + if not preserve_transfer_encoding: + del headers['Transfer-Encoding'] + + # If body is defined (i.e. not None or CONTENT_MISSING), we always + # add a content-length header. + if response.body or response.body == "": + headers["Content-Length"] = [str(len(response.body))] + + return headers.format() diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 55b5ca763..a1ca4a182 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -1,12 +1,20 @@ from __future__ import (absolute_import, print_function, division) import itertools +import time from hpack.hpack import Encoder, Decoder from netlib import http, utils, odict +from netlib.http import semantics from . import frame -class HTTP2Protocol(object): +class TCPHandler(object): + def __init__(self, rfile, wfile=None): + self.rfile = rfile + self.wfile = wfile + + +class HTTP2Protocol(semantics.ProtocolMixin): ERROR_CODES = utils.BiDi( NO_ERROR=0x0, @@ -31,16 +39,182 @@ class HTTP2Protocol(object): ALPN_PROTO_H2 = 'h2' - def __init__(self, tcp_handler, is_server=False, dump_frames=False): - self.tcp_handler = tcp_handler + + def __init__( + self, + tcp_handler=None, + rfile=None, + wfile=None, + is_server=False, + dump_frames=False, + encoder=None, + decoder=None, + ): + self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) self.is_server = is_server + self.dump_frames = dump_frames + self.encoder = encoder or Encoder() + self.decoder = decoder or Decoder() self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() self.current_stream_id = None - self.encoder = Encoder() - self.decoder = Decoder() self.connection_preface_performed = False - self.dump_frames = dump_frames + + def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): + self.perform_connection_preface() + + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() + + stream_id, headers, body = self._receive_transmission(include_body) + + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + timestamp_end = time.time() + + request = http.Request( + "relative", # TODO: use the correct value + headers.get_first(':method', 'GET'), + headers.get_first(':scheme', 'https'), + headers.get_first(':host', 'localhost'), + 443, # TODO: parse port number from host? + headers.get_first(':path', '/'), + (2, 0), + headers, + body, + timestamp_start, + timestamp_end, + ) + request.stream_id = stream_id + + return request + + def read_response(self, request_method='', body_size_limit=None, include_body=True): + self.perform_connection_preface() + + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() + + stream_id, headers, body = self._receive_transmission(include_body) + + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + if include_body: + timestamp_end = time.time() + else: + timestamp_end = None + + response = http.Response( + (2, 0), + int(headers.get_first(':status')), + "", + headers, + body, + timestamp_start=timestamp_start, + timestamp_end=timestamp_end, + ) + response.stream_id = stream_id + + return response + + + def assemble_request(self, request): + assert isinstance(request, semantics.Request) + + authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host + if self.tcp_handler.address.port != 443: + authority += ":%d" % self.tcp_handler.address.port + + headers = request.headers.copy() + + if not ':authority' in headers.keys(): + headers.add(':authority', bytes(authority), prepend=True) + if not ':scheme' in headers.keys(): + headers.add(':scheme', bytes(request.scheme), prepend=True) + if not ':path' in headers.keys(): + headers.add(':path', bytes(request.path), prepend=True) + if not ':method' in headers.keys(): + headers.add(':method', bytes(request.method), prepend=True) + + headers = headers.items() + + if hasattr(request, 'stream_id'): + stream_id = request.stream_id + else: + stream_id = self._next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(request.body is None or len(request.body) == 0)), + self._create_body(request.body, stream_id))) + + def assemble_response(self, response): + assert isinstance(response, semantics.Response) + + headers = response.headers.copy() + + if not ':status' in headers.keys(): + headers.add(':status', bytes(str(response.status_code)), prepend=True) + + headers = headers.items() + + if hasattr(response, 'stream_id'): + stream_id = response.stream_id + else: + stream_id = self._next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(response.body is None or len(response.body) == 0)), + self._create_body(response.body, stream_id), + )) + + def perform_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + if self.is_server: + self.perform_server_connection_preface(force) + else: + self.perform_client_connection_preface(force) + + def perform_server_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + self.connection_preface_performed = True + + magic_length = len(self.CLIENT_CONNECTION_PREFACE) + magic = self.tcp_handler.rfile.safe_read(magic_length) + assert magic == self.CLIENT_CONNECTION_PREFACE + + self.send_frame(frame.SettingsFrame(state=self), hide=True) + self._receive_settings(hide=True) + + def perform_client_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + self.connection_preface_performed = True + + self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) + + self.send_frame(frame.SettingsFrame(state=self), hide=True) + self._receive_settings(hide=True) + + def send_frame(self, frm, hide=False): + raw_bytes = frm.to_bytes() + self.tcp_handler.wfile.write(raw_bytes) + self.tcp_handler.wfile.flush() + if not hide and self.dump_frames: # pragma no cover + print(frm.human_readable(">>")) + + def read_frame(self, hide=False): + frm = frame.Frame.from_file(self.tcp_handler.rfile, self) + if not hide and self.dump_frames: # pragma no cover + print(frm.human_readable("<<")) + if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: + self._apply_settings(frm.settings, hide) + + return frm def check_alpn(self): alp = self.tcp_handler.get_alpn_proto_negotiated() @@ -63,27 +237,7 @@ class HTTP2Protocol(object): assert len(frm.settings) == 0 break - def perform_server_connection_preface(self, force=False): - if force or not self.connection_preface_performed: - self.connection_preface_performed = True - - magic_length = len(self.CLIENT_CONNECTION_PREFACE) - magic = self.tcp_handler.rfile.safe_read(magic_length) - assert magic == self.CLIENT_CONNECTION_PREFACE - - self.send_frame(frame.SettingsFrame(state=self), hide=True) - self._receive_settings(hide=True) - - def perform_client_connection_preface(self, force=False): - if force or not self.connection_preface_performed: - self.connection_preface_performed = True - - self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) - - self.send_frame(frame.SettingsFrame(state=self), hide=True) - self._receive_settings(hide=True) - - def next_stream_id(self): + def _next_stream_id(self): if self.current_stream_id is None: if self.is_server: # servers must use even stream ids @@ -95,22 +249,6 @@ class HTTP2Protocol(object): self.current_stream_id += 2 return self.current_stream_id - def send_frame(self, frm, hide=False): - raw_bytes = frm.to_bytes() - self.tcp_handler.wfile.write(raw_bytes) - self.tcp_handler.wfile.flush() - if not hide and self.dump_frames: # pragma no cover - print(frm.human_readable(">>")) - - def read_frame(self, hide=False): - frm = frame.Frame.from_file(self.tcp_handler.rfile, self) - if not hide and self.dump_frames: # pragma no cover - print(frm.human_readable("<<")) - if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: - self._apply_settings(frm.settings, hide) - - return frm - def _apply_settings(self, settings, hide=False): for setting, value in settings.items(): old_value = self.http2_settings[setting] @@ -164,51 +302,7 @@ class HTTP2Protocol(object): return [frm.to_bytes()] - - def create_request(self, method, path, headers=None, body=None): - if headers is None: - headers = [] - - authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host - if self.tcp_handler.address.port != 443: - authority += ":%d" % self.tcp_handler.address.port - - headers = [ - (b':method', bytes(method)), - (b':path', bytes(path)), - (b':scheme', b'https'), - (b':authority', authority), - ] + headers - - stream_id = self.next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(body is None)), - self._create_body(body, stream_id))) - - def read_response(self, *args): - stream_id, headers, body = self._receive_transmission() - - status = headers[':status'][0] - response = http.Response("HTTP/2", status, "", headers, body) - response.stream_id = stream_id - return response - - def read_request(self): - stream_id, headers, body = self._receive_transmission() - - form_in = "" - method = headers.get(':method', [''])[0] - scheme = headers.get(':scheme', [''])[0] - host = headers.get(':host', [''])[0] - port = '' # TODO: parse port number? - path = headers.get(':path', [''])[0] - - request = http.Request(form_in, method, scheme, host, port, path, "HTTP/2", headers, body) - request.stream_id = stream_id - return request - - def _receive_transmission(self): + def _receive_transmission(self, include_body=True): body_expected = True stream_id = 0 @@ -239,19 +333,3 @@ class HTTP2Protocol(object): headers.add(header, value) return stream_id, headers, body - - def create_response(self, code, stream_id=None, headers=None, body=None): - if headers is None: - headers = [] - if isinstance(headers, odict.ODict): - headers = headers.items() - - headers = [(b':status', bytes(str(code)))] + headers - - if not stream_id: - stream_id = self.next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(body is None)), - self._create_body(body, stream_id), - )) diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 9e13edaa4..54bf83d27 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -7,6 +7,32 @@ import urlparse from .. import utils, odict +CONTENT_MISSING = 0 + + +class ProtocolMixin(object): + + def read_request(self): + raise NotImplemented + + def read_response(self): + raise NotImplemented + + def assemble(self, message): + if isinstance(message, Request): + return self.assemble_request(message) + elif isinstance(message, Response): + return self.assemble_response(message) + else: + raise ValueError("HTTP message not supported.") + + def assemble_request(self, request): + raise NotImplemented + + def assemble_response(self, response): + raise NotImplemented + + class Request(object): def __init__( @@ -18,9 +44,15 @@ class Request(object): port, path, httpversion, - headers, - body, + headers=None, + body=None, + timestamp_start=None, + timestamp_end=None, ): + if not headers: + headers = odict.ODictCaseless() + assert isinstance(headers, odict.ODictCaseless) + self.form_in = form_in self.method = method self.scheme = scheme @@ -30,17 +62,31 @@ class Request(object): self.httpversion = httpversion self.headers = headers self.body = body + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end + def __eq__(self, other): - return self.__dict__ == other.__dict__ + try: + self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + return self_d == other_d + except: + return False def __repr__(self): return "Request(%s - %s, %s)" % (self.method, self.host, self.path) @property def content(self): + # TODO: remove deprecated getter return self.body + @content.setter + def content(self, content): + # TODO: remove deprecated setter + self.body = content + class EmptyRequest(Request): def __init__(self): @@ -63,28 +109,59 @@ class Response(object): self, httpversion, status_code, - msg, - headers, - body, + msg=None, + headers=None, + body=None, sslinfo=None, + timestamp_start=None, + timestamp_end=None, ): + if not headers: + headers = odict.ODictCaseless() + assert isinstance(headers, odict.ODictCaseless) + self.httpversion = httpversion self.status_code = status_code self.msg = msg self.headers = headers self.body = body self.sslinfo = sslinfo + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end + def __eq__(self, other): - return self.__dict__ == other.__dict__ + try: + self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + return self_d == other_d + except: + return False def __repr__(self): return "Response(%s - %s)" % (self.status_code, self.msg) @property def content(self): + # TODO: remove deprecated getter return self.body + @content.setter + def content(self, content): + # TODO: remove deprecated setter + self.body = content + + @property + def code(self): + # TODO: remove deprecated getter + return self.status_code + + @code.setter + def code(self, code): + # TODO: remove deprecated setter + self.status_code = code + + def is_valid_port(port): if not 0 <= port <= 65535: diff --git a/netlib/odict.py b/netlib/odict.py index f52acd504..d02de08d1 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -96,8 +96,11 @@ class ODict(object): return True return False - def add(self, key, value): - self.lst.append([key, value]) + def add(self, key, value, prepend=False): + if prepend: + self.lst.insert(0, [key, value]) + else: + self.lst.append([key, value]) def get(self, k, d=None): if k in self: diff --git a/netlib/utils.py b/netlib/utils.py index bee412f96..86e33f333 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -129,3 +129,13 @@ class Data(object): if not os.path.exists(fullpath): raise ValueError("dataPath: %s does not exist." % fullpath) return fullpath + + +def hostport(scheme, host, port): + """ + Returns the host component, with a port specifcation if needed. + """ + if (port, scheme) in [(80, "http"), (443, "https")]: + return host + else: + return "%s:%s" % (host, port) diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index dcebbd5ee..b196b7a3d 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -297,10 +297,10 @@ class TestReadResponseNoContentLength(tservers.ServerTestBase): def test_read_response(): - def tst(data, method, limit, include_body=True): + def tst(data, method, body_size_limit, include_body=True): data = textwrap.dedent(data) return mock_protocol(data).read_response( - method, limit, include_body=include_body + method, body_size_limit, include_body=include_body ) tutils.raises("server disconnect", tst, "", "GET", None) diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index d30402663..b2d414d16 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -1,6 +1,6 @@ import OpenSSL -from netlib import tcp, odict +from netlib import tcp, odict, http from netlib.http import http2 from netlib.http.http2.frame import * from ... import tutils, tservers @@ -117,11 +117,11 @@ class TestClientStreamIds(): def test_client_stream_ids(self): assert self.protocol.current_stream_id is None - assert self.protocol.next_stream_id() == 1 + assert self.protocol._next_stream_id() == 1 assert self.protocol.current_stream_id == 1 - assert self.protocol.next_stream_id() == 3 + assert self.protocol._next_stream_id() == 3 assert self.protocol.current_stream_id == 3 - assert self.protocol.next_stream_id() == 5 + assert self.protocol._next_stream_id() == 5 assert self.protocol.current_stream_id == 5 @@ -131,11 +131,11 @@ class TestServerStreamIds(): def test_server_stream_ids(self): assert self.protocol.current_stream_id is None - assert self.protocol.next_stream_id() == 2 + assert self.protocol._next_stream_id() == 2 assert self.protocol.current_stream_id == 2 - assert self.protocol.next_stream_id() == 4 + assert self.protocol._next_stream_id() == 4 assert self.protocol.current_stream_id == 4 - assert self.protocol.next_stream_id() == 6 + assert self.protocol._next_stream_id() == 6 assert self.protocol.current_stream_id == 6 @@ -215,17 +215,36 @@ class TestCreateBody(): # TODO: add test for too large frames -class TestCreateRequest(): +class TestAssembleRequest(): c = tcp.TCPClient(("127.0.0.1", 0)) - def test_create_request_simple(self): - bytes = http2.HTTP2Protocol(self.c).create_request('GET', '/') + def test_assemble_request_simple(self): + bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request( + '', + 'GET', + 'https', + '', + '', + '/', + (2, 0), + None, + None, + )) assert len(bytes) == 1 assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex') - def test_create_request_with_body(self): - bytes = http2.HTTP2Protocol(self.c).create_request( - 'GET', '/', [(b'foo', b'bar')], 'foobar') + def test_assemble_request_with_body(self): + bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request( + '', + 'GET', + 'https', + '', + '', + '/', + (2, 0), + odict.ODictCaseless([('foo', 'bar')]), + 'foobar', + )) assert len(bytes) == 2 assert bytes[0] ==\ '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex') @@ -250,11 +269,12 @@ class TestReadResponse(tservers.ServerTestBase): c.connect() c.convert_to_ssl() protocol = http2.HTTP2Protocol(c) + protocol.connection_preface_performed = True resp = protocol.read_response() - assert resp.httpversion == "HTTP/2" - assert resp.status_code == "200" + assert resp.httpversion == (2, 0) + assert resp.status_code == 200 assert resp.msg == "" assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] assert resp.body == b'foobar' @@ -275,12 +295,13 @@ class TestReadEmptyResponse(tservers.ServerTestBase): c.connect() c.convert_to_ssl() protocol = http2.HTTP2Protocol(c) + protocol.connection_preface_performed = True resp = protocol.read_response() assert resp.stream_id - assert resp.httpversion == "HTTP/2" - assert resp.status_code == "200" + assert resp.httpversion == (2, 0) + assert resp.status_code == 200 assert resp.msg == "" assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] assert resp.body == b'' @@ -303,6 +324,7 @@ class TestReadRequest(tservers.ServerTestBase): c.connect() c.convert_to_ssl() protocol = http2.HTTP2Protocol(c, is_server=True) + protocol.connection_preface_performed = True resp = protocol.read_request() @@ -315,16 +337,24 @@ class TestCreateResponse(): c = tcp.TCPClient(("127.0.0.1", 0)) def test_create_response_simple(self): - bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(200) + bytes = http2.HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( + (2, 0), + 200, + )) assert len(bytes) == 1 assert bytes[0] ==\ '00000101050000000288'.decode('hex') def test_create_response_with_body(self): - bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response( - 200, 1, [(b'foo', b'bar')], 'foobar') + bytes = http2.HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( + (2, 0), + 200, + '', + odict.ODictCaseless([('foo', 'bar')]), + 'foobar' + )) assert len(bytes) == 2 assert bytes[0] ==\ - '00000901040000000188408294e7838c767f'.decode('hex') + '00000901040000000288408294e7838c767f'.decode('hex') assert bytes[1] ==\ - '000006000100000001666f6f626172'.decode('hex') + '000006000100000002666f6f626172'.decode('hex')