mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2025-01-30 23:09:44 +00:00
test PushPromise support
This commit is contained in:
parent
97c2530f90
commit
41f4197a0d
@ -17,7 +17,7 @@ from .base import Layer
|
||||
from .http import _HttpTransmissionLayer, HttpLayer
|
||||
from .. import utils
|
||||
from ..models import HTTPRequest, HTTPResponse
|
||||
|
||||
from ..exceptions import HttpProtocolException, ProtocolException
|
||||
|
||||
class SafeH2Connection(H2Connection):
|
||||
def __init__(self, conn, *args, **kwargs):
|
||||
@ -207,7 +207,14 @@ class Http2Layer(Layer):
|
||||
is_server = (conn == self.server_conn.connection)
|
||||
|
||||
with source_conn.h2.lock:
|
||||
events = source_conn.h2.receive_data(utils.http2_read_frame(source_conn.rfile))
|
||||
try:
|
||||
raw_frame = utils.http2_read_frame(source_conn.rfile)
|
||||
except:
|
||||
for stream in self.streams.values():
|
||||
stream.zombie = time.time()
|
||||
return
|
||||
|
||||
events = source_conn.h2.receive_data(raw_frame)
|
||||
source_conn.send(source_conn.h2.data_to_send())
|
||||
|
||||
for event in events:
|
||||
|
@ -28,9 +28,7 @@ requires_alpn = pytest.mark.skipif(
|
||||
|
||||
|
||||
class SimpleHttp2Server(netlib_tservers.ServerTestBase):
|
||||
ssl = dict(
|
||||
alpn_select=b'h2',
|
||||
)
|
||||
ssl = dict(alpn_select=b'h2')
|
||||
|
||||
class handler(netlib.tcp.BaseHandler):
|
||||
def handle(self):
|
||||
@ -61,6 +59,59 @@ class SimpleHttp2Server(netlib_tservers.ServerTestBase):
|
||||
return
|
||||
|
||||
|
||||
class PushHttp2Server(netlib_tservers.ServerTestBase):
|
||||
ssl = dict(alpn_select=b'h2')
|
||||
|
||||
class handler(netlib.tcp.BaseHandler):
|
||||
def handle(self):
|
||||
h2_conn = h2.connection.H2Connection(client_side=False)
|
||||
|
||||
preamble = self.rfile.read(24)
|
||||
h2_conn.initiate_connection()
|
||||
h2_conn.receive_data(preamble)
|
||||
self.wfile.write(h2_conn.data_to_send())
|
||||
self.wfile.flush()
|
||||
|
||||
while True:
|
||||
events = h2_conn.receive_data(utils.http2_read_frame(self.rfile))
|
||||
self.wfile.write(h2_conn.data_to_send())
|
||||
self.wfile.flush()
|
||||
|
||||
for event in events:
|
||||
if isinstance(event, h2.events.RequestReceived):
|
||||
h2_conn.send_headers(1, [(':status', '200')])
|
||||
h2_conn.push_stream(1, 2, [
|
||||
(':authority', "127.0.0.1:%s" % self.address.port),
|
||||
(':method', 'GET'),
|
||||
(':scheme', 'https'),
|
||||
(':path', '/pushed_stream_foo'),
|
||||
('foo', 'bar')
|
||||
])
|
||||
h2_conn.push_stream(1, 4, [
|
||||
(':authority', "127.0.0.1:%s" % self.address.port),
|
||||
(':method', 'GET'),
|
||||
(':scheme', 'https'),
|
||||
(':path', '/pushed_stream_bar'),
|
||||
('foo', 'bar')
|
||||
])
|
||||
self.wfile.write(h2_conn.data_to_send())
|
||||
self.wfile.flush()
|
||||
|
||||
h2_conn.send_headers(2, [(':status', '202')])
|
||||
h2_conn.send_headers(4, [(':status', '204')])
|
||||
h2_conn.send_data(1, b'regular_stream')
|
||||
h2_conn.send_data(2, b'pushed_stream_foo')
|
||||
h2_conn.send_data(4, b'pushed_stream_bar')
|
||||
h2_conn.end_stream(1)
|
||||
h2_conn.end_stream(2)
|
||||
h2_conn.end_stream(4)
|
||||
self.wfile.write(h2_conn.data_to_send())
|
||||
self.wfile.flush()
|
||||
print("HERE")
|
||||
elif isinstance(event, h2.events.ConnectionTerminated):
|
||||
return
|
||||
|
||||
|
||||
@requires_alpn
|
||||
class TestHttp2(tservers.ProxTestBase):
|
||||
def _setup_connection(self):
|
||||
@ -132,3 +183,37 @@ class TestHttp2(tservers.ProxTestBase):
|
||||
assert self.master.state.flows[0].response.status_code == 200
|
||||
assert self.master.state.flows[0].response.headers['foo'] == 'bar'
|
||||
assert self.master.state.flows[0].response.body == b'foobar'
|
||||
|
||||
def test_pushed_streams(self):
|
||||
self.server = PushHttp2Server()
|
||||
self.server.setup_class()
|
||||
|
||||
client, h2_conn = self._setup_connection()
|
||||
|
||||
self._send_request(client.wfile, h2_conn, headers=[
|
||||
(':authority', "127.0.0.1:%s" % self.server.port),
|
||||
(':method', 'GET'),
|
||||
(':scheme', 'https'),
|
||||
(':path', '/'),
|
||||
('foo', 'bar')
|
||||
])
|
||||
|
||||
ended_streams = 0
|
||||
while ended_streams != 3:
|
||||
try:
|
||||
events = h2_conn.receive_data(utils.http2_read_frame(client.rfile))
|
||||
except:
|
||||
break
|
||||
client.wfile.write(h2_conn.data_to_send())
|
||||
client.wfile.flush()
|
||||
|
||||
for event in events:
|
||||
if isinstance(event, h2.events.StreamEnded):
|
||||
ended_streams += 1
|
||||
|
||||
self.server.teardown_class()
|
||||
|
||||
assert len(self.master.state.flows) == 3
|
||||
assert self.master.state.flows[0].response.body == b'regular_stream'
|
||||
assert self.master.state.flows[1].response.body == b'pushed_stream_foo'
|
||||
assert self.master.state.flows[2].response.body == b'pushed_stream_bar'
|
||||
|
Loading…
Reference in New Issue
Block a user