http2: improve test suite

This commit is contained in:
Thomas Kriechbaumer 2015-06-15 17:31:08 +02:00
parent 20c136e070
commit abb37a3ef5
4 changed files with 53 additions and 26 deletions

View File

@ -55,7 +55,7 @@ class HTTP2Protocol(object):
if isinstance(frm, frame.SettingsFrame):
break
def _read_settings_ack(self, hide=False):
def _read_settings_ack(self, hide=False): # pragma no cover
while True:
frm = self.read_frame(hide)
if isinstance(frm, frame.SettingsFrame):
@ -99,12 +99,12 @@ class HTTP2Protocol(object):
raw_bytes = frm.to_bytes()
self.tcp_handler.wfile.write(raw_bytes)
self.tcp_handler.wfile.flush()
if not hide and self.dump_frames:
if not hide and self.dump_frames: # pragma no cover
print(frm.human_readable(">>"))
def read_frame(self, hide=False):
frm = frame.Frame.from_file(self.tcp_handler.rfile, self)
if not hide and self.dump_frames:
if not hide and self.dump_frames: # pragma no cover
print(frm.human_readable("<<"))
if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK:
self._apply_settings(frm.settings, hide)
@ -123,7 +123,9 @@ class HTTP2Protocol(object):
state=self,
flags=frame.Frame.FLAG_ACK),
hide)
# self._read_settings_ack(hide)
# be liberal in what we expect from the other end
# to be more strict use: self._read_settings_ack(hide)
def _create_headers(self, headers, stream_id, end_stream=True):
# TODO: implement max frame size checks and sending in chunks
@ -140,7 +142,7 @@ class HTTP2Protocol(object):
stream_id=stream_id,
header_block_fragment=header_block_fragment)
if self.dump_frames:
if self.dump_frames: # pragma no cover
print(frm.human_readable(">>"))
return [frm.to_bytes()]
@ -158,7 +160,7 @@ class HTTP2Protocol(object):
stream_id=stream_id,
payload=body)
if self.dump_frames:
if self.dump_frames: # pragma no cover
print(frm.human_readable(">>"))
return [frm.to_bytes()]
@ -225,8 +227,6 @@ class HTTP2Protocol(object):
if headers is None:
headers = []
body='foobar'
headers = [(b':status', bytes(str(code)))] + headers
if not stream_id:

View File

@ -414,6 +414,9 @@ class _Connection(object):
if cipher_list:
try:
context.set_cipher_list(cipher_list)
# TODO: maybe change this to with newer pyOpenSSL APIs
context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1'))
except SSL.Error as v:
raise NetLibError("SSL cipher specification error: %s" % str(v))
@ -421,8 +424,6 @@ class _Connection(object):
if log_ssl_key:
context.set_info_callback(log_ssl_key)
context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1'))
if OpenSSL._util.lib.Cryptography_HAS_ALPN:
if alpn_protos is not None:
# advertise application layer protocols
@ -526,7 +527,7 @@ class TCPClient(_Connection):
if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established:
return self.connection.get_alpn_proto_negotiated()
else:
return None
return ""
class BaseHandler(_Connection):
@ -636,7 +637,7 @@ class BaseHandler(_Connection):
if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established:
return self.connection.get_alpn_proto_negotiated()
else:
return None
return ""
class TCPServer(object):

View File

@ -300,8 +300,9 @@ class TestReadRequest(test.ServerTestBase):
c.convert_to_ssl()
protocol = http2.HTTP2Protocol(c, is_server=True)
headers, body = protocol.read_request()
stream_id, headers, body = protocol.read_request()
assert stream_id
assert headers == {':method': 'GET', ':path': '/', ':scheme': 'https'}
assert body == b'foobar'
@ -309,17 +310,17 @@ class TestReadRequest(test.ServerTestBase):
class TestCreateResponse():
c = tcp.TCPClient(("127.0.0.1", 0))
def test_create_request_simple(self):
def test_create_response_simple(self):
bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(200)
assert len(bytes) == 1
assert bytes[0] ==\
'00000101050000000288'.decode('hex')
def test_create_request_with_body(self):
def test_create_response_with_body(self):
bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(
200, [(b'foo', b'bar')], 'foobar')
200, 1, [(b'foo', b'bar')], 'foobar')
assert len(bytes) == 2
assert bytes[0] ==\
'00000901040000000288408294e7838c767f'.decode('hex')
'00000901040000000188408294e7838c767f'.decode('hex')
assert bytes[1] ==\
'000006000100000002666f6f626172'.decode('hex')
'000006000100000001666f6f626172'.decode('hex')

View File

@ -41,6 +41,18 @@ class HangHandler(tcp.BaseHandler):
time.sleep(1)
class ALPNHandler(tcp.BaseHandler):
sni = None
def handle(self):
alp = self.get_alpn_proto_negotiated()
if alp:
self.wfile.write("%s" % alp)
else:
self.wfile.write("NONE")
self.wfile.flush()
class TestServer(test.ServerTestBase):
handler = EchoHandler
@ -416,30 +428,43 @@ class TestTimeOut(test.ServerTestBase):
tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10)
class TestALPN(test.ServerTestBase):
handler = EchoHandler
class TestALPNClient(test.ServerTestBase):
handler = ALPNHandler
ssl = dict(
alpn_select="foobar"
alpn_select="bar"
)
if OpenSSL._util.lib.Cryptography_HAS_ALPN:
def test_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl(alpn_protos=["foobar"])
assert c.get_alpn_proto_negotiated() == "foobar"
c.convert_to_ssl(alpn_protos=["foo", "bar", "fasel"])
assert c.get_alpn_proto_negotiated() == "bar"
assert c.rfile.readline().strip() == "bar"
def test_no_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
assert c.get_alpn_proto_negotiated() == None
c.convert_to_ssl()
assert c.get_alpn_proto_negotiated() == ""
assert c.rfile.readline().strip() == "NONE"
else:
def test_none_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl(alpn_protos=["foobar"])
assert c.get_alpn_proto_negotiated() == None
c.convert_to_ssl(alpn_protos=["foo", "bar", "fasel"])
assert c.get_alpn_proto_negotiated() == ""
assert c.rfile.readline() == "NONE"
class TestNoSSLNoALPNClient(test.ServerTestBase):
handler = ALPNHandler
def test_no_ssl_no_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
assert c.get_alpn_proto_negotiated() == ""
assert c.rfile.readline().strip() == "NONE"
class TestSSLTimeOut(test.ServerTestBase):