mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-29 19:08:44 +00:00
unify HTTP trailers APIs
This commit is contained in:
parent
d589f13a1d
commit
ebb061796c
@ -26,6 +26,7 @@ class HTTPRequest(http.Request):
|
|||||||
http_version,
|
http_version,
|
||||||
headers,
|
headers,
|
||||||
content,
|
content,
|
||||||
|
trailers=None,
|
||||||
timestamp_start=None,
|
timestamp_start=None,
|
||||||
timestamp_end=None,
|
timestamp_end=None,
|
||||||
is_replay=False,
|
is_replay=False,
|
||||||
@ -41,6 +42,7 @@ class HTTPRequest(http.Request):
|
|||||||
http_version,
|
http_version,
|
||||||
headers,
|
headers,
|
||||||
content,
|
content,
|
||||||
|
trailers,
|
||||||
timestamp_start,
|
timestamp_start,
|
||||||
timestamp_end,
|
timestamp_end,
|
||||||
)
|
)
|
||||||
@ -73,6 +75,7 @@ class HTTPRequest(http.Request):
|
|||||||
http_version=request.data.http_version,
|
http_version=request.data.http_version,
|
||||||
headers=request.data.headers,
|
headers=request.data.headers,
|
||||||
content=request.data.content,
|
content=request.data.content,
|
||||||
|
trailers=request.data.trailers,
|
||||||
timestamp_start=request.data.timestamp_start,
|
timestamp_start=request.data.timestamp_start,
|
||||||
timestamp_end=request.data.timestamp_end,
|
timestamp_end=request.data.timestamp_end,
|
||||||
)
|
)
|
||||||
|
@ -174,6 +174,7 @@ def convert_6_7(data):
|
|||||||
|
|
||||||
def convert_7_8(data):
|
def convert_7_8(data):
|
||||||
data["version"] = 8
|
data["version"] = 8
|
||||||
|
data["request"]["trailers"] = None
|
||||||
data["response"]["trailers"] = None
|
data["response"]["trailers"] = None
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -59,7 +59,7 @@ def read_request_head(rfile):
|
|||||||
timestamp_start = rfile.first_byte_timestamp
|
timestamp_start = rfile.first_byte_timestamp
|
||||||
|
|
||||||
return request.Request(
|
return request.Request(
|
||||||
form, method, scheme, host, port, path, http_version, headers, None, timestamp_start
|
form, method, scheme, host, port, path, http_version, headers, None, None, timestamp_start
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -134,6 +134,20 @@ class Message(serializable.Serializable):
|
|||||||
|
|
||||||
content = property(get_content, set_content)
|
content = property(get_content, set_content)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def trailers(self):
|
||||||
|
"""
|
||||||
|
Message trailers object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
mitmproxy.net.http.Headers
|
||||||
|
"""
|
||||||
|
return self.data.trailers
|
||||||
|
|
||||||
|
@trailers.setter
|
||||||
|
def trailers(self, h):
|
||||||
|
self.data.trailers = h
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def http_version(self):
|
def http_version(self):
|
||||||
"""
|
"""
|
||||||
|
@ -29,6 +29,7 @@ class RequestData(message.MessageData):
|
|||||||
http_version,
|
http_version,
|
||||||
headers=(),
|
headers=(),
|
||||||
content=None,
|
content=None,
|
||||||
|
trailers=None,
|
||||||
timestamp_start=None,
|
timestamp_start=None,
|
||||||
timestamp_end=None
|
timestamp_end=None
|
||||||
):
|
):
|
||||||
@ -46,6 +47,8 @@ class RequestData(message.MessageData):
|
|||||||
headers = nheaders.Headers(headers)
|
headers = nheaders.Headers(headers)
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
|
raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
|
||||||
|
if trailers is not None and not isinstance(trailers, nheaders.Headers):
|
||||||
|
trailers = nheaders.Headers(trailers)
|
||||||
|
|
||||||
self.first_line_format = first_line_format
|
self.first_line_format = first_line_format
|
||||||
self.method = method
|
self.method = method
|
||||||
@ -56,6 +59,7 @@ class RequestData(message.MessageData):
|
|||||||
self.http_version = http_version
|
self.http_version = http_version
|
||||||
self.headers = headers
|
self.headers = headers
|
||||||
self.content = content
|
self.content = content
|
||||||
|
self.trailers = trailers
|
||||||
self.timestamp_start = timestamp_start
|
self.timestamp_start = timestamp_start
|
||||||
self.timestamp_end = timestamp_end
|
self.timestamp_end = timestamp_end
|
||||||
|
|
||||||
|
@ -34,6 +34,8 @@ class ResponseData(message.MessageData):
|
|||||||
headers = nheaders.Headers(headers)
|
headers = nheaders.Headers(headers)
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
|
raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
|
||||||
|
if trailers is not None and not isinstance(trailers, nheaders.Headers):
|
||||||
|
trailers = nheaders.Headers(trailers)
|
||||||
|
|
||||||
self.http_version = http_version
|
self.http_version = http_version
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
|
@ -20,6 +20,9 @@ class _HttpTransmissionLayer(base.Layer):
|
|||||||
def read_request_body(self, request):
|
def read_request_body(self, request):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def read_request_trailers(self, request):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def send_request(self, request):
|
def send_request(self, request):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@ -30,11 +33,15 @@ class _HttpTransmissionLayer(base.Layer):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
yield "this is a generator" # pragma: no cover
|
yield "this is a generator" # pragma: no cover
|
||||||
|
|
||||||
|
def read_response_trailers(self, request, response):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def read_response(self, request):
|
def read_response(self, request):
|
||||||
response = self.read_response_headers()
|
response = self.read_response_headers()
|
||||||
response.data.content = b"".join(
|
response.data.content = b"".join(
|
||||||
self.read_response_body(request, response)
|
self.read_response_body(request, response)
|
||||||
)
|
)
|
||||||
|
response.data.trailers = self.read_response_trailers(request, response)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def send_response(self, response):
|
def send_response(self, response):
|
||||||
@ -42,6 +49,7 @@ class _HttpTransmissionLayer(base.Layer):
|
|||||||
raise exceptions.HttpException("Cannot assemble flow with missing content")
|
raise exceptions.HttpException("Cannot assemble flow with missing content")
|
||||||
self.send_response_headers(response)
|
self.send_response_headers(response)
|
||||||
self.send_response_body(response, [response.data.content])
|
self.send_response_body(response, [response.data.content])
|
||||||
|
self.send_response_trailers(response)
|
||||||
|
|
||||||
def send_response_headers(self, response):
|
def send_response_headers(self, response):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@ -49,6 +57,9 @@ class _HttpTransmissionLayer(base.Layer):
|
|||||||
def send_response_body(self, response, chunks):
|
def send_response_body(self, response, chunks):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def send_response_trailers(self, response, chunks):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def check_close_connection(self, f):
|
def check_close_connection(self, f):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@ -255,6 +266,7 @@ class HttpLayer(base.Layer):
|
|||||||
f.request.data.content = b"".join(
|
f.request.data.content = b"".join(
|
||||||
self.read_request_body(f.request)
|
self.read_request_body(f.request)
|
||||||
)
|
)
|
||||||
|
f.request.data.trailers = self.read_request_trailers(f.request)
|
||||||
f.request.timestamp_end = time.time()
|
f.request.timestamp_end = time.time()
|
||||||
self.channel.ask("http_connect", f)
|
self.channel.ask("http_connect", f)
|
||||||
|
|
||||||
@ -282,6 +294,9 @@ class HttpLayer(base.Layer):
|
|||||||
f.request.data.content = None
|
f.request.data.content = None
|
||||||
else:
|
else:
|
||||||
f.request.data.content = b"".join(self.read_request_body(request))
|
f.request.data.content = b"".join(self.read_request_body(request))
|
||||||
|
|
||||||
|
f.request.data.trailers = self.read_request_trailers(f.request)
|
||||||
|
|
||||||
request.timestamp_end = time.time()
|
request.timestamp_end = time.time()
|
||||||
except exceptions.HttpException as e:
|
except exceptions.HttpException as e:
|
||||||
# We optimistically guess there might be an HTTP client on the
|
# We optimistically guess there might be an HTTP client on the
|
||||||
@ -348,6 +363,8 @@ class HttpLayer(base.Layer):
|
|||||||
else:
|
else:
|
||||||
self.send_request_body(f.request, [f.request.data.content])
|
self.send_request_body(f.request, [f.request.data.content])
|
||||||
|
|
||||||
|
self.send_request_trailers(f.request)
|
||||||
|
|
||||||
f.response = self.read_response_headers()
|
f.response = self.read_response_headers()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -406,10 +423,9 @@ class HttpLayer(base.Layer):
|
|||||||
# we now need to emulate the responseheaders hook.
|
# we now need to emulate the responseheaders hook.
|
||||||
self.channel.ask("responseheaders", f)
|
self.channel.ask("responseheaders", f)
|
||||||
|
|
||||||
|
f.response.data.trailers = self.read_response_trailers(f.request, f.response)
|
||||||
|
|
||||||
self.log("response", "debug", [repr(f.response)])
|
self.log("response", "debug", [repr(f.response)])
|
||||||
# not support HTTP/1.1 trailers
|
|
||||||
if f.request.http_version == "HTTP/2.0":
|
|
||||||
f.response.data.trailers = self.read_trailers_headers()
|
|
||||||
self.channel.ask("response", f)
|
self.channel.ask("response", f)
|
||||||
|
|
||||||
if not f.response.stream:
|
if not f.response.stream:
|
||||||
|
@ -23,6 +23,12 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
|
|||||||
human.parse_size(self.config.options.body_size_limit)
|
human.parse_size(self.config.options.body_size_limit)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def read_request_trailers(self, request):
|
||||||
|
if "Trailer" in request.headers:
|
||||||
|
# TODO: not implemented yet
|
||||||
|
self.log("HTTP/1 request trailer headers are not implemented yet!", "warn")
|
||||||
|
return None
|
||||||
|
|
||||||
def send_request_headers(self, request):
|
def send_request_headers(self, request):
|
||||||
headers = http1.assemble_request_head(request)
|
headers = http1.assemble_request_head(request)
|
||||||
self.server_conn.wfile.write(headers)
|
self.server_conn.wfile.write(headers)
|
||||||
@ -33,7 +39,13 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
|
|||||||
self.server_conn.wfile.write(chunk)
|
self.server_conn.wfile.write(chunk)
|
||||||
self.server_conn.wfile.flush()
|
self.server_conn.wfile.flush()
|
||||||
|
|
||||||
|
def send_request_trailers(self, request):
|
||||||
|
if "Trailer" in request.headers:
|
||||||
|
# TODO: not implemented yet
|
||||||
|
self.log("HTTP/1 request trailer headers are not implemented yet!", "warn")
|
||||||
|
|
||||||
def send_request(self, request):
|
def send_request(self, request):
|
||||||
|
# TODO: this does not yet support request trailers
|
||||||
self.server_conn.wfile.write(http1.assemble_request(request))
|
self.server_conn.wfile.write(http1.assemble_request(request))
|
||||||
self.server_conn.wfile.flush()
|
self.server_conn.wfile.flush()
|
||||||
|
|
||||||
@ -49,6 +61,12 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
|
|||||||
human.parse_size(self.config.options.body_size_limit)
|
human.parse_size(self.config.options.body_size_limit)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def read_response_trailers(self, request, response):
|
||||||
|
if "Trailer" in response.headers:
|
||||||
|
# TODO: not implemented yet
|
||||||
|
self.log("HTTP/1 trailer headers are not implemented yet!", "warn")
|
||||||
|
return None
|
||||||
|
|
||||||
def send_response_headers(self, response):
|
def send_response_headers(self, response):
|
||||||
raw = http1.assemble_response_head(response)
|
raw = http1.assemble_response_head(response)
|
||||||
self.client_conn.wfile.write(raw)
|
self.client_conn.wfile.write(raw)
|
||||||
@ -59,6 +77,12 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
|
|||||||
self.client_conn.wfile.write(chunk)
|
self.client_conn.wfile.write(chunk)
|
||||||
self.client_conn.wfile.flush()
|
self.client_conn.wfile.flush()
|
||||||
|
|
||||||
|
def send_response_trailers(self, response):
|
||||||
|
if "Trailer" in response.headers:
|
||||||
|
# TODO: not implemented yet
|
||||||
|
self.log("HTTP/1 trailer headers are not implemented yet!", "warn")
|
||||||
|
return
|
||||||
|
|
||||||
def check_close_connection(self, flow):
|
def check_close_connection(self, flow):
|
||||||
request_close = http1.connection_close(
|
request_close = http1.connection_close(
|
||||||
flow.request.http_version,
|
flow.request.http_version,
|
||||||
|
@ -235,8 +235,10 @@ class Http2Layer(base.Layer):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def _handle_trailers(self, eid, event, is_server, other_conn):
|
def _handle_trailers(self, eid, event, is_server, other_conn):
|
||||||
headers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers])
|
trailers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers])
|
||||||
self.streams[eid].update_trailers(headers)
|
# TODO: support request trailers as well!
|
||||||
|
self.streams[eid].response_trailers = trailers
|
||||||
|
self.streams[eid].response_trailers_arrived.set()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _handle_remote_settings_changed(self, event, other_conn):
|
def _handle_remote_settings_changed(self, event, other_conn):
|
||||||
@ -417,15 +419,17 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
|||||||
self.request_data_queue: queue.Queue[bytes] = queue.Queue()
|
self.request_data_queue: queue.Queue[bytes] = queue.Queue()
|
||||||
self.request_queued_data_length = 0
|
self.request_queued_data_length = 0
|
||||||
self.request_data_finished = threading.Event()
|
self.request_data_finished = threading.Event()
|
||||||
|
self.request_trailers_arrived = threading.Event()
|
||||||
|
self.request_trailers = None
|
||||||
|
|
||||||
self.response_arrived = threading.Event()
|
self.response_arrived = threading.Event()
|
||||||
self.response_data_queue: queue.Queue[bytes] = queue.Queue()
|
self.response_data_queue: queue.Queue[bytes] = queue.Queue()
|
||||||
self.response_queued_data_length = 0
|
self.response_queued_data_length = 0
|
||||||
self.response_data_finished = threading.Event()
|
self.response_data_finished = threading.Event()
|
||||||
|
self.response_trailers_arrived = threading.Event()
|
||||||
|
self.response_trailers = None
|
||||||
|
|
||||||
self.no_body = False
|
self.no_body = False
|
||||||
self.has_trailers = False
|
|
||||||
self.trailers_header = None
|
|
||||||
|
|
||||||
self.priority_exclusive: bool
|
self.priority_exclusive: bool
|
||||||
self.priority_depends_on: Optional[int] = None
|
self.priority_depends_on: Optional[int] = None
|
||||||
@ -437,8 +441,10 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
|||||||
self.zombie = time.time()
|
self.zombie = time.time()
|
||||||
self.request_data_finished.set()
|
self.request_data_finished.set()
|
||||||
self.request_arrived.set()
|
self.request_arrived.set()
|
||||||
|
self.request_trailers_arrived.set()
|
||||||
self.response_arrived.set()
|
self.response_arrived.set()
|
||||||
self.response_data_finished.set()
|
self.response_data_finished.set()
|
||||||
|
self.response_trailers_arrived.set()
|
||||||
|
|
||||||
def connect(self): # pragma: no cover
|
def connect(self): # pragma: no cover
|
||||||
raise exceptions.Http2ProtocolException("HTTP2 layer should already have a connection.")
|
raise exceptions.Http2ProtocolException("HTTP2 layer should already have a connection.")
|
||||||
@ -526,6 +532,14 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
|||||||
break
|
break
|
||||||
self.raise_zombie()
|
self.raise_zombie()
|
||||||
|
|
||||||
|
@detect_zombie_stream
|
||||||
|
def read_request_trailers(self, request):
|
||||||
|
if "trailer" in request.headers:
|
||||||
|
self.request_trailers_arrived.wait()
|
||||||
|
self.raise_zombie()
|
||||||
|
return self.request_trailers
|
||||||
|
return None
|
||||||
|
|
||||||
@detect_zombie_stream
|
@detect_zombie_stream
|
||||||
def send_request_headers(self, request):
|
def send_request_headers(self, request):
|
||||||
if self.pushed:
|
if self.pushed:
|
||||||
@ -600,25 +614,14 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
|||||||
)
|
)
|
||||||
|
|
||||||
@detect_zombie_stream
|
@detect_zombie_stream
|
||||||
def update_trailers(self, headers):
|
def send_request_trailers(self, request):
|
||||||
self.trailers_header = headers
|
self._send_trailers(self.server_conn, self.request_trailers)
|
||||||
self.has_trailers = True
|
|
||||||
|
|
||||||
@detect_zombie_stream
|
@detect_zombie_stream
|
||||||
def send_trailers_headers(self):
|
def send_request(self, request):
|
||||||
if self.has_trailers and self.trailers_header:
|
self.send_request_headers(request)
|
||||||
with self.connections[self.client_conn].lock:
|
self.send_request_body(request, [request.content])
|
||||||
self.connections[self.client_conn].safe_send_headers(
|
self.send_request_trailers(request)
|
||||||
self.raise_zombie,
|
|
||||||
self.client_stream_id,
|
|
||||||
self.trailers_header,
|
|
||||||
end_stream = True
|
|
||||||
)
|
|
||||||
|
|
||||||
@detect_zombie_stream
|
|
||||||
def send_request(self, message):
|
|
||||||
self.send_request_headers(message)
|
|
||||||
self.send_request_body(message, [message.content])
|
|
||||||
|
|
||||||
@detect_zombie_stream
|
@detect_zombie_stream
|
||||||
def read_response_headers(self):
|
def read_response_headers(self):
|
||||||
@ -640,10 +643,6 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
|||||||
timestamp_end=self.timestamp_end,
|
timestamp_end=self.timestamp_end,
|
||||||
)
|
)
|
||||||
|
|
||||||
@detect_zombie_stream
|
|
||||||
def read_trailers_headers(self):
|
|
||||||
return self.trailers_header
|
|
||||||
|
|
||||||
@detect_zombie_stream
|
@detect_zombie_stream
|
||||||
def read_response_body(self, request, response):
|
def read_response_body(self, request, response):
|
||||||
while True:
|
while True:
|
||||||
@ -658,6 +657,14 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
|||||||
break
|
break
|
||||||
self.raise_zombie()
|
self.raise_zombie()
|
||||||
|
|
||||||
|
@detect_zombie_stream
|
||||||
|
def read_response_trailers(self, request, response):
|
||||||
|
if "trailer" in response.headers:
|
||||||
|
self.response_trailers_arrived.wait()
|
||||||
|
self.raise_zombie()
|
||||||
|
return self.response_trailers
|
||||||
|
return None
|
||||||
|
|
||||||
@detect_zombie_stream
|
@detect_zombie_stream
|
||||||
def send_response_headers(self, response):
|
def send_response_headers(self, response):
|
||||||
headers = response.headers.copy()
|
headers = response.headers.copy()
|
||||||
@ -670,15 +677,28 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
|||||||
)
|
)
|
||||||
|
|
||||||
@detect_zombie_stream
|
@detect_zombie_stream
|
||||||
def send_response_body(self, _response, chunks):
|
def send_response_body(self, response, chunks):
|
||||||
self.connections[self.client_conn].safe_send_body(
|
self.connections[self.client_conn].safe_send_body(
|
||||||
self.raise_zombie,
|
self.raise_zombie,
|
||||||
self.client_stream_id,
|
self.client_stream_id,
|
||||||
chunks,
|
chunks,
|
||||||
end_stream = not self.has_trailers
|
end_stream=("trailer" not in response.headers)
|
||||||
|
)
|
||||||
|
|
||||||
|
@detect_zombie_stream
|
||||||
|
def send_response_trailers(self, _response):
|
||||||
|
self._send_trailers(self.client_conn, self.response_trailers)
|
||||||
|
|
||||||
|
def _send_trailers(self, conn, trailers):
|
||||||
|
if not trailers:
|
||||||
|
return
|
||||||
|
with self.connections[conn].lock:
|
||||||
|
self.connections[conn].safe_send_headers(
|
||||||
|
self.raise_zombie,
|
||||||
|
self.client_stream_id,
|
||||||
|
trailers,
|
||||||
|
end_stream=True
|
||||||
)
|
)
|
||||||
if self.has_trailers:
|
|
||||||
self.send_trailers_headers()
|
|
||||||
|
|
||||||
def __call__(self): # pragma: no cover
|
def __call__(self): # pragma: no cover
|
||||||
raise EnvironmentError('Http2SingleStreamLayer must be run as thread')
|
raise EnvironmentError('Http2SingleStreamLayer must be run as thread')
|
||||||
|
@ -110,8 +110,9 @@ class HTTP2StateProtocol:
|
|||||||
b"HTTP/2.0",
|
b"HTTP/2.0",
|
||||||
headers,
|
headers,
|
||||||
body,
|
body,
|
||||||
timestamp_start,
|
None,
|
||||||
timestamp_end,
|
timestamp_start=timestamp_start,
|
||||||
|
timestamp_end=timestamp_end,
|
||||||
)
|
)
|
||||||
request.stream_id = stream_id
|
request.stream_id = stream_id
|
||||||
|
|
||||||
|
@ -21,8 +21,11 @@ class TestRequestData:
|
|||||||
treq(headers="foobar")
|
treq(headers="foobar")
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
treq(content="foobar")
|
treq(content="foobar")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
treq(trailers="foobar")
|
||||||
|
|
||||||
assert isinstance(treq(headers=()).headers, Headers)
|
assert isinstance(treq(headers=()).headers, Headers)
|
||||||
|
assert isinstance(treq(trailers=()).trailers, Headers)
|
||||||
|
|
||||||
|
|
||||||
class TestRequestCore:
|
class TestRequestCore:
|
||||||
|
@ -20,8 +20,11 @@ class TestResponseData:
|
|||||||
tresp(reason="fööbär")
|
tresp(reason="fööbär")
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
tresp(content="foobar")
|
tresp(content="foobar")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tresp(trailers="foobar")
|
||||||
|
|
||||||
assert isinstance(tresp(headers=()).headers, Headers)
|
assert isinstance(tresp(headers=()).headers, Headers)
|
||||||
|
assert isinstance(tresp(trailers=()).trailers, Headers)
|
||||||
|
|
||||||
|
|
||||||
class TestResponseCore:
|
class TestResponseCore:
|
||||||
|
@ -1034,37 +1034,19 @@ class TestResponseStreaming(_Http2Test):
|
|||||||
|
|
||||||
|
|
||||||
class TestTrailers(_Http2Test):
|
class TestTrailers(_Http2Test):
|
||||||
request_body_buffer = b''
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def handle_server_event(cls, event, h2_conn, rfile, wfile):
|
def handle_server_event(cls, event, h2_conn, rfile, wfile):
|
||||||
if isinstance(event, h2.events.ConnectionTerminated):
|
if isinstance(event, h2.events.ConnectionTerminated):
|
||||||
return False
|
return False
|
||||||
elif isinstance(event, h2.events.RequestReceived):
|
|
||||||
assert (b'self.client-foo', b'self.client-bar-1') in event.headers
|
|
||||||
assert (b'self.client-foo', b'self.client-bar-2') in event.headers
|
|
||||||
elif isinstance(event, h2.events.StreamEnded):
|
elif isinstance(event, h2.events.StreamEnded):
|
||||||
import warnings
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
# Ignore UnicodeWarning:
|
|
||||||
# h2/utilities.py:64: UnicodeWarning: Unicode equal comparison
|
|
||||||
# failed to convert both arguments to Unicode - interpreting
|
|
||||||
# them as being unequal.
|
|
||||||
# elif header[0] in (b'cookie', u'cookie') and len(header[1]) < 20:
|
|
||||||
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
h2_conn.send_headers(event.stream_id, [
|
h2_conn.send_headers(event.stream_id, [
|
||||||
(':status', '200'),
|
(':status', '200'),
|
||||||
('server-foo', 'server-bar'),
|
('trailer', 'x-my-trailers')
|
||||||
('föo', 'bär'),
|
|
||||||
('X-Stream-ID', str(event.stream_id)),
|
|
||||||
])
|
])
|
||||||
h2_conn.send_data(event.stream_id, b'response body')
|
h2_conn.send_data(event.stream_id, b'response body')
|
||||||
h2_conn.send_headers(event.stream_id, [('trailers', 'trailers-foo')], end_stream=True)
|
h2_conn.send_headers(event.stream_id, [('x-my-trailers', 'foobar')], end_stream=True)
|
||||||
wfile.write(h2_conn.data_to_send())
|
wfile.write(h2_conn.data_to_send())
|
||||||
wfile.flush()
|
wfile.flush()
|
||||||
elif isinstance(event, h2.events.DataReceived):
|
|
||||||
cls.request_body_buffer += event.data
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def test_trailers(self):
|
def test_trailers(self):
|
||||||
@ -1079,11 +1061,9 @@ class TestTrailers(_Http2Test):
|
|||||||
(':method', 'GET'),
|
(':method', 'GET'),
|
||||||
(':scheme', 'https'),
|
(':scheme', 'https'),
|
||||||
(':path', '/'),
|
(':path', '/'),
|
||||||
('self.client-FoO', 'self.client-bar-1'),
|
])
|
||||||
('self.client-FoO', 'self.client-bar-2'),
|
|
||||||
],
|
|
||||||
body=b'request body')
|
|
||||||
|
|
||||||
|
trailers_buffer = None
|
||||||
done = False
|
done = False
|
||||||
while not done:
|
while not done:
|
||||||
try:
|
try:
|
||||||
@ -1099,6 +1079,8 @@ class TestTrailers(_Http2Test):
|
|||||||
for event in events:
|
for event in events:
|
||||||
if isinstance(event, h2.events.DataReceived):
|
if isinstance(event, h2.events.DataReceived):
|
||||||
response_body_buffer += event.data
|
response_body_buffer += event.data
|
||||||
|
elif isinstance(event, h2.events.TrailersReceived):
|
||||||
|
trailers_buffer = event.headers
|
||||||
elif isinstance(event, h2.events.StreamEnded):
|
elif isinstance(event, h2.events.StreamEnded):
|
||||||
done = True
|
done = True
|
||||||
|
|
||||||
@ -1108,9 +1090,7 @@ class TestTrailers(_Http2Test):
|
|||||||
|
|
||||||
assert len(self.master.state.flows) == 1
|
assert len(self.master.state.flows) == 1
|
||||||
assert self.master.state.flows[0].response.status_code == 200
|
assert self.master.state.flows[0].response.status_code == 200
|
||||||
assert self.master.state.flows[0].response.headers['server-foo'] == 'server-bar'
|
|
||||||
assert self.master.state.flows[0].response.headers['föo'] == 'bär'
|
|
||||||
assert self.master.state.flows[0].response.content == b'response body'
|
assert self.master.state.flows[0].response.content == b'response body'
|
||||||
assert self.request_body_buffer == b'request body'
|
|
||||||
assert response_body_buffer == b'response body'
|
assert response_body_buffer == b'response body'
|
||||||
assert self.master.state.flows[0].response.data.trailers['trailers'] == 'trailers-foo'
|
assert self.master.state.flows[0].response.data.trailers['x-my-trailers'] == 'foobar'
|
||||||
|
assert trailers_buffer == [(b'x-my-trailers', b'foobar')]
|
||||||
|
Loading…
Reference in New Issue
Block a user