Merge pull request #1178 from cortesi/pseudohdrs

Improve handling of HTTP2 pseudo-headers
This commit is contained in:
Aldo Cortesi 2016-05-31 16:34:28 +12:00
commit 2f526393d2
4 changed files with 35 additions and 10 deletions

View File

@ -306,6 +306,9 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread):
method = self.request_headers.get(':method', 'GET') method = self.request_headers.get(':method', 'GET')
scheme = self.request_headers.get(':scheme', 'https') scheme = self.request_headers.get(':scheme', 'https')
path = self.request_headers.get(':path', '/') path = self.request_headers.get(':path', '/')
self.request_headers.clear(":method")
self.request_headers.clear(":scheme")
self.request_headers.clear(":path")
host = None host = None
port = None port = None
@ -362,10 +365,15 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread):
self.server_stream_id = self.server_conn.h2.get_next_available_stream_id() self.server_stream_id = self.server_conn.h2.get_next_available_stream_id()
self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id
headers = message.headers.copy()
headers.insert(0, ":path", message.path)
headers.insert(0, ":method", message.method)
headers.insert(0, ":scheme", message.scheme)
self.server_conn.h2.safe_send_headers( self.server_conn.h2.safe_send_headers(
self.is_zombie, self.is_zombie,
self.server_stream_id, self.server_stream_id,
message.headers headers
) )
self.server_conn.h2.safe_send_body( self.server_conn.h2.safe_send_body(
self.is_zombie, self.is_zombie,
@ -379,12 +387,14 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread):
self.response_arrived.wait() self.response_arrived.wait()
status_code = int(self.response_headers.get(':status', 502)) status_code = int(self.response_headers.get(':status', 502))
headers = self.response_headers.copy()
headers.clear(":status")
return HTTPResponse( return HTTPResponse(
http_version=b"HTTP/2.0", http_version=b"HTTP/2.0",
status_code=status_code, status_code=status_code,
reason='', reason='',
headers=self.response_headers, headers=headers,
content=None, content=None,
timestamp_start=self.timestamp_start, timestamp_start=self.timestamp_start,
timestamp_end=self.timestamp_end, timestamp_end=self.timestamp_end,
@ -404,10 +414,12 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread):
raise Http2ProtocolException("Zombie Stream") raise Http2ProtocolException("Zombie Stream")
def send_response_headers(self, response): def send_response_headers(self, response):
headers = response.headers.copy()
headers.insert(0, ":status", str(response.status_code))
self.client_conn.h2.safe_send_headers( self.client_conn.h2.safe_send_headers(
self.is_zombie, self.is_zombie,
self.client_stream_id, self.client_stream_id,
response.headers headers
) )
if self.zombie: # pragma: no cover if self.zombie: # pragma: no cover
raise Http2ProtocolException("Zombie Stream") raise Http2ProtocolException("Zombie Stream")

View File

@ -98,6 +98,11 @@ class HTTP2Protocol(object):
method = headers.get(':method', 'GET') method = headers.get(':method', 'GET')
scheme = headers.get(':scheme', 'https') scheme = headers.get(':scheme', 'https')
path = headers.get(':path', '/') path = headers.get(':path', '/')
headers.clear(":method")
headers.clear(":scheme")
headers.clear(":path")
host = None host = None
port = None port = None
@ -202,11 +207,8 @@ class HTTP2Protocol(object):
if ':authority' not in headers: if ':authority' not in headers:
headers.insert(0, b':authority', authority.encode('ascii')) headers.insert(0, b':authority', authority.encode('ascii'))
if ':scheme' not in headers:
headers.insert(0, b':scheme', request.scheme.encode('ascii')) headers.insert(0, b':scheme', request.scheme.encode('ascii'))
if ':path' not in headers:
headers.insert(0, b':path', request.path.encode('ascii')) headers.insert(0, b':path', request.path.encode('ascii'))
if ':method' not in headers:
headers.insert(0, b':method', request.method.encode('ascii')) headers.insert(0, b':method', request.method.encode('ascii'))
if hasattr(request, 'stream_id'): if hasattr(request, 'stream_id'):

View File

@ -171,6 +171,14 @@ class _MultiDict(MutableMapping, Serializable):
else: else:
return super(_MultiDict, self).items() return super(_MultiDict, self).items()
def clear(self, key):
"""
Removes all items with the specified key, and does not raise an
exception if the key does not exist.
"""
if key in self:
del self[key]
def to_dict(self): def to_dict(self):
""" """
Get the MultiDict as a plain Python dict. Get the MultiDict as a plain Python dict.

View File

@ -312,7 +312,10 @@ class TestReadRequest(tservers.ServerTestBase):
req = protocol.read_request(NotImplemented) req = protocol.read_request(NotImplemented)
assert req.stream_id assert req.stream_id
assert req.headers.fields == ((b':method', b'GET'), (b':path', b'/'), (b':scheme', b'https')) assert req.headers.fields == ()
assert req.method == "GET"
assert req.path == "/"
assert req.scheme == "https"
assert req.content == b'foobar' assert req.content == b'foobar'