Merge remote-tracking branch 'duffer/pretty-host'

This commit is contained in:
Maximilian Hils 2016-02-18 23:17:02 +01:00
commit ecb26c3c82
3 changed files with 39 additions and 5 deletions

View File

@ -1,5 +1,6 @@
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import re
import warnings import warnings
import six import six
@ -12,6 +13,10 @@ from .. import encoding
from .headers import Headers from .headers import Headers
from .message import Message, _native, _always_bytes, MessageData from .message import Message, _native, _always_bytes, MessageData
# This regex extracts & splits the host header into host and port.
# Handles the edge case of IPv6 addresses containing colons.
# https://bugzilla.mozilla.org/show_bug.cgi?id=45891
host_header_re = re.compile(r"^(?P<host>[^:]+|\[.+\])(?::(?P<port>\d+))?$")
class RequestData(MessageData): class RequestData(MessageData):
def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=None, content=None, def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=None, content=None,
@ -159,6 +164,18 @@ class Request(Message):
def url(self, url): def url(self, url):
self.scheme, self.host, self.port, self.path = utils.parse_url(url) self.scheme, self.host, self.port, self.path = utils.parse_url(url)
def _parse_host_header(self):
"""Extract the host and port from Host header"""
if "host" not in self.headers:
return None, None
host, port = self.headers["host"], None
m = host_header_re.match(host)
if m:
host = m.group("host").strip("[]")
if m.group("port"):
port = int(m.group("port"))
return host, port
@property @property
def pretty_host(self): def pretty_host(self):
""" """
@ -166,7 +183,13 @@ class Request(Message):
This is useful in transparent mode where :py:attr:`host` is only an IP address, This is useful in transparent mode where :py:attr:`host` is only an IP address,
but may not reflect the actual destination as the Host header could be spoofed. but may not reflect the actual destination as the Host header could be spoofed.
""" """
return self.headers.get("host", self.host) host, port = self._parse_host_header()
if not host:
return self.host
if not port:
port = 443 if self.scheme == 'https' else 80
# Prefer the original address if host header has an unexpected form
return host if port == self.port else self.host
@property @property
def pretty_url(self): def pretty_url(self):

View File

@ -1035,7 +1035,7 @@ class TestRequest:
assert r.url == "https://address:22/path" assert r.url == "https://address:22/path"
assert r.pretty_url == "https://address:22/path" assert r.pretty_url == "https://address:22/path"
r.headers["Host"] = "foo.com" r.headers["Host"] = "foo.com:22"
assert r.url == "https://address:22/path" assert r.url == "https://address:22/path"
assert r.pretty_url == "https://foo.com:22/path" assert r.pretty_url == "https://foo.com:22/path"

View File

@ -104,25 +104,36 @@ class TestRequestUtils(object):
def test_pretty_host(self): def test_pretty_host(self):
request = treq() request = treq()
# Without host header
assert request.pretty_host == "address" assert request.pretty_host == "address"
assert request.host == "address" assert request.host == "address"
request.headers["host"] = "other" # Same port as self.port (22)
request.headers["host"] = "other:22"
assert request.pretty_host == "other" assert request.pretty_host == "other"
# Different ports
request.headers["host"] = "other"
assert request.pretty_host == "address"
assert request.host == "address" assert request.host == "address"
# Empty host
request.host = None request.host = None
assert request.pretty_host is None assert request.pretty_host is None
assert request.host is None assert request.host is None
# Invalid IDNA # Invalid IDNA
request.headers["host"] = ".disqus.com" request.headers["host"] = ".disqus.com:22"
assert request.pretty_host == ".disqus.com" assert request.pretty_host == ".disqus.com"
def test_pretty_url(self): def test_pretty_url(self):
request = treq() request = treq()
# Without host header
assert request.url == "http://address:22/path" assert request.url == "http://address:22/path"
assert request.pretty_url == "http://address:22/path" assert request.pretty_url == "http://address:22/path"
request.headers["host"] = "other" # Same port as self.port (22)
request.headers["host"] = "other:22"
assert request.pretty_url == "http://other:22/path" assert request.pretty_url == "http://other:22/path"
# Different ports
request.headers["host"] = "other"
assert request.pretty_url == "http://address:22/path"
def test_pretty_url_authority(self): def test_pretty_url_authority(self):
request = treq(first_line_format="authority") request = treq(first_line_format="authority")