From cf8c063773b70ad37ab0a2125f5ed03c35e17336 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Tue, 2 Feb 2016 23:54:35 +0100 Subject: [PATCH] fix http2 race condition --- libmproxy/protocol/http2.py | 67 +++++++++++++++++++++++++++++-------- test/test_protocol_http2.py | 66 +++++++++++++++++++++++++++++------- 2 files changed, 107 insertions(+), 26 deletions(-) diff --git a/libmproxy/protocol/http2.py b/libmproxy/protocol/http2.py index e617f77c0..de0688364 100644 --- a/libmproxy/protocol/http2.py +++ b/libmproxy/protocol/http2.py @@ -167,7 +167,8 @@ class Http2Layer(Layer): new_settings = dict([(id, cs.new_value) for (id, cs) in event.changed_settings.iteritems()]) other_conn.h2.safe_update_settings(new_settings) elif isinstance(event, ConnectionTerminated): - other_conn.h2.safe_close_connection(event.error_code) + # Do not immediately terminate the other connection. + # Some streams might be still sending data to the client. return False elif isinstance(event, PushedStreamReceived): # pushed stream ids should be uniq and not dependent on race conditions @@ -183,7 +184,7 @@ class Http2Layer(Layer): self.streams[event.pushed_stream_id].pushed = True self.streams[event.pushed_stream_id].parent_stream_id = parent_eid self.streams[event.pushed_stream_id].timestamp_end = time.time() - self.streams[event.pushed_stream_id].data_finished.set() + self.streams[event.pushed_stream_id].request_data_finished.set() self.streams[event.pushed_stream_id].start() elif isinstance(event, TrailersReceived): raise NotImplementedError() @@ -240,18 +241,50 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread): self.server_stream_id = None self.request_headers = request_headers self.response_headers = None - self.data_queue = Queue.Queue() - self.queued_data_length = 0 + self.pushed = False + + self.request_data_queue = Queue.Queue() + self.request_queued_data_length = 0 + self.request_data_finished = threading.Event() self.response_arrived = threading.Event() - self.data_finished = threading.Event() + self.response_data_queue = Queue.Queue() + self.response_queued_data_length = 0 + self.response_data_finished = threading.Event() + + @property + def data_queue(self): + if self.response_arrived.is_set(): + return self.response_data_queue + else: + return self.request_data_queue + + @property + def queued_data_length(self): + if self.response_arrived.is_set(): + return self.response_queued_data_length + else: + return self.request_queued_data_length + + @property + def data_finished(self): + if self.response_arrived.is_set(): + return self.response_data_finished + else: + return self.request_data_finished + + @queued_data_length.setter + def queued_data_length(self, v): + if self.response_arrived.is_set(): + return self.response_queued_data_length + else: + return self.request_queued_data_length def is_zombie(self): return self.zombie is not None def read_request(self): - self.data_finished.wait() - self.data_finished.clear() + self.request_data_finished.wait() authority = self.request_headers.get(':authority', '') method = self.request_headers.get(':method', 'GET') @@ -279,8 +312,8 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread): port = int(port) data = [] - while self.data_queue.qsize() > 0: - data.append(self.data_queue.get()) + while self.request_data_queue.qsize() > 0: + data.append(self.request_data_queue.get()) data = b"".join(data) return HTTPRequest( @@ -298,9 +331,15 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread): ) def send_request(self, message): - if self.zombie: + if self.pushed: + # nothing to do here return + with self.server_conn.h2.lock: + # We must not assign a stream id if we are already a zombie. + if self.zombie: + return + self.server_stream_id = self.server_conn.h2.get_next_available_stream_id() self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id @@ -333,12 +372,12 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread): def read_response_body(self, request, response): while True: try: - yield self.data_queue.get(timeout=1) + yield self.response_data_queue.get(timeout=1) except Queue.Empty: pass - if self.data_finished.is_set(): - while self.data_queue.qsize() > 0: - yield self.data_queue.get() + if self.response_data_finished.is_set(): + while self.response_data_queue.qsize() > 0: + yield self.response_data_queue.get() return if self.zombie: return diff --git a/test/test_protocol_http2.py b/test/test_protocol_http2.py index cc62f734b..38cfdfc3c 100644 --- a/test/test_protocol_http2.py +++ b/test/test_protocol_http2.py @@ -5,6 +5,7 @@ import pytest import traceback import os import tempfile +import sys from libmproxy.proxy.config import ProxyConfig from libmproxy.proxy.server import ProxyServer @@ -47,9 +48,11 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase): self.wfile.write(h2_conn.data_to_send()) self.wfile.flush() - while True: + done = False + while not done: try: - events = h2_conn.receive_data(b''.join(http2_read_raw_frame(self.rfile))) + raw = b''.join(http2_read_raw_frame(self.rfile)) + events = h2_conn.receive_data(raw) except: break self.wfile.write(h2_conn.data_to_send()) @@ -58,10 +61,12 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase): for event in events: try: if not self.server.handle_server_event(event, h2_conn, self.rfile, self.wfile): + done = True break except Exception as e: print(repr(e)) print(traceback.format_exc()) + done = True break def handle_server_event(self, h2_conn, rfile, wfile): @@ -182,7 +187,10 @@ class TestSimple(_Http2TestBase, _Http2ServerBase): done = False while not done: - events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile))) + try: + events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile))) + except: + break client.wfile.write(h2_conn.data_to_send()) client.wfile.flush() @@ -248,7 +256,10 @@ class TestWithBodies(_Http2TestBase, _Http2ServerBase): done = False while not done: - events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile))) + try: + events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile))) + except: + break client.wfile.write(h2_conn.data_to_send()) client.wfile.flush() @@ -303,14 +314,16 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase): wfile.write(h2_conn.data_to_send()) wfile.flush() - h2_conn.send_headers(2, [(':status', '202')]) - h2_conn.send_headers(4, [(':status', '204')]) + h2_conn.send_headers(2, [(':status', '200')]) + h2_conn.send_headers(4, [(':status', '200')]) wfile.write(h2_conn.data_to_send()) wfile.flush() h2_conn.send_data(1, b'regular_stream') h2_conn.send_data(2, b'pushed_stream_foo') h2_conn.send_data(4, b'pushed_stream_bar') + wfile.write(h2_conn.data_to_send()) + wfile.flush() h2_conn.end_stream(1) h2_conn.end_stream(2) h2_conn.end_stream(4) @@ -330,11 +343,14 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase): ('foo', 'bar') ]) + done = False ended_streams = 0 pushed_streams = 0 - while ended_streams != 3: + responses = 0 + while not done: try: - events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile))) + raw = b''.join(http2_read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) except: break client.wfile.write(h2_conn.data_to_send()) @@ -345,7 +361,19 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase): ended_streams += 1 elif isinstance(event, h2.events.PushedStreamReceived): pushed_streams += 1 + elif isinstance(event, h2.events.ResponseReceived): + responses += 1 + if isinstance(event, h2.events.ConnectionTerminated): + done = True + if responses == 3 and ended_streams == 3 and pushed_streams == 2: + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert ended_streams == 3 assert pushed_streams == 2 bodies = [flow.response.body for flow in self.master.state.flows] @@ -365,8 +393,11 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase): ('foo', 'bar') ]) - streams = 0 - while streams != 3: + done = False + ended_streams = 0 + pushed_streams = 0 + responses = 0 + while not done: try: events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile))) except: @@ -376,12 +407,23 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase): for event in events: if isinstance(event, h2.events.StreamEnded) and event.stream_id == 1: - streams += 1 + ended_streams += 1 elif isinstance(event, h2.events.PushedStreamReceived): - streams += 1 + pushed_streams += 1 h2_conn.reset_stream(event.pushed_stream_id, error_code=0x8) client.wfile.write(h2_conn.data_to_send()) client.wfile.flush() + elif isinstance(event, h2.events.ResponseReceived): + responses += 1 + if isinstance(event, h2.events.ConnectionTerminated): + done = True + + if responses >= 1 and ended_streams >= 1 and pushed_streams == 2: + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() bodies = [flow.response.body for flow in self.master.state.flows if flow.response] assert len(bodies) >= 1