fix http2 race condition

This commit is contained in:
Thomas Kriechbaumer 2016-02-02 23:54:35 +01:00
parent ca5cc34d0b
commit cf8c063773
2 changed files with 107 additions and 26 deletions

View File

@ -167,7 +167,8 @@ class Http2Layer(Layer):
new_settings = dict([(id, cs.new_value) for (id, cs) in event.changed_settings.iteritems()])
other_conn.h2.safe_update_settings(new_settings)
elif isinstance(event, ConnectionTerminated):
other_conn.h2.safe_close_connection(event.error_code)
# Do not immediately terminate the other connection.
# Some streams might be still sending data to the client.
return False
elif isinstance(event, PushedStreamReceived):
# pushed stream ids should be uniq and not dependent on race conditions
@ -183,7 +184,7 @@ class Http2Layer(Layer):
self.streams[event.pushed_stream_id].pushed = True
self.streams[event.pushed_stream_id].parent_stream_id = parent_eid
self.streams[event.pushed_stream_id].timestamp_end = time.time()
self.streams[event.pushed_stream_id].data_finished.set()
self.streams[event.pushed_stream_id].request_data_finished.set()
self.streams[event.pushed_stream_id].start()
elif isinstance(event, TrailersReceived):
raise NotImplementedError()
@ -240,18 +241,50 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread):
self.server_stream_id = None
self.request_headers = request_headers
self.response_headers = None
self.data_queue = Queue.Queue()
self.queued_data_length = 0
self.pushed = False
self.request_data_queue = Queue.Queue()
self.request_queued_data_length = 0
self.request_data_finished = threading.Event()
self.response_arrived = threading.Event()
self.data_finished = threading.Event()
self.response_data_queue = Queue.Queue()
self.response_queued_data_length = 0
self.response_data_finished = threading.Event()
@property
def data_queue(self):
if self.response_arrived.is_set():
return self.response_data_queue
else:
return self.request_data_queue
@property
def queued_data_length(self):
if self.response_arrived.is_set():
return self.response_queued_data_length
else:
return self.request_queued_data_length
@property
def data_finished(self):
if self.response_arrived.is_set():
return self.response_data_finished
else:
return self.request_data_finished
@queued_data_length.setter
def queued_data_length(self, v):
if self.response_arrived.is_set():
return self.response_queued_data_length
else:
return self.request_queued_data_length
def is_zombie(self):
return self.zombie is not None
def read_request(self):
self.data_finished.wait()
self.data_finished.clear()
self.request_data_finished.wait()
authority = self.request_headers.get(':authority', '')
method = self.request_headers.get(':method', 'GET')
@ -279,8 +312,8 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread):
port = int(port)
data = []
while self.data_queue.qsize() > 0:
data.append(self.data_queue.get())
while self.request_data_queue.qsize() > 0:
data.append(self.request_data_queue.get())
data = b"".join(data)
return HTTPRequest(
@ -298,9 +331,15 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread):
)
def send_request(self, message):
if self.zombie:
if self.pushed:
# nothing to do here
return
with self.server_conn.h2.lock:
# We must not assign a stream id if we are already a zombie.
if self.zombie:
return
self.server_stream_id = self.server_conn.h2.get_next_available_stream_id()
self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id
@ -333,12 +372,12 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread):
def read_response_body(self, request, response):
while True:
try:
yield self.data_queue.get(timeout=1)
yield self.response_data_queue.get(timeout=1)
except Queue.Empty:
pass
if self.data_finished.is_set():
while self.data_queue.qsize() > 0:
yield self.data_queue.get()
if self.response_data_finished.is_set():
while self.response_data_queue.qsize() > 0:
yield self.response_data_queue.get()
return
if self.zombie:
return

View File

@ -5,6 +5,7 @@ import pytest
import traceback
import os
import tempfile
import sys
from libmproxy.proxy.config import ProxyConfig
from libmproxy.proxy.server import ProxyServer
@ -47,9 +48,11 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase):
self.wfile.write(h2_conn.data_to_send())
self.wfile.flush()
while True:
done = False
while not done:
try:
events = h2_conn.receive_data(b''.join(http2_read_raw_frame(self.rfile)))
raw = b''.join(http2_read_raw_frame(self.rfile))
events = h2_conn.receive_data(raw)
except:
break
self.wfile.write(h2_conn.data_to_send())
@ -58,10 +61,12 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase):
for event in events:
try:
if not self.server.handle_server_event(event, h2_conn, self.rfile, self.wfile):
done = True
break
except Exception as e:
print(repr(e))
print(traceback.format_exc())
done = True
break
def handle_server_event(self, h2_conn, rfile, wfile):
@ -182,7 +187,10 @@ class TestSimple(_Http2TestBase, _Http2ServerBase):
done = False
while not done:
events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile)))
try:
events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile)))
except:
break
client.wfile.write(h2_conn.data_to_send())
client.wfile.flush()
@ -248,7 +256,10 @@ class TestWithBodies(_Http2TestBase, _Http2ServerBase):
done = False
while not done:
events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile)))
try:
events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile)))
except:
break
client.wfile.write(h2_conn.data_to_send())
client.wfile.flush()
@ -303,14 +314,16 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase):
wfile.write(h2_conn.data_to_send())
wfile.flush()
h2_conn.send_headers(2, [(':status', '202')])
h2_conn.send_headers(4, [(':status', '204')])
h2_conn.send_headers(2, [(':status', '200')])
h2_conn.send_headers(4, [(':status', '200')])
wfile.write(h2_conn.data_to_send())
wfile.flush()
h2_conn.send_data(1, b'regular_stream')
h2_conn.send_data(2, b'pushed_stream_foo')
h2_conn.send_data(4, b'pushed_stream_bar')
wfile.write(h2_conn.data_to_send())
wfile.flush()
h2_conn.end_stream(1)
h2_conn.end_stream(2)
h2_conn.end_stream(4)
@ -330,11 +343,14 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase):
('foo', 'bar')
])
done = False
ended_streams = 0
pushed_streams = 0
while ended_streams != 3:
responses = 0
while not done:
try:
events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile)))
raw = b''.join(http2_read_raw_frame(client.rfile))
events = h2_conn.receive_data(raw)
except:
break
client.wfile.write(h2_conn.data_to_send())
@ -345,7 +361,19 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase):
ended_streams += 1
elif isinstance(event, h2.events.PushedStreamReceived):
pushed_streams += 1
elif isinstance(event, h2.events.ResponseReceived):
responses += 1
if isinstance(event, h2.events.ConnectionTerminated):
done = True
if responses == 3 and ended_streams == 3 and pushed_streams == 2:
done = True
h2_conn.close_connection()
client.wfile.write(h2_conn.data_to_send())
client.wfile.flush()
assert ended_streams == 3
assert pushed_streams == 2
bodies = [flow.response.body for flow in self.master.state.flows]
@ -365,8 +393,11 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase):
('foo', 'bar')
])
streams = 0
while streams != 3:
done = False
ended_streams = 0
pushed_streams = 0
responses = 0
while not done:
try:
events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile)))
except:
@ -376,12 +407,23 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase):
for event in events:
if isinstance(event, h2.events.StreamEnded) and event.stream_id == 1:
streams += 1
ended_streams += 1
elif isinstance(event, h2.events.PushedStreamReceived):
streams += 1
pushed_streams += 1
h2_conn.reset_stream(event.pushed_stream_id, error_code=0x8)
client.wfile.write(h2_conn.data_to_send())
client.wfile.flush()
elif isinstance(event, h2.events.ResponseReceived):
responses += 1
if isinstance(event, h2.events.ConnectionTerminated):
done = True
if responses >= 1 and ended_streams >= 1 and pushed_streams == 2:
done = True
h2_conn.close_connection()
client.wfile.write(h2_conn.data_to_send())
client.wfile.flush()
bodies = [flow.response.body for flow in self.master.state.flows if flow.response]
assert len(bodies) >= 1