mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2025-01-30 23:09:44 +00:00
fix http2 race condition
This commit is contained in:
parent
ca5cc34d0b
commit
cf8c063773
@ -167,7 +167,8 @@ class Http2Layer(Layer):
|
||||
new_settings = dict([(id, cs.new_value) for (id, cs) in event.changed_settings.iteritems()])
|
||||
other_conn.h2.safe_update_settings(new_settings)
|
||||
elif isinstance(event, ConnectionTerminated):
|
||||
other_conn.h2.safe_close_connection(event.error_code)
|
||||
# Do not immediately terminate the other connection.
|
||||
# Some streams might be still sending data to the client.
|
||||
return False
|
||||
elif isinstance(event, PushedStreamReceived):
|
||||
# pushed stream ids should be uniq and not dependent on race conditions
|
||||
@ -183,7 +184,7 @@ class Http2Layer(Layer):
|
||||
self.streams[event.pushed_stream_id].pushed = True
|
||||
self.streams[event.pushed_stream_id].parent_stream_id = parent_eid
|
||||
self.streams[event.pushed_stream_id].timestamp_end = time.time()
|
||||
self.streams[event.pushed_stream_id].data_finished.set()
|
||||
self.streams[event.pushed_stream_id].request_data_finished.set()
|
||||
self.streams[event.pushed_stream_id].start()
|
||||
elif isinstance(event, TrailersReceived):
|
||||
raise NotImplementedError()
|
||||
@ -240,18 +241,50 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread):
|
||||
self.server_stream_id = None
|
||||
self.request_headers = request_headers
|
||||
self.response_headers = None
|
||||
self.data_queue = Queue.Queue()
|
||||
self.queued_data_length = 0
|
||||
self.pushed = False
|
||||
|
||||
self.request_data_queue = Queue.Queue()
|
||||
self.request_queued_data_length = 0
|
||||
self.request_data_finished = threading.Event()
|
||||
|
||||
self.response_arrived = threading.Event()
|
||||
self.data_finished = threading.Event()
|
||||
self.response_data_queue = Queue.Queue()
|
||||
self.response_queued_data_length = 0
|
||||
self.response_data_finished = threading.Event()
|
||||
|
||||
@property
|
||||
def data_queue(self):
|
||||
if self.response_arrived.is_set():
|
||||
return self.response_data_queue
|
||||
else:
|
||||
return self.request_data_queue
|
||||
|
||||
@property
|
||||
def queued_data_length(self):
|
||||
if self.response_arrived.is_set():
|
||||
return self.response_queued_data_length
|
||||
else:
|
||||
return self.request_queued_data_length
|
||||
|
||||
@property
|
||||
def data_finished(self):
|
||||
if self.response_arrived.is_set():
|
||||
return self.response_data_finished
|
||||
else:
|
||||
return self.request_data_finished
|
||||
|
||||
@queued_data_length.setter
|
||||
def queued_data_length(self, v):
|
||||
if self.response_arrived.is_set():
|
||||
return self.response_queued_data_length
|
||||
else:
|
||||
return self.request_queued_data_length
|
||||
|
||||
def is_zombie(self):
|
||||
return self.zombie is not None
|
||||
|
||||
def read_request(self):
|
||||
self.data_finished.wait()
|
||||
self.data_finished.clear()
|
||||
self.request_data_finished.wait()
|
||||
|
||||
authority = self.request_headers.get(':authority', '')
|
||||
method = self.request_headers.get(':method', 'GET')
|
||||
@ -279,8 +312,8 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread):
|
||||
port = int(port)
|
||||
|
||||
data = []
|
||||
while self.data_queue.qsize() > 0:
|
||||
data.append(self.data_queue.get())
|
||||
while self.request_data_queue.qsize() > 0:
|
||||
data.append(self.request_data_queue.get())
|
||||
data = b"".join(data)
|
||||
|
||||
return HTTPRequest(
|
||||
@ -298,9 +331,15 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread):
|
||||
)
|
||||
|
||||
def send_request(self, message):
|
||||
if self.zombie:
|
||||
if self.pushed:
|
||||
# nothing to do here
|
||||
return
|
||||
|
||||
with self.server_conn.h2.lock:
|
||||
# We must not assign a stream id if we are already a zombie.
|
||||
if self.zombie:
|
||||
return
|
||||
|
||||
self.server_stream_id = self.server_conn.h2.get_next_available_stream_id()
|
||||
self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id
|
||||
|
||||
@ -333,12 +372,12 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread):
|
||||
def read_response_body(self, request, response):
|
||||
while True:
|
||||
try:
|
||||
yield self.data_queue.get(timeout=1)
|
||||
yield self.response_data_queue.get(timeout=1)
|
||||
except Queue.Empty:
|
||||
pass
|
||||
if self.data_finished.is_set():
|
||||
while self.data_queue.qsize() > 0:
|
||||
yield self.data_queue.get()
|
||||
if self.response_data_finished.is_set():
|
||||
while self.response_data_queue.qsize() > 0:
|
||||
yield self.response_data_queue.get()
|
||||
return
|
||||
if self.zombie:
|
||||
return
|
||||
|
@ -5,6 +5,7 @@ import pytest
|
||||
import traceback
|
||||
import os
|
||||
import tempfile
|
||||
import sys
|
||||
|
||||
from libmproxy.proxy.config import ProxyConfig
|
||||
from libmproxy.proxy.server import ProxyServer
|
||||
@ -47,9 +48,11 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase):
|
||||
self.wfile.write(h2_conn.data_to_send())
|
||||
self.wfile.flush()
|
||||
|
||||
while True:
|
||||
done = False
|
||||
while not done:
|
||||
try:
|
||||
events = h2_conn.receive_data(b''.join(http2_read_raw_frame(self.rfile)))
|
||||
raw = b''.join(http2_read_raw_frame(self.rfile))
|
||||
events = h2_conn.receive_data(raw)
|
||||
except:
|
||||
break
|
||||
self.wfile.write(h2_conn.data_to_send())
|
||||
@ -58,10 +61,12 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase):
|
||||
for event in events:
|
||||
try:
|
||||
if not self.server.handle_server_event(event, h2_conn, self.rfile, self.wfile):
|
||||
done = True
|
||||
break
|
||||
except Exception as e:
|
||||
print(repr(e))
|
||||
print(traceback.format_exc())
|
||||
done = True
|
||||
break
|
||||
|
||||
def handle_server_event(self, h2_conn, rfile, wfile):
|
||||
@ -182,7 +187,10 @@ class TestSimple(_Http2TestBase, _Http2ServerBase):
|
||||
|
||||
done = False
|
||||
while not done:
|
||||
events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile)))
|
||||
try:
|
||||
events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile)))
|
||||
except:
|
||||
break
|
||||
client.wfile.write(h2_conn.data_to_send())
|
||||
client.wfile.flush()
|
||||
|
||||
@ -248,7 +256,10 @@ class TestWithBodies(_Http2TestBase, _Http2ServerBase):
|
||||
|
||||
done = False
|
||||
while not done:
|
||||
events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile)))
|
||||
try:
|
||||
events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile)))
|
||||
except:
|
||||
break
|
||||
client.wfile.write(h2_conn.data_to_send())
|
||||
client.wfile.flush()
|
||||
|
||||
@ -303,14 +314,16 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase):
|
||||
wfile.write(h2_conn.data_to_send())
|
||||
wfile.flush()
|
||||
|
||||
h2_conn.send_headers(2, [(':status', '202')])
|
||||
h2_conn.send_headers(4, [(':status', '204')])
|
||||
h2_conn.send_headers(2, [(':status', '200')])
|
||||
h2_conn.send_headers(4, [(':status', '200')])
|
||||
wfile.write(h2_conn.data_to_send())
|
||||
wfile.flush()
|
||||
|
||||
h2_conn.send_data(1, b'regular_stream')
|
||||
h2_conn.send_data(2, b'pushed_stream_foo')
|
||||
h2_conn.send_data(4, b'pushed_stream_bar')
|
||||
wfile.write(h2_conn.data_to_send())
|
||||
wfile.flush()
|
||||
h2_conn.end_stream(1)
|
||||
h2_conn.end_stream(2)
|
||||
h2_conn.end_stream(4)
|
||||
@ -330,11 +343,14 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase):
|
||||
('foo', 'bar')
|
||||
])
|
||||
|
||||
done = False
|
||||
ended_streams = 0
|
||||
pushed_streams = 0
|
||||
while ended_streams != 3:
|
||||
responses = 0
|
||||
while not done:
|
||||
try:
|
||||
events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile)))
|
||||
raw = b''.join(http2_read_raw_frame(client.rfile))
|
||||
events = h2_conn.receive_data(raw)
|
||||
except:
|
||||
break
|
||||
client.wfile.write(h2_conn.data_to_send())
|
||||
@ -345,7 +361,19 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase):
|
||||
ended_streams += 1
|
||||
elif isinstance(event, h2.events.PushedStreamReceived):
|
||||
pushed_streams += 1
|
||||
elif isinstance(event, h2.events.ResponseReceived):
|
||||
responses += 1
|
||||
if isinstance(event, h2.events.ConnectionTerminated):
|
||||
done = True
|
||||
|
||||
if responses == 3 and ended_streams == 3 and pushed_streams == 2:
|
||||
done = True
|
||||
|
||||
h2_conn.close_connection()
|
||||
client.wfile.write(h2_conn.data_to_send())
|
||||
client.wfile.flush()
|
||||
|
||||
assert ended_streams == 3
|
||||
assert pushed_streams == 2
|
||||
|
||||
bodies = [flow.response.body for flow in self.master.state.flows]
|
||||
@ -365,8 +393,11 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase):
|
||||
('foo', 'bar')
|
||||
])
|
||||
|
||||
streams = 0
|
||||
while streams != 3:
|
||||
done = False
|
||||
ended_streams = 0
|
||||
pushed_streams = 0
|
||||
responses = 0
|
||||
while not done:
|
||||
try:
|
||||
events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile)))
|
||||
except:
|
||||
@ -376,12 +407,23 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase):
|
||||
|
||||
for event in events:
|
||||
if isinstance(event, h2.events.StreamEnded) and event.stream_id == 1:
|
||||
streams += 1
|
||||
ended_streams += 1
|
||||
elif isinstance(event, h2.events.PushedStreamReceived):
|
||||
streams += 1
|
||||
pushed_streams += 1
|
||||
h2_conn.reset_stream(event.pushed_stream_id, error_code=0x8)
|
||||
client.wfile.write(h2_conn.data_to_send())
|
||||
client.wfile.flush()
|
||||
elif isinstance(event, h2.events.ResponseReceived):
|
||||
responses += 1
|
||||
if isinstance(event, h2.events.ConnectionTerminated):
|
||||
done = True
|
||||
|
||||
if responses >= 1 and ended_streams >= 1 and pushed_streams == 2:
|
||||
done = True
|
||||
|
||||
h2_conn.close_connection()
|
||||
client.wfile.write(h2_conn.data_to_send())
|
||||
client.wfile.flush()
|
||||
|
||||
bodies = [flow.response.body for flow in self.master.state.flows if flow.response]
|
||||
assert len(bodies) >= 1
|
||||
|
Loading…
Reference in New Issue
Block a user