Use handlers in http2 test suite

This commit is contained in:
Aldo Cortesi 2016-06-14 12:09:13 +12:00
parent e6fd98bb72
commit 9e63350a96

View File

@ -75,10 +75,10 @@ class TestCheckALPNMatch(tservers.ServerTestBase):
def test_check_alpn(self): def test_check_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect() with c.connect():
c.convert_to_ssl(alpn_protos=[b'h2']) c.convert_to_ssl(alpn_protos=[b'h2'])
protocol = HTTP2Protocol(c) protocol = HTTP2Protocol(c)
assert protocol.check_alpn() assert protocol.check_alpn()
class TestCheckALPNMismatch(tservers.ServerTestBase): class TestCheckALPNMismatch(tservers.ServerTestBase):
@ -91,11 +91,11 @@ class TestCheckALPNMismatch(tservers.ServerTestBase):
def test_check_alpn(self): def test_check_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect() with c.connect():
c.convert_to_ssl(alpn_protos=[b'h2']) c.convert_to_ssl(alpn_protos=[b'h2'])
protocol = HTTP2Protocol(c) protocol = HTTP2Protocol(c)
with raises(NotImplementedError): with raises(NotImplementedError):
protocol.check_alpn() protocol.check_alpn()
class TestPerformServerConnectionPreface(tservers.ServerTestBase): class TestPerformServerConnectionPreface(tservers.ServerTestBase):
@ -124,15 +124,15 @@ class TestPerformServerConnectionPreface(tservers.ServerTestBase):
def test_perform_server_connection_preface(self): def test_perform_server_connection_preface(self):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect() with c.connect():
protocol = HTTP2Protocol(c) protocol = HTTP2Protocol(c)
assert not protocol.connection_preface_performed assert not protocol.connection_preface_performed
protocol.perform_server_connection_preface() protocol.perform_server_connection_preface()
assert protocol.connection_preface_performed assert protocol.connection_preface_performed
with raises(TcpDisconnect): with raises(TcpDisconnect):
protocol.perform_server_connection_preface(force=True) protocol.perform_server_connection_preface(force=True)
class TestPerformClientConnectionPreface(tservers.ServerTestBase): class TestPerformClientConnectionPreface(tservers.ServerTestBase):
@ -160,12 +160,12 @@ class TestPerformClientConnectionPreface(tservers.ServerTestBase):
def test_perform_client_connection_preface(self): def test_perform_client_connection_preface(self):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect() with c.connect():
protocol = HTTP2Protocol(c) protocol = HTTP2Protocol(c)
assert not protocol.connection_preface_performed assert not protocol.connection_preface_performed
protocol.perform_client_connection_preface() protocol.perform_client_connection_preface()
assert protocol.connection_preface_performed assert protocol.connection_preface_performed
class TestClientStreamIds(object): class TestClientStreamIds(object):
@ -209,24 +209,24 @@ class TestApplySettings(tservers.ServerTestBase):
def test_apply_settings(self): def test_apply_settings(self):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect() with c.connect():
c.convert_to_ssl() c.convert_to_ssl()
protocol = HTTP2Protocol(c) protocol = HTTP2Protocol(c)
protocol._apply_settings({ protocol._apply_settings({
hyperframe.frame.SettingsFrame.ENABLE_PUSH: 'foo', hyperframe.frame.SettingsFrame.ENABLE_PUSH: 'foo',
hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 'bar', hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 'bar',
hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE: 'deadbeef', hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE: 'deadbeef',
}) })
assert c.rfile.safe_read(2) == b"OK" assert c.rfile.safe_read(2) == b"OK"
assert protocol.http2_settings[ assert protocol.http2_settings[
hyperframe.frame.SettingsFrame.ENABLE_PUSH] == 'foo' hyperframe.frame.SettingsFrame.ENABLE_PUSH] == 'foo'
assert protocol.http2_settings[ assert protocol.http2_settings[
hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS] == 'bar' hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS] == 'bar'
assert protocol.http2_settings[ assert protocol.http2_settings[
hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE] == 'deadbeef' hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE] == 'deadbeef'
class TestCreateHeaders(object): class TestCreateHeaders(object):
@ -304,19 +304,19 @@ class TestReadRequest(tservers.ServerTestBase):
def test_read_request(self): def test_read_request(self):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect() with c.connect():
c.convert_to_ssl() c.convert_to_ssl()
protocol = HTTP2Protocol(c, is_server=True) protocol = HTTP2Protocol(c, is_server=True)
protocol.connection_preface_performed = True protocol.connection_preface_performed = True
req = protocol.read_request(NotImplemented) req = protocol.read_request(NotImplemented)
assert req.stream_id assert req.stream_id
assert req.headers.fields == () assert req.headers.fields == ()
assert req.method == "GET" assert req.method == "GET"
assert req.path == "/" assert req.path == "/"
assert req.scheme == "https" assert req.scheme == "https"
assert req.content == b'foobar' assert req.content == b'foobar'
class TestReadRequestRelative(tservers.ServerTestBase): class TestReadRequestRelative(tservers.ServerTestBase):
@ -330,16 +330,16 @@ class TestReadRequestRelative(tservers.ServerTestBase):
def test_asterisk_form(self): def test_asterisk_form(self):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect() with c.connect():
c.convert_to_ssl() c.convert_to_ssl()
protocol = HTTP2Protocol(c, is_server=True) protocol = HTTP2Protocol(c, is_server=True)
protocol.connection_preface_performed = True protocol.connection_preface_performed = True
req = protocol.read_request(NotImplemented) req = protocol.read_request(NotImplemented)
assert req.first_line_format == "relative" assert req.first_line_format == "relative"
assert req.method == "OPTIONS" assert req.method == "OPTIONS"
assert req.path == "*" assert req.path == "*"
class TestReadRequestAbsolute(tservers.ServerTestBase): class TestReadRequestAbsolute(tservers.ServerTestBase):
@ -353,17 +353,17 @@ class TestReadRequestAbsolute(tservers.ServerTestBase):
def test_absolute_form(self): def test_absolute_form(self):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect() with c.connect():
c.convert_to_ssl() c.convert_to_ssl()
protocol = HTTP2Protocol(c, is_server=True) protocol = HTTP2Protocol(c, is_server=True)
protocol.connection_preface_performed = True protocol.connection_preface_performed = True
req = protocol.read_request(NotImplemented) req = protocol.read_request(NotImplemented)
assert req.first_line_format == "absolute" assert req.first_line_format == "absolute"
assert req.scheme == "http" assert req.scheme == "http"
assert req.host == "address" assert req.host == "address"
assert req.port == 22 assert req.port == 22
class TestReadRequestConnect(tservers.ServerTestBase): class TestReadRequestConnect(tservers.ServerTestBase):
@ -379,22 +379,22 @@ class TestReadRequestConnect(tservers.ServerTestBase):
def test_connect(self): def test_connect(self):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect() with c.connect():
c.convert_to_ssl() c.convert_to_ssl()
protocol = HTTP2Protocol(c, is_server=True) protocol = HTTP2Protocol(c, is_server=True)
protocol.connection_preface_performed = True protocol.connection_preface_performed = True
req = protocol.read_request(NotImplemented) req = protocol.read_request(NotImplemented)
assert req.first_line_format == "authority" assert req.first_line_format == "authority"
assert req.method == "CONNECT" assert req.method == "CONNECT"
assert req.host == "address" assert req.host == "address"
assert req.port == 22 assert req.port == 22
req = protocol.read_request(NotImplemented) req = protocol.read_request(NotImplemented)
assert req.first_line_format == "authority" assert req.first_line_format == "authority"
assert req.method == "CONNECT" assert req.method == "CONNECT"
assert req.host == "example.com" assert req.host == "example.com"
assert req.port == 443 assert req.port == 443
class TestReadResponse(tservers.ServerTestBase): class TestReadResponse(tservers.ServerTestBase):
@ -411,19 +411,19 @@ class TestReadResponse(tservers.ServerTestBase):
def test_read_response(self): def test_read_response(self):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect() with c.connect():
c.convert_to_ssl() c.convert_to_ssl()
protocol = HTTP2Protocol(c) protocol = HTTP2Protocol(c)
protocol.connection_preface_performed = True protocol.connection_preface_performed = True
resp = protocol.read_response(NotImplemented, stream_id=42) resp = protocol.read_response(NotImplemented, stream_id=42)
assert resp.http_version == "HTTP/2.0" assert resp.http_version == "HTTP/2.0"
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.reason == '' assert resp.reason == ''
assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar')) assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar'))
assert resp.content == b'foobar' assert resp.content == b'foobar'
assert resp.timestamp_end assert resp.timestamp_end
class TestReadEmptyResponse(tservers.ServerTestBase): class TestReadEmptyResponse(tservers.ServerTestBase):
@ -437,19 +437,19 @@ class TestReadEmptyResponse(tservers.ServerTestBase):
def test_read_empty_response(self): def test_read_empty_response(self):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect() with c.connect():
c.convert_to_ssl() c.convert_to_ssl()
protocol = HTTP2Protocol(c) protocol = HTTP2Protocol(c)
protocol.connection_preface_performed = True protocol.connection_preface_performed = True
resp = protocol.read_response(NotImplemented, stream_id=42) resp = protocol.read_response(NotImplemented, stream_id=42)
assert resp.stream_id == 42 assert resp.stream_id == 42
assert resp.http_version == "HTTP/2.0" assert resp.http_version == "HTTP/2.0"
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.reason == '' assert resp.reason == ''
assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar')) assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar'))
assert resp.content == b'' assert resp.content == b''
class TestAssembleRequest(object): class TestAssembleRequest(object):