request streaming for HTTP/2

This commit is contained in:
Ujjwal Verma 2017-06-05 01:52:36 +05:30 committed by Thomas Kriechbaumer
parent 47c9604aed
commit d4f35d7a4a
3 changed files with 155 additions and 13 deletions

View File

@ -333,6 +333,8 @@ class HttpLayer(base.Layer):
if f.request.stream:
self.send_request_headers(f.request)
chunks = self.read_request_body(f.request)
if callable(f.request.stream):
chunks = f.request.stream(chunks)
self.send_request_body(f.request, chunks)
else:
self.send_request(f.request)

View File

@ -487,14 +487,23 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
@detect_zombie_stream
def read_request_body(self, request):
self.request_data_finished.wait()
data = []
while self.request_data_queue.qsize() > 0:
data.append(self.request_data_queue.get())
return data
if not request.stream:
self.request_data_finished.wait()
while True:
try:
yield self.request_data_queue.get(timeout=0.1)
except queue.Empty: # pragma: no cover
pass
if self.request_data_finished.is_set():
self.raise_zombie()
while self.request_data_queue.qsize() > 0:
yield self.request_data_queue.get()
break
self.raise_zombie()
@detect_zombie_stream
def send_request(self, message):
def send_request_headers(self, request):
if self.pushed:
# nothing to do here
return
@ -519,10 +528,10 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
self.server_stream_id = self.connections[self.server_conn].get_next_available_stream_id()
self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id
headers = message.headers.copy()
headers.insert(0, ":path", message.path)
headers.insert(0, ":method", message.method)
headers.insert(0, ":scheme", message.scheme)
headers = request.headers.copy()
headers.insert(0, ":path", request.path)
headers.insert(0, ":method", request.method)
headers.insert(0, ":scheme", request.scheme)
priority_exclusive = None
priority_depends_on = None
@ -553,13 +562,24 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
self.raise_zombie()
self.connections[self.server_conn].lock.release()
@detect_zombie_stream
def send_request_body(self, request, chunks):
if self.pushed:
# nothing to do here
return
if not self.no_body:
self.connections[self.server_conn].safe_send_body(
self.raise_zombie,
self.server_stream_id,
[message.content]
chunks
)
@detect_zombie_stream
def send_request(self, message):
self.send_request_headers(message)
self.send_request_body(message, [message.content])
@detect_zombie_stream
def read_response_headers(self):
self.response_arrived.wait()

View File

@ -14,6 +14,7 @@ import mitmproxy.net
from ...net import tservers as net_tservers
from mitmproxy import exceptions
from mitmproxy.net.http import http1, http2
from pathod.language import generators
from ... import tservers
from ....conftest import requires_alpn
@ -166,7 +167,8 @@ class _Http2TestBase:
end_stream=None,
priority_exclusive=None,
priority_depends_on=None,
priority_weight=None):
priority_weight=None,
streaming=False):
if headers is None:
headers = []
if end_stream is None:
@ -182,7 +184,8 @@ class _Http2TestBase:
)
if body:
h2_conn.send_data(stream_id, body)
h2_conn.end_stream(stream_id)
if not streaming:
h2_conn.end_stream(stream_id)
wfile.write(h2_conn.data_to_send())
wfile.flush()
@ -862,3 +865,120 @@ class TestConnectionTerminated(_Http2Test):
assert connection_terminated_event.error_code == 5
assert connection_terminated_event.last_stream_id == 42
assert connection_terminated_event.additional_data == b'foobar'
@requires_alpn
class TestRequestStreaming(_Http2Test):
@classmethod
def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.DataReceived):
data = event.data
assert data
h2_conn.close_connection(error_code=5, last_stream_id=42, additional_data=data)
wfile.write(h2_conn.data_to_send())
wfile.flush()
return True
@pytest.mark.parametrize('streaming', [True, False])
def test_request_streaming(self, streaming):
class Stream:
def requestheaders(self, f):
f.request.stream = streaming
self.master.addons.add(Stream())
h2_conn = self.setup_connection()
body = generators.RandomGenerator("bytes", 100)[:]
self._send_request(
self.client.wfile,
h2_conn,
headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
],
body=body,
streaming=True
)
done = False
connection_terminated_event = None
self.client.rfile.o.settimeout(2)
while not done:
try:
raw = b''.join(http2.read_raw_frame(self.client.rfile))
events = h2_conn.receive_data(raw)
for event in events:
if isinstance(event, h2.events.ConnectionTerminated):
connection_terminated_event = event
done = True
except:
break
if streaming:
assert connection_terminated_event.additional_data == body
else:
assert connection_terminated_event is None
@requires_alpn
class TestResponseStreaming(_Http2Test):
@classmethod
def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.RequestReceived):
data = generators.RandomGenerator("bytes", 100)[:]
h2_conn.send_headers(event.stream_id, [
(':status', '200'),
('content-length', '100')
])
h2_conn.send_data(event.stream_id, data)
wfile.write(h2_conn.data_to_send())
wfile.flush()
return True
@pytest.mark.parametrize('streaming', [True, False])
def test_response_streaming(self, streaming):
class Stream:
def responseheaders(self, f):
f.response.stream = streaming
self.master.addons.add(Stream())
h2_conn = self.setup_connection()
self._send_request(
self.client.wfile,
h2_conn,
headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
]
)
done = False
self.client.rfile.o.settimeout(2)
data = None
while not done:
try:
raw = b''.join(http2.read_raw_frame(self.client.rfile))
events = h2_conn.receive_data(raw)
for event in events:
if isinstance(event, h2.events.DataReceived):
data = event.data
done = True
except:
break
if streaming:
assert data
else:
assert data is None