mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 18:18:25 +00:00
http2: improve test suite
This commit is contained in:
parent
20c136e070
commit
abb37a3ef5
@ -55,7 +55,7 @@ class HTTP2Protocol(object):
|
||||
if isinstance(frm, frame.SettingsFrame):
|
||||
break
|
||||
|
||||
def _read_settings_ack(self, hide=False):
|
||||
def _read_settings_ack(self, hide=False): # pragma no cover
|
||||
while True:
|
||||
frm = self.read_frame(hide)
|
||||
if isinstance(frm, frame.SettingsFrame):
|
||||
@ -99,12 +99,12 @@ class HTTP2Protocol(object):
|
||||
raw_bytes = frm.to_bytes()
|
||||
self.tcp_handler.wfile.write(raw_bytes)
|
||||
self.tcp_handler.wfile.flush()
|
||||
if not hide and self.dump_frames:
|
||||
if not hide and self.dump_frames: # pragma no cover
|
||||
print(frm.human_readable(">>"))
|
||||
|
||||
def read_frame(self, hide=False):
|
||||
frm = frame.Frame.from_file(self.tcp_handler.rfile, self)
|
||||
if not hide and self.dump_frames:
|
||||
if not hide and self.dump_frames: # pragma no cover
|
||||
print(frm.human_readable("<<"))
|
||||
if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK:
|
||||
self._apply_settings(frm.settings, hide)
|
||||
@ -123,7 +123,9 @@ class HTTP2Protocol(object):
|
||||
state=self,
|
||||
flags=frame.Frame.FLAG_ACK),
|
||||
hide)
|
||||
# self._read_settings_ack(hide)
|
||||
|
||||
# be liberal in what we expect from the other end
|
||||
# to be more strict use: self._read_settings_ack(hide)
|
||||
|
||||
def _create_headers(self, headers, stream_id, end_stream=True):
|
||||
# TODO: implement max frame size checks and sending in chunks
|
||||
@ -140,7 +142,7 @@ class HTTP2Protocol(object):
|
||||
stream_id=stream_id,
|
||||
header_block_fragment=header_block_fragment)
|
||||
|
||||
if self.dump_frames:
|
||||
if self.dump_frames: # pragma no cover
|
||||
print(frm.human_readable(">>"))
|
||||
|
||||
return [frm.to_bytes()]
|
||||
@ -158,7 +160,7 @@ class HTTP2Protocol(object):
|
||||
stream_id=stream_id,
|
||||
payload=body)
|
||||
|
||||
if self.dump_frames:
|
||||
if self.dump_frames: # pragma no cover
|
||||
print(frm.human_readable(">>"))
|
||||
|
||||
return [frm.to_bytes()]
|
||||
@ -225,8 +227,6 @@ class HTTP2Protocol(object):
|
||||
if headers is None:
|
||||
headers = []
|
||||
|
||||
body='foobar'
|
||||
|
||||
headers = [(b':status', bytes(str(code)))] + headers
|
||||
|
||||
if not stream_id:
|
||||
|
@ -414,6 +414,9 @@ class _Connection(object):
|
||||
if cipher_list:
|
||||
try:
|
||||
context.set_cipher_list(cipher_list)
|
||||
|
||||
# TODO: maybe change this to with newer pyOpenSSL APIs
|
||||
context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1'))
|
||||
except SSL.Error as v:
|
||||
raise NetLibError("SSL cipher specification error: %s" % str(v))
|
||||
|
||||
@ -421,8 +424,6 @@ class _Connection(object):
|
||||
if log_ssl_key:
|
||||
context.set_info_callback(log_ssl_key)
|
||||
|
||||
context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1'))
|
||||
|
||||
if OpenSSL._util.lib.Cryptography_HAS_ALPN:
|
||||
if alpn_protos is not None:
|
||||
# advertise application layer protocols
|
||||
@ -526,7 +527,7 @@ class TCPClient(_Connection):
|
||||
if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established:
|
||||
return self.connection.get_alpn_proto_negotiated()
|
||||
else:
|
||||
return None
|
||||
return ""
|
||||
|
||||
|
||||
class BaseHandler(_Connection):
|
||||
@ -636,7 +637,7 @@ class BaseHandler(_Connection):
|
||||
if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established:
|
||||
return self.connection.get_alpn_proto_negotiated()
|
||||
else:
|
||||
return None
|
||||
return ""
|
||||
|
||||
|
||||
class TCPServer(object):
|
||||
|
@ -300,8 +300,9 @@ class TestReadRequest(test.ServerTestBase):
|
||||
c.convert_to_ssl()
|
||||
protocol = http2.HTTP2Protocol(c, is_server=True)
|
||||
|
||||
headers, body = protocol.read_request()
|
||||
stream_id, headers, body = protocol.read_request()
|
||||
|
||||
assert stream_id
|
||||
assert headers == {':method': 'GET', ':path': '/', ':scheme': 'https'}
|
||||
assert body == b'foobar'
|
||||
|
||||
@ -309,17 +310,17 @@ class TestReadRequest(test.ServerTestBase):
|
||||
class TestCreateResponse():
|
||||
c = tcp.TCPClient(("127.0.0.1", 0))
|
||||
|
||||
def test_create_request_simple(self):
|
||||
def test_create_response_simple(self):
|
||||
bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(200)
|
||||
assert len(bytes) == 1
|
||||
assert bytes[0] ==\
|
||||
'00000101050000000288'.decode('hex')
|
||||
|
||||
def test_create_request_with_body(self):
|
||||
def test_create_response_with_body(self):
|
||||
bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(
|
||||
200, [(b'foo', b'bar')], 'foobar')
|
||||
200, 1, [(b'foo', b'bar')], 'foobar')
|
||||
assert len(bytes) == 2
|
||||
assert bytes[0] ==\
|
||||
'00000901040000000288408294e7838c767f'.decode('hex')
|
||||
'00000901040000000188408294e7838c767f'.decode('hex')
|
||||
assert bytes[1] ==\
|
||||
'000006000100000002666f6f626172'.decode('hex')
|
||||
'000006000100000001666f6f626172'.decode('hex')
|
||||
|
@ -41,6 +41,18 @@ class HangHandler(tcp.BaseHandler):
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
class ALPNHandler(tcp.BaseHandler):
|
||||
sni = None
|
||||
|
||||
def handle(self):
|
||||
alp = self.get_alpn_proto_negotiated()
|
||||
if alp:
|
||||
self.wfile.write("%s" % alp)
|
||||
else:
|
||||
self.wfile.write("NONE")
|
||||
self.wfile.flush()
|
||||
|
||||
|
||||
class TestServer(test.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
|
||||
@ -416,30 +428,43 @@ class TestTimeOut(test.ServerTestBase):
|
||||
tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10)
|
||||
|
||||
|
||||
class TestALPN(test.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
class TestALPNClient(test.ServerTestBase):
|
||||
handler = ALPNHandler
|
||||
ssl = dict(
|
||||
alpn_select="foobar"
|
||||
alpn_select="bar"
|
||||
)
|
||||
|
||||
if OpenSSL._util.lib.Cryptography_HAS_ALPN:
|
||||
def test_alpn(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.convert_to_ssl(alpn_protos=["foobar"])
|
||||
assert c.get_alpn_proto_negotiated() == "foobar"
|
||||
c.convert_to_ssl(alpn_protos=["foo", "bar", "fasel"])
|
||||
assert c.get_alpn_proto_negotiated() == "bar"
|
||||
assert c.rfile.readline().strip() == "bar"
|
||||
|
||||
def test_no_alpn(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
assert c.get_alpn_proto_negotiated() == None
|
||||
c.convert_to_ssl()
|
||||
assert c.get_alpn_proto_negotiated() == ""
|
||||
assert c.rfile.readline().strip() == "NONE"
|
||||
|
||||
else:
|
||||
def test_none_alpn(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.convert_to_ssl(alpn_protos=["foobar"])
|
||||
assert c.get_alpn_proto_negotiated() == None
|
||||
c.convert_to_ssl(alpn_protos=["foo", "bar", "fasel"])
|
||||
assert c.get_alpn_proto_negotiated() == ""
|
||||
assert c.rfile.readline() == "NONE"
|
||||
|
||||
class TestNoSSLNoALPNClient(test.ServerTestBase):
|
||||
handler = ALPNHandler
|
||||
|
||||
def test_no_ssl_no_alpn(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
assert c.get_alpn_proto_negotiated() == ""
|
||||
assert c.rfile.readline().strip() == "NONE"
|
||||
|
||||
|
||||
class TestSSLTimeOut(test.ServerTestBase):
|
||||
|
Loading…
Reference in New Issue
Block a user