improved zombie detection

This commit is contained in:
Thomas Kriechbaumer 2016-01-16 11:31:43 +01:00
parent 3f44eff143
commit 947f79eb6c

View File

@ -157,7 +157,7 @@ class SafeH2Connection(H2Connection):
with self.lock: with self.lock:
try: try:
self.reset_stream(stream_id, error_code) self.reset_stream(stream_id, error_code)
except StreamClosedError: except h2.exceptions.ProtocolError:
# stream is already closed - good # stream is already closed - good
pass pass
self.conn.send(self.data_to_send()) self.conn.send(self.data_to_send())
@ -172,30 +172,33 @@ class SafeH2Connection(H2Connection):
self.update_settings(new_settings) self.update_settings(new_settings)
self.conn.send(self.data_to_send()) self.conn.send(self.data_to_send())
def safe_send_headers(self, stream_id, headers): def safe_send_headers(self, is_zombie, stream_id, headers):
with self.lock: with self.lock:
if is_zombie(self, stream_id):
return
self.send_headers(stream_id, headers) self.send_headers(stream_id, headers)
self.conn.send(self.data_to_send()) self.conn.send(self.data_to_send())
def safe_send_body(self, stream_id, chunks): def safe_send_body(self, is_zombie, stream_id, chunks):
# TODO: this assumes the MAX_FRAME_SIZE does not change in the middle
# of a transfer - it could though. Then we need to re-chunk everything.
for chunk in chunks: for chunk in chunks:
max_outbound_frame_size = self.max_outbound_frame_size position = 0
for i in xrange(0, len(chunk), max_outbound_frame_size): while position < len(chunk):
frame_chunk = chunk[i:i+max_outbound_frame_size]
self.lock.acquire() self.lock.acquire()
while True: max_outbound_frame_size = self.max_outbound_frame_size
if self.local_flow_control_window(stream_id) < len(frame_chunk): frame_chunk = chunk[position:position+max_outbound_frame_size]
self.lock.release() if self.local_flow_control_window(stream_id) < len(frame_chunk):
time.sleep(0) self.lock.release()
else: time.sleep(0)
break continue
if is_zombie(self, stream_id):
return
self.send_data(stream_id, frame_chunk) self.send_data(stream_id, frame_chunk)
self.conn.send(self.data_to_send()) self.conn.send(self.data_to_send())
self.lock.release() self.lock.release()
position += max_outbound_frame_size
with self.lock: with self.lock:
if is_zombie(self, stream_id):
return
self.end_stream(stream_id) self.end_stream(stream_id)
self.conn.send(self.data_to_send()) self.conn.send(self.data_to_send())
@ -254,45 +257,45 @@ class Http2Layer(Layer):
events = source_conn.h2.receive_data(raw_frame) events = source_conn.h2.receive_data(raw_frame)
source_conn.send(source_conn.h2.data_to_send()) source_conn.send(source_conn.h2.data_to_send())
for event in events: for event in events:
if hasattr(event, 'stream_id'): if hasattr(event, 'stream_id'):
if is_server:
eid = self.server_to_client_stream_ids[event.stream_id]
else:
eid = event.stream_id
if isinstance(event, RequestReceived):
headers = Headers([[str(k), str(v)] for k, v in event.headers])
self.streams[eid] = Http2SingleStreamLayer(self, eid, headers)
self.streams[eid].start()
elif isinstance(event, ResponseReceived):
headers = Headers([[str(k), str(v)] for k, v in event.headers])
self.streams[eid].response_headers = headers
self.streams[eid].response_arrived.set()
elif isinstance(event, DataReceived):
self.streams[eid].data_queue.put(event.data)
source_conn.h2.safe_increment_flow_control(event.stream_id, len(event.data))
elif isinstance(event, StreamEnded):
self.streams[eid].data_finished.set()
elif isinstance(event, StreamReset):
self.streams[eid].zombie = time.time()
if eid in self.streams and event.error_code == 0x8:
if is_server: if is_server:
other_stream_id = self.streams[eid].client_stream_id eid = self.server_to_client_stream_ids[event.stream_id]
else: else:
other_stream_id = self.streams[eid].server_stream_id eid = event.stream_id
other_conn.h2.safe_reset_stream(other_stream_id, event.error_code)
elif isinstance(event, RemoteSettingsChanged): if isinstance(event, RequestReceived):
source_conn.h2.safe_acknowledge_settings(event) headers = Headers([[str(k), str(v)] for k, v in event.headers])
new_settings = dict([(id, cs.new_value) for (id, cs) in event.changed_settings.iteritems()]) self.streams[eid] = Http2SingleStreamLayer(self, eid, headers)
other_conn.h2.safe_update_settings(new_settings) self.streams[eid].start()
elif isinstance(event, ConnectionTerminated): elif isinstance(event, ResponseReceived):
other_conn.h2.safe_close_connection(event.error_code) headers = Headers([[str(k), str(v)] for k, v in event.headers])
return self.streams[eid].response_headers = headers
elif isinstance(event, TrailersReceived): self.streams[eid].response_arrived.set()
raise NotImplementedError() elif isinstance(event, DataReceived):
elif isinstance(event, PushedStreamReceived): self.streams[eid].data_queue.put(event.data)
raise NotImplementedError() source_conn.h2.safe_increment_flow_control(event.stream_id, len(event.data))
elif isinstance(event, StreamEnded):
self.streams[eid].data_finished.set()
elif isinstance(event, StreamReset):
self.streams[eid].zombie = time.time()
if eid in self.streams and event.error_code == 0x8:
if is_server:
other_stream_id = self.streams[eid].client_stream_id
else:
other_stream_id = self.streams[eid].server_stream_id
other_conn.h2.safe_reset_stream(other_stream_id, event.error_code)
elif isinstance(event, RemoteSettingsChanged):
source_conn.h2.safe_acknowledge_settings(event)
new_settings = dict([(id, cs.new_value) for (id, cs) in event.changed_settings.iteritems()])
other_conn.h2.safe_update_settings(new_settings)
elif isinstance(event, ConnectionTerminated):
other_conn.h2.safe_close_connection(event.error_code)
return
elif isinstance(event, TrailersReceived):
raise NotImplementedError()
elif isinstance(event, PushedStreamReceived):
raise NotImplementedError()
death_time = time.time() - 10 death_time = time.time() - 10
for stream_id in self.streams.keys(): for stream_id in self.streams.keys():
@ -314,6 +317,18 @@ class Http2SingleStreamLayer(_HttpLayer, threading.Thread):
self.response_arrived = threading.Event() self.response_arrived = threading.Event()
self.data_finished = threading.Event() self.data_finished = threading.Event()
def is_zombie(self, h2_conn, stream_id):
if self.zombie:
return True
try:
h2_conn._get_stream_by_id(stream_id)
except Exception as e:
if isinstance(e, h2.exceptions.StreamClosedError):
return true
return False
def read_request(self): def read_request(self):
self.data_finished.wait() self.data_finished.wait()
self.data_finished.clear() self.data_finished.clear()
@ -364,18 +379,17 @@ class Http2SingleStreamLayer(_HttpLayer, threading.Thread):
) )
def send_request(self, message): def send_request(self, message):
if self.zombie:
return
with self.server_conn.h2.lock: with self.server_conn.h2.lock:
self.server_stream_id = self.server_conn.h2.get_next_available_stream_id() self.server_stream_id = self.server_conn.h2.get_next_available_stream_id()
self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id
self.server_conn.h2.safe_send_headers( self.server_conn.h2.safe_send_headers(
self.is_zombie,
self.server_stream_id, self.server_stream_id,
message.headers message.headers
) )
self.server_conn.h2.safe_send_body( self.server_conn.h2.safe_send_body(
self.is_zombie,
self.server_stream_id, self.server_stream_id,
message.body message.body
) )
@ -409,19 +423,15 @@ class Http2SingleStreamLayer(_HttpLayer, threading.Thread):
return return
def send_response_headers(self, response): def send_response_headers(self, response):
if self.zombie:
return
self.client_conn.h2.safe_send_headers( self.client_conn.h2.safe_send_headers(
self.is_zombie,
self.client_stream_id, self.client_stream_id,
response.headers response.headers
) )
def send_response_body(self, _response, chunks): def send_response_body(self, _response, chunks):
if self.zombie:
return
self.client_conn.h2.safe_send_body( self.client_conn.h2.safe_send_body(
self.is_zombie,
self.client_stream_id, self.client_stream_id,
chunks chunks
) )