From a5aa16e03b00e5715c6b5dfaecab10537776d891 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sat, 11 Jun 2016 12:10:17 +0200 Subject: [PATCH] fix #1240 --- mitmproxy/protocol/http2.py | 136 ++++++++++++++++++++++-------------- 1 file changed, 82 insertions(+), 54 deletions(-) diff --git a/mitmproxy/protocol/http2.py b/mitmproxy/protocol/http2.py index 957b8d64a..b9a30c7e1 100644 --- a/mitmproxy/protocol/http2.py +++ b/mitmproxy/protocol/http2.py @@ -56,11 +56,11 @@ class SafeH2Connection(connection.H2Connection): self.conn.send(self.data_to_send()) def safe_send_headers(self, is_zombie, stream_id, headers): - with self.lock: - if is_zombie(): # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") - self.send_headers(stream_id, headers.fields) - self.conn.send(self.data_to_send()) + # make sure to have a lock + if is_zombie(): # pragma: no cover + raise exceptions.Http2ProtocolException("Zombie Stream") + self.send_headers(stream_id, headers.fields) + self.conn.send(self.data_to_send()) def safe_send_body(self, is_zombie, stream_id, chunks): for chunk in chunks: @@ -77,8 +77,12 @@ class SafeH2Connection(connection.H2Connection): time.sleep(0.1) continue self.send_data(stream_id, frame_chunk) - self.conn.send(self.data_to_send()) - self.lock.release() + try: + self.conn.send(self.data_to_send()) + except Exception as e: + raise e + finally: + self.lock.release() position += max_outbound_frame_size with self.lock: if is_zombie(): # pragma: no cover @@ -225,6 +229,9 @@ class Http2Layer(base.Layer): for stream in self.streams.values(): if not stream.zombie: stream.zombie = time.time() + stream.request_data_finished.set() + stream.response_arrived.set() + stream.data_finished.set() def __call__(self): if self.server_conn: @@ -235,31 +242,36 @@ class Http2Layer(base.Layer): self.client_conn.h2.receive_data(preamble) self.client_conn.send(self.client_conn.h2.data_to_send()) - while True: - r = tcp.ssl_read_select(self.active_conns, 1) - for conn in r: - source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn - other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn - is_server = (conn == self.server_conn.connection) + try: + while True: + r = tcp.ssl_read_select(self.active_conns, 1) + for conn in r: + source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn + other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn + is_server = (conn == self.server_conn.connection) - with source_conn.h2.lock: - try: - raw_frame = b''.join(http2.framereader.http2_read_raw_frame(source_conn.rfile)) - except: - # read frame failed: connection closed - self._kill_all_streams() - return - - incoming_events = source_conn.h2.receive_data(raw_frame) - source_conn.send(source_conn.h2.data_to_send()) - - for event in incoming_events: - if not self._handle_event(event, source_conn, other_conn, is_server): - # connection terminated: GoAway + with source_conn.h2.lock: + try: + raw_frame = b''.join(http2.framereader.http2_read_raw_frame(source_conn.rfile)) + except: + # read frame failed: connection closed self._kill_all_streams() return - self._cleanup_streams() + incoming_events = source_conn.h2.receive_data(raw_frame) + source_conn.send(source_conn.h2.data_to_send()) + + for event in incoming_events: + if not self._handle_event(event, source_conn, other_conn, is_server): + # connection terminated: GoAway + self._kill_all_streams() + return + + self._cleanup_streams() + except Exception as e: + self.log(repr(e), "info") + self.log(traceback.format_exc(), "debug") + self._kill_all_streams() class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread): @@ -315,6 +327,9 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) def read_request(self): self.request_data_finished.wait() + if self.zombie: # pragma: no cover + raise exceptions.Http2ProtocolException("Zombie Stream") + authority = self.request_headers.get(':authority', '') method = self.request_headers.get(':method', 'GET') scheme = self.request_headers.get(':scheme', 'https') @@ -366,31 +381,32 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) raise NotImplementedError() def send_request(self, message): - if not hasattr(self.server_conn, 'h2'): - raise exceptions.Http2ProtocolException("Zombie Stream") + if self.pushed: + # nothing to do here + return while True: + if self.zombie: # pragma: no cover + raise exceptions.Http2ProtocolException("Zombie Stream") + self.server_conn.h2.lock.acquire() + max_streams = self.server_conn.h2.remote_settings.max_concurrent_streams if self.server_conn.h2.open_outbound_streams + 1 >= max_streams: # wait until we get a free slot for a new outgoing stream self.server_conn.h2.lock.release() time.sleep(0.1) - else: - break + continue - if self.pushed: - # nothing to do here - self.server_conn.h2.lock.release() - return + # keep the lock + break - with self.server_conn.h2.lock: - # We must not assign a stream id if we are already a zombie. - if self.zombie: # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") + # We must not assign a stream id if we are already a zombie. + if self.zombie: # pragma: no cover + raise exceptions.Http2ProtocolException("Zombie Stream") - self.server_stream_id = self.server_conn.h2.get_next_available_stream_id() - self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id + self.server_stream_id = self.server_conn.h2.get_next_available_stream_id() + self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id headers = message.headers.copy() headers.insert(0, ":path", message.path) @@ -398,12 +414,17 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) headers.insert(0, ":scheme", message.scheme) self.server_stream_id = self.server_conn.h2.get_next_available_stream_id() self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id - self.server_conn.h2.safe_send_headers( - self.is_zombie, - self.server_stream_id, - headers, - ) - self.server_conn.h2.lock.release() + + try: + self.server_conn.h2.safe_send_headers( + self.is_zombie, + self.server_stream_id, + headers, + ) + except Exception as e: + raise e + finally: + self.server_conn.h2.lock.release() self.server_conn.h2.safe_send_body( self.is_zombie, @@ -416,6 +437,9 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) def read_response_headers(self): self.response_arrived.wait() + if self.zombie: # pragma: no cover + raise exceptions.Http2ProtocolException("Zombie Stream") + status_code = int(self.response_headers.get(':status', 502)) headers = self.response_headers.copy() headers.clear(":status") @@ -437,6 +461,8 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) except queue.Empty: pass if self.response_data_finished.is_set(): + if self.zombie: # pragma: no cover + raise exceptions.Http2ProtocolException("Zombie Stream") while self.response_data_queue.qsize() > 0: yield self.response_data_queue.get() break @@ -446,11 +472,12 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) def send_response_headers(self, response): headers = response.headers.copy() headers.insert(0, ":status", str(response.status_code)) - self.client_conn.h2.safe_send_headers( - self.is_zombie, - self.client_stream_id, - headers - ) + with self.client_conn.h2.lock: + self.client_conn.h2.safe_send_headers( + self.is_zombie, + self.client_stream_id, + headers + ) if self.zombie: # pragma: no cover raise exceptions.Http2ProtocolException("Zombie Stream") @@ -484,4 +511,5 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) self.log(repr(e), "info") self.log(traceback.format_exc(), "debug") - self.zombie = time.time() + if not self.zombie: + self.zombie = time.time()