refactor http2 tests

This commit is contained in:
Thomas Kriechbaumer 2016-01-26 13:15:20 +01:00
parent 187691e65b
commit 276817e40e

View File

@ -4,8 +4,16 @@ import inspect
import socket import socket
import OpenSSL import OpenSSL
import pytest import pytest
import traceback
import os
import tempfile
from io import BytesIO from io import BytesIO
from libmproxy.proxy.config import ProxyConfig
from libmproxy.proxy.server import ProxyServer
from libmproxy.cmdline import APP_HOST, APP_PORT
import logging import logging
logging.getLogger("hyper.packages.hpack.hpack").setLevel(logging.WARNING) logging.getLogger("hyper.packages.hpack.hpack").setLevel(logging.WARNING)
logging.getLogger("requests.packages.urllib3.connectionpool").setLevel(logging.WARNING) logging.getLogger("requests.packages.urllib3.connectionpool").setLevel(logging.WARNING)
@ -18,6 +26,7 @@ import netlib
from netlib import tservers as netlib_tservers from netlib import tservers as netlib_tservers
import h2 import h2
from hyperframe.frame import Frame
from libmproxy import utils from libmproxy import utils
from . import tservers from . import tservers
@ -26,8 +35,7 @@ requires_alpn = pytest.mark.skipif(
not OpenSSL._util.lib.Cryptography_HAS_ALPN, not OpenSSL._util.lib.Cryptography_HAS_ALPN,
reason="requires OpenSSL with ALPN support") reason="requires OpenSSL with ALPN support")
class _Http2ServerBase(netlib_tservers.ServerTestBase):
class SimpleHttp2Server(netlib_tservers.ServerTestBase):
ssl = dict(alpn_select=b'h2') ssl = dict(alpn_select=b'h2')
class handler(netlib.tcp.BaseHandler): class handler(netlib.tcp.BaseHandler):
@ -41,78 +49,56 @@ class SimpleHttp2Server(netlib_tservers.ServerTestBase):
self.wfile.flush() self.wfile.flush()
while True: while True:
events = h2_conn.receive_data(utils.http2_read_frame(self.rfile)) raw_frame = utils.http2_read_frame(self.rfile)
events = h2_conn.receive_data(raw_frame)
self.wfile.write(h2_conn.data_to_send()) self.wfile.write(h2_conn.data_to_send())
self.wfile.flush() self.wfile.flush()
for event in events: for event in events:
if isinstance(event, h2.events.RequestReceived): try:
h2_conn.send_headers(1, [ if not self.server.handle_server_event(event, h2_conn, self.rfile, self.wfile):
(':status', '200'), break
('foo', 'bar'), except Exception as e:
]) print(repr(e))
h2_conn.send_data(1, b'foobar') print(traceback.format_exc())
h2_conn.end_stream(1) break
self.wfile.write(h2_conn.data_to_send())
self.wfile.flush() def handle_server_event(self, h2_conn, rfile, wfile):
elif isinstance(event, h2.events.ConnectionTerminated): raise NotImplementedError()
return
class PushHttp2Server(netlib_tservers.ServerTestBase): class _Http2TestBase(object):
ssl = dict(alpn_select=b'h2') @classmethod
def setup_class(self):
self.config = ProxyConfig(**self.get_proxy_config())
class handler(netlib.tcp.BaseHandler): tmaster = tservers.TestMaster(self.config)
def handle(self): tmaster.start_app(APP_HOST, APP_PORT)
h2_conn = h2.connection.H2Connection(client_side=False) self.proxy = tservers.ProxyThread(tmaster)
self.proxy.start()
preamble = self.rfile.read(24) @classmethod
h2_conn.initiate_connection() def teardown_class(cls):
h2_conn.receive_data(preamble) cls.proxy.shutdown()
self.wfile.write(h2_conn.data_to_send())
self.wfile.flush()
while True: @property
events = h2_conn.receive_data(utils.http2_read_frame(self.rfile)) def master(self):
self.wfile.write(h2_conn.data_to_send()) return self.proxy.tmaster
self.wfile.flush()
for event in events: @classmethod
if isinstance(event, h2.events.RequestReceived): def get_proxy_config(cls):
h2_conn.send_headers(1, [(':status', '200')]) cls.cadir = os.path.join(tempfile.gettempdir(), "mitmproxy")
h2_conn.push_stream(1, 2, [ return dict(
(':authority', "127.0.0.1:%s" % self.address.port), no_upstream_cert = False,
(':method', 'GET'), cadir = cls.cadir,
(':scheme', 'https'), authenticator = None,
(':path', '/pushed_stream_foo'), )
('foo', 'bar')
])
h2_conn.push_stream(1, 4, [
(':authority', "127.0.0.1:%s" % self.address.port),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/pushed_stream_bar'),
('foo', 'bar')
])
self.wfile.write(h2_conn.data_to_send())
self.wfile.flush()
h2_conn.send_headers(2, [(':status', '202')]) def setup(self):
h2_conn.send_headers(4, [(':status', '204')]) self.master.clear_log()
h2_conn.send_data(1, b'regular_stream') self.master.state.clear()
h2_conn.send_data(2, b'pushed_stream_foo') self.server.server.handle_server_event = self.handle_server_event
h2_conn.send_data(4, b'pushed_stream_bar')
h2_conn.end_stream(1)
h2_conn.end_stream(2)
h2_conn.end_stream(4)
self.wfile.write(h2_conn.data_to_send())
self.wfile.flush()
elif isinstance(event, h2.events.ConnectionTerminated):
return
@requires_alpn
class TestHttp2(tservers.ProxTestBase):
def _setup_connection(self): def _setup_connection(self):
self.config.http2 = True self.config.http2 = True
@ -123,7 +109,7 @@ class TestHttp2(tservers.ProxTestBase):
client.wfile.write( client.wfile.write(
b"CONNECT localhost:%d HTTP/1.1\r\n" b"CONNECT localhost:%d HTTP/1.1\r\n"
b"Host: localhost:%d\r\n" b"Host: localhost:%d\r\n"
b"\r\n" % (self.server.port, self.server.port) b"\r\n" % (self.server.server.address.port, self.server.server.address.port)
) )
client.wfile.flush() client.wfile.flush()
@ -149,14 +135,40 @@ class TestHttp2(tservers.ProxTestBase):
wfile.write(h2_conn.data_to_send()) wfile.write(h2_conn.data_to_send())
wfile.flush() wfile.flush()
def test_simple(self):
self.server = SimpleHttp2Server()
self.server.setup_class()
@requires_alpn
class TestSimple(_Http2TestBase, _Http2ServerBase):
@classmethod
def setup_class(self):
_Http2TestBase.setup_class()
_Http2ServerBase.setup_class()
@classmethod
def teardown_class(self):
_Http2TestBase.teardown_class()
_Http2ServerBase.teardown_class()
@classmethod
def handle_server_event(self, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.RequestReceived):
h2_conn.send_headers(1, [
(':status', '200'),
('foo', 'bar'),
])
h2_conn.send_data(1, b'foobar')
h2_conn.end_stream(1)
wfile.write(h2_conn.data_to_send())
wfile.flush()
return True
def test_simple(self):
client, h2_conn = self._setup_connection() client, h2_conn = self._setup_connection()
self._send_request(client.wfile, h2_conn, headers=[ self._send_request(client.wfile, h2_conn, headers=[
(':authority', "127.0.0.1:%s" % self.server.port), (':authority', "127.0.0.1:%s" % self.server.server.address.port),
(':method', 'GET'), (':method', 'GET'),
(':scheme', 'https'), (':scheme', 'https'),
(':path', '/'), (':path', '/'),
@ -176,21 +188,69 @@ class TestHttp2(tservers.ProxTestBase):
client.wfile.write(h2_conn.data_to_send()) client.wfile.write(h2_conn.data_to_send())
client.wfile.flush() client.wfile.flush()
self.server.teardown_class()
assert len(self.master.state.flows) == 1 assert len(self.master.state.flows) == 1
assert self.master.state.flows[0].response.status_code == 200 assert self.master.state.flows[0].response.status_code == 200
assert self.master.state.flows[0].response.headers['foo'] == 'bar' assert self.master.state.flows[0].response.headers['foo'] == 'bar'
assert self.master.state.flows[0].response.body == b'foobar' assert self.master.state.flows[0].response.body == b'foobar'
def test_pushed_streams(self):
self.server = PushHttp2Server()
self.server.setup_class()
@requires_alpn
class TestPushPromise(_Http2TestBase, _Http2ServerBase):
@classmethod
def setup_class(self):
_Http2TestBase.setup_class()
_Http2ServerBase.setup_class()
@classmethod
def teardown_class(self):
_Http2TestBase.teardown_class()
_Http2ServerBase.teardown_class()
@classmethod
def handle_server_event(self, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.RequestReceived):
if event.stream_id != 1:
# ignore requests initiated by push promises
return True
h2_conn.send_headers(1, [(':status', '200')])
h2_conn.push_stream(1, 2, [
(':authority', "127.0.0.1:%s" % self.port),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/pushed_stream_foo'),
('foo', 'bar')
])
h2_conn.push_stream(1, 4, [
(':authority', "127.0.0.1:%s" % self.port),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/pushed_stream_bar'),
('foo', 'bar')
])
wfile.write(h2_conn.data_to_send())
wfile.flush()
h2_conn.send_headers(2, [(':status', '202')])
h2_conn.send_headers(4, [(':status', '204')])
h2_conn.send_data(1, b'regular_stream')
h2_conn.send_data(2, b'pushed_stream_foo')
h2_conn.send_data(4, b'pushed_stream_bar')
h2_conn.end_stream(1)
h2_conn.end_stream(2)
h2_conn.end_stream(4)
wfile.write(h2_conn.data_to_send())
wfile.flush()
return True
def test_push_promise(self):
client, h2_conn = self._setup_connection() client, h2_conn = self._setup_connection()
self._send_request(client.wfile, h2_conn, headers=[ self._send_request(client.wfile, h2_conn, stream_id=1, headers=[
(':authority', "127.0.0.1:%s" % self.server.port), (':authority', "127.0.0.1:%s" % self.server.server.address.port),
(':method', 'GET'), (':method', 'GET'),
(':scheme', 'https'), (':scheme', 'https'),
(':path', '/'), (':path', '/'),
@ -198,6 +258,7 @@ class TestHttp2(tservers.ProxTestBase):
]) ])
ended_streams = 0 ended_streams = 0
pushed_streams = 0
while ended_streams != 3: while ended_streams != 3:
try: try:
events = h2_conn.receive_data(utils.http2_read_frame(client.rfile)) events = h2_conn.receive_data(utils.http2_read_frame(client.rfile))
@ -209,10 +270,13 @@ class TestHttp2(tservers.ProxTestBase):
for event in events: for event in events:
if isinstance(event, h2.events.StreamEnded): if isinstance(event, h2.events.StreamEnded):
ended_streams += 1 ended_streams += 1
elif isinstance(event, h2.events.PushedStreamReceived):
pushed_streams += 1
self.server.teardown_class() assert pushed_streams == 2
assert len(self.master.state.flows) == 3 bodies = [flow.response.body for flow in self.master.state.flows]
assert self.master.state.flows[0].response.body == b'regular_stream' assert len(bodies) == 3
assert self.master.state.flows[1].response.body == b'pushed_stream_foo' assert b'regular_stream' in bodies
assert self.master.state.flows[2].response.body == b'pushed_stream_bar' assert b'pushed_stream_foo' in bodies
assert b'pushed_stream_bar' in bodies