From 276817e40e99dbb2ddc7638839bd74e944fd704e Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Tue, 26 Jan 2016 13:15:20 +0100 Subject: [PATCH] refactor http2 tests --- test/test_protocol_http2.py | 222 +++++++++++++++++++++++------------- 1 file changed, 143 insertions(+), 79 deletions(-) diff --git a/test/test_protocol_http2.py b/test/test_protocol_http2.py index 17687b455..b42b86cb6 100644 --- a/test/test_protocol_http2.py +++ b/test/test_protocol_http2.py @@ -4,8 +4,16 @@ import inspect import socket import OpenSSL import pytest +import traceback +import os +import tempfile + from io import BytesIO +from libmproxy.proxy.config import ProxyConfig +from libmproxy.proxy.server import ProxyServer +from libmproxy.cmdline import APP_HOST, APP_PORT + import logging logging.getLogger("hyper.packages.hpack.hpack").setLevel(logging.WARNING) logging.getLogger("requests.packages.urllib3.connectionpool").setLevel(logging.WARNING) @@ -18,6 +26,7 @@ import netlib from netlib import tservers as netlib_tservers import h2 +from hyperframe.frame import Frame from libmproxy import utils from . import tservers @@ -26,8 +35,7 @@ requires_alpn = pytest.mark.skipif( not OpenSSL._util.lib.Cryptography_HAS_ALPN, reason="requires OpenSSL with ALPN support") - -class SimpleHttp2Server(netlib_tservers.ServerTestBase): +class _Http2ServerBase(netlib_tservers.ServerTestBase): ssl = dict(alpn_select=b'h2') class handler(netlib.tcp.BaseHandler): @@ -41,78 +49,56 @@ class SimpleHttp2Server(netlib_tservers.ServerTestBase): self.wfile.flush() while True: - events = h2_conn.receive_data(utils.http2_read_frame(self.rfile)) + raw_frame = utils.http2_read_frame(self.rfile) + events = h2_conn.receive_data(raw_frame) self.wfile.write(h2_conn.data_to_send()) self.wfile.flush() for event in events: - if isinstance(event, h2.events.RequestReceived): - h2_conn.send_headers(1, [ - (':status', '200'), - ('foo', 'bar'), - ]) - h2_conn.send_data(1, b'foobar') - h2_conn.end_stream(1) - self.wfile.write(h2_conn.data_to_send()) - self.wfile.flush() - elif isinstance(event, h2.events.ConnectionTerminated): - return + try: + if not self.server.handle_server_event(event, h2_conn, self.rfile, self.wfile): + break + except Exception as e: + print(repr(e)) + print(traceback.format_exc()) + break + + def handle_server_event(self, h2_conn, rfile, wfile): + raise NotImplementedError() -class PushHttp2Server(netlib_tservers.ServerTestBase): - ssl = dict(alpn_select=b'h2') +class _Http2TestBase(object): + @classmethod + def setup_class(self): + self.config = ProxyConfig(**self.get_proxy_config()) - class handler(netlib.tcp.BaseHandler): - def handle(self): - h2_conn = h2.connection.H2Connection(client_side=False) + tmaster = tservers.TestMaster(self.config) + tmaster.start_app(APP_HOST, APP_PORT) + self.proxy = tservers.ProxyThread(tmaster) + self.proxy.start() - preamble = self.rfile.read(24) - h2_conn.initiate_connection() - h2_conn.receive_data(preamble) - self.wfile.write(h2_conn.data_to_send()) - self.wfile.flush() + @classmethod + def teardown_class(cls): + cls.proxy.shutdown() - while True: - events = h2_conn.receive_data(utils.http2_read_frame(self.rfile)) - self.wfile.write(h2_conn.data_to_send()) - self.wfile.flush() + @property + def master(self): + return self.proxy.tmaster - for event in events: - if isinstance(event, h2.events.RequestReceived): - h2_conn.send_headers(1, [(':status', '200')]) - h2_conn.push_stream(1, 2, [ - (':authority', "127.0.0.1:%s" % self.address.port), - (':method', 'GET'), - (':scheme', 'https'), - (':path', '/pushed_stream_foo'), - ('foo', 'bar') - ]) - h2_conn.push_stream(1, 4, [ - (':authority', "127.0.0.1:%s" % self.address.port), - (':method', 'GET'), - (':scheme', 'https'), - (':path', '/pushed_stream_bar'), - ('foo', 'bar') - ]) - self.wfile.write(h2_conn.data_to_send()) - self.wfile.flush() + @classmethod + def get_proxy_config(cls): + cls.cadir = os.path.join(tempfile.gettempdir(), "mitmproxy") + return dict( + no_upstream_cert = False, + cadir = cls.cadir, + authenticator = None, + ) - h2_conn.send_headers(2, [(':status', '202')]) - h2_conn.send_headers(4, [(':status', '204')]) - h2_conn.send_data(1, b'regular_stream') - h2_conn.send_data(2, b'pushed_stream_foo') - h2_conn.send_data(4, b'pushed_stream_bar') - h2_conn.end_stream(1) - h2_conn.end_stream(2) - h2_conn.end_stream(4) - self.wfile.write(h2_conn.data_to_send()) - self.wfile.flush() - elif isinstance(event, h2.events.ConnectionTerminated): - return + def setup(self): + self.master.clear_log() + self.master.state.clear() + self.server.server.handle_server_event = self.handle_server_event - -@requires_alpn -class TestHttp2(tservers.ProxTestBase): def _setup_connection(self): self.config.http2 = True @@ -123,7 +109,7 @@ class TestHttp2(tservers.ProxTestBase): client.wfile.write( b"CONNECT localhost:%d HTTP/1.1\r\n" b"Host: localhost:%d\r\n" - b"\r\n" % (self.server.port, self.server.port) + b"\r\n" % (self.server.server.address.port, self.server.server.address.port) ) client.wfile.flush() @@ -149,14 +135,40 @@ class TestHttp2(tservers.ProxTestBase): wfile.write(h2_conn.data_to_send()) wfile.flush() - def test_simple(self): - self.server = SimpleHttp2Server() - self.server.setup_class() +@requires_alpn +class TestSimple(_Http2TestBase, _Http2ServerBase): + @classmethod + def setup_class(self): + _Http2TestBase.setup_class() + _Http2ServerBase.setup_class() + + @classmethod + def teardown_class(self): + _Http2TestBase.teardown_class() + _Http2ServerBase.teardown_class() + + @classmethod + def handle_server_event(self, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.RequestReceived): + h2_conn.send_headers(1, [ + (':status', '200'), + ('foo', 'bar'), + ]) + h2_conn.send_data(1, b'foobar') + h2_conn.end_stream(1) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + return True + + def test_simple(self): client, h2_conn = self._setup_connection() self._send_request(client.wfile, h2_conn, headers=[ - (':authority', "127.0.0.1:%s" % self.server.port), + (':authority', "127.0.0.1:%s" % self.server.server.address.port), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), @@ -176,21 +188,69 @@ class TestHttp2(tservers.ProxTestBase): client.wfile.write(h2_conn.data_to_send()) client.wfile.flush() - self.server.teardown_class() - assert len(self.master.state.flows) == 1 assert self.master.state.flows[0].response.status_code == 200 assert self.master.state.flows[0].response.headers['foo'] == 'bar' assert self.master.state.flows[0].response.body == b'foobar' - def test_pushed_streams(self): - self.server = PushHttp2Server() - self.server.setup_class() +@requires_alpn +class TestPushPromise(_Http2TestBase, _Http2ServerBase): + @classmethod + def setup_class(self): + _Http2TestBase.setup_class() + _Http2ServerBase.setup_class() + + @classmethod + def teardown_class(self): + _Http2TestBase.teardown_class() + _Http2ServerBase.teardown_class() + + @classmethod + def handle_server_event(self, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.RequestReceived): + if event.stream_id != 1: + # ignore requests initiated by push promises + return True + + h2_conn.send_headers(1, [(':status', '200')]) + h2_conn.push_stream(1, 2, [ + (':authority', "127.0.0.1:%s" % self.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/pushed_stream_foo'), + ('foo', 'bar') + ]) + h2_conn.push_stream(1, 4, [ + (':authority', "127.0.0.1:%s" % self.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/pushed_stream_bar'), + ('foo', 'bar') + ]) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + h2_conn.send_headers(2, [(':status', '202')]) + h2_conn.send_headers(4, [(':status', '204')]) + h2_conn.send_data(1, b'regular_stream') + h2_conn.send_data(2, b'pushed_stream_foo') + h2_conn.send_data(4, b'pushed_stream_bar') + h2_conn.end_stream(1) + h2_conn.end_stream(2) + h2_conn.end_stream(4) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + return True + + def test_push_promise(self): client, h2_conn = self._setup_connection() - self._send_request(client.wfile, h2_conn, headers=[ - (':authority', "127.0.0.1:%s" % self.server.port), + self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ + (':authority', "127.0.0.1:%s" % self.server.server.address.port), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), @@ -198,6 +258,7 @@ class TestHttp2(tservers.ProxTestBase): ]) ended_streams = 0 + pushed_streams = 0 while ended_streams != 3: try: events = h2_conn.receive_data(utils.http2_read_frame(client.rfile)) @@ -209,10 +270,13 @@ class TestHttp2(tservers.ProxTestBase): for event in events: if isinstance(event, h2.events.StreamEnded): ended_streams += 1 + elif isinstance(event, h2.events.PushedStreamReceived): + pushed_streams += 1 - self.server.teardown_class() + assert pushed_streams == 2 - assert len(self.master.state.flows) == 3 - assert self.master.state.flows[0].response.body == b'regular_stream' - assert self.master.state.flows[1].response.body == b'pushed_stream_foo' - assert self.master.state.flows[2].response.body == b'pushed_stream_bar' + bodies = [flow.response.body for flow in self.master.state.flows] + assert len(bodies) == 3 + assert b'regular_stream' in bodies + assert b'pushed_stream_foo' in bodies + assert b'pushed_stream_bar' in bodies