This commit is contained in:
Thomas Kriechbaumer 2016-06-11 12:10:17 +02:00
parent e0d6434b27
commit a5aa16e03b

View File

@ -56,11 +56,11 @@ class SafeH2Connection(connection.H2Connection):
self.conn.send(self.data_to_send())
def safe_send_headers(self, is_zombie, stream_id, headers):
with self.lock:
if is_zombie(): # pragma: no cover
raise exceptions.Http2ProtocolException("Zombie Stream")
self.send_headers(stream_id, headers.fields)
self.conn.send(self.data_to_send())
# make sure to have a lock
if is_zombie(): # pragma: no cover
raise exceptions.Http2ProtocolException("Zombie Stream")
self.send_headers(stream_id, headers.fields)
self.conn.send(self.data_to_send())
def safe_send_body(self, is_zombie, stream_id, chunks):
for chunk in chunks:
@ -77,8 +77,12 @@ class SafeH2Connection(connection.H2Connection):
time.sleep(0.1)
continue
self.send_data(stream_id, frame_chunk)
self.conn.send(self.data_to_send())
self.lock.release()
try:
self.conn.send(self.data_to_send())
except Exception as e:
raise e
finally:
self.lock.release()
position += max_outbound_frame_size
with self.lock:
if is_zombie(): # pragma: no cover
@ -225,6 +229,9 @@ class Http2Layer(base.Layer):
for stream in self.streams.values():
if not stream.zombie:
stream.zombie = time.time()
stream.request_data_finished.set()
stream.response_arrived.set()
stream.data_finished.set()
def __call__(self):
if self.server_conn:
@ -235,31 +242,36 @@ class Http2Layer(base.Layer):
self.client_conn.h2.receive_data(preamble)
self.client_conn.send(self.client_conn.h2.data_to_send())
while True:
r = tcp.ssl_read_select(self.active_conns, 1)
for conn in r:
source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn
other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn
is_server = (conn == self.server_conn.connection)
try:
while True:
r = tcp.ssl_read_select(self.active_conns, 1)
for conn in r:
source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn
other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn
is_server = (conn == self.server_conn.connection)
with source_conn.h2.lock:
try:
raw_frame = b''.join(http2.framereader.http2_read_raw_frame(source_conn.rfile))
except:
# read frame failed: connection closed
self._kill_all_streams()
return
incoming_events = source_conn.h2.receive_data(raw_frame)
source_conn.send(source_conn.h2.data_to_send())
for event in incoming_events:
if not self._handle_event(event, source_conn, other_conn, is_server):
# connection terminated: GoAway
with source_conn.h2.lock:
try:
raw_frame = b''.join(http2.framereader.http2_read_raw_frame(source_conn.rfile))
except:
# read frame failed: connection closed
self._kill_all_streams()
return
self._cleanup_streams()
incoming_events = source_conn.h2.receive_data(raw_frame)
source_conn.send(source_conn.h2.data_to_send())
for event in incoming_events:
if not self._handle_event(event, source_conn, other_conn, is_server):
# connection terminated: GoAway
self._kill_all_streams()
return
self._cleanup_streams()
except Exception as e:
self.log(repr(e), "info")
self.log(traceback.format_exc(), "debug")
self._kill_all_streams()
class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread):
@ -315,6 +327,9 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
def read_request(self):
self.request_data_finished.wait()
if self.zombie: # pragma: no cover
raise exceptions.Http2ProtocolException("Zombie Stream")
authority = self.request_headers.get(':authority', '')
method = self.request_headers.get(':method', 'GET')
scheme = self.request_headers.get(':scheme', 'https')
@ -366,31 +381,32 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
raise NotImplementedError()
def send_request(self, message):
if not hasattr(self.server_conn, 'h2'):
raise exceptions.Http2ProtocolException("Zombie Stream")
if self.pushed:
# nothing to do here
return
while True:
if self.zombie: # pragma: no cover
raise exceptions.Http2ProtocolException("Zombie Stream")
self.server_conn.h2.lock.acquire()
max_streams = self.server_conn.h2.remote_settings.max_concurrent_streams
if self.server_conn.h2.open_outbound_streams + 1 >= max_streams:
# wait until we get a free slot for a new outgoing stream
self.server_conn.h2.lock.release()
time.sleep(0.1)
else:
break
continue
if self.pushed:
# nothing to do here
self.server_conn.h2.lock.release()
return
# keep the lock
break
with self.server_conn.h2.lock:
# We must not assign a stream id if we are already a zombie.
if self.zombie: # pragma: no cover
raise exceptions.Http2ProtocolException("Zombie Stream")
# We must not assign a stream id if we are already a zombie.
if self.zombie: # pragma: no cover
raise exceptions.Http2ProtocolException("Zombie Stream")
self.server_stream_id = self.server_conn.h2.get_next_available_stream_id()
self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id
self.server_stream_id = self.server_conn.h2.get_next_available_stream_id()
self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id
headers = message.headers.copy()
headers.insert(0, ":path", message.path)
@ -398,12 +414,17 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
headers.insert(0, ":scheme", message.scheme)
self.server_stream_id = self.server_conn.h2.get_next_available_stream_id()
self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id
self.server_conn.h2.safe_send_headers(
self.is_zombie,
self.server_stream_id,
headers,
)
self.server_conn.h2.lock.release()
try:
self.server_conn.h2.safe_send_headers(
self.is_zombie,
self.server_stream_id,
headers,
)
except Exception as e:
raise e
finally:
self.server_conn.h2.lock.release()
self.server_conn.h2.safe_send_body(
self.is_zombie,
@ -416,6 +437,9 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
def read_response_headers(self):
self.response_arrived.wait()
if self.zombie: # pragma: no cover
raise exceptions.Http2ProtocolException("Zombie Stream")
status_code = int(self.response_headers.get(':status', 502))
headers = self.response_headers.copy()
headers.clear(":status")
@ -437,6 +461,8 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
except queue.Empty:
pass
if self.response_data_finished.is_set():
if self.zombie: # pragma: no cover
raise exceptions.Http2ProtocolException("Zombie Stream")
while self.response_data_queue.qsize() > 0:
yield self.response_data_queue.get()
break
@ -446,11 +472,12 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
def send_response_headers(self, response):
headers = response.headers.copy()
headers.insert(0, ":status", str(response.status_code))
self.client_conn.h2.safe_send_headers(
self.is_zombie,
self.client_stream_id,
headers
)
with self.client_conn.h2.lock:
self.client_conn.h2.safe_send_headers(
self.is_zombie,
self.client_stream_id,
headers
)
if self.zombie: # pragma: no cover
raise exceptions.Http2ProtocolException("Zombie Stream")
@ -484,4 +511,5 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
self.log(repr(e), "info")
self.log(traceback.format_exc(), "debug")
self.zombie = time.time()
if not self.zombie:
self.zombie = time.time()