Merge pull request #2040 from mhils/request-host-header

Add "Request.host_header"
This commit is contained in:
Maximilian Hils 2017-02-18 12:08:54 +01:00 committed by GitHub
commit 4158a1ae55
6 changed files with 96 additions and 16 deletions

View File

@ -34,7 +34,7 @@ class Rerouter:
The original host header is retrieved early The original host header is retrieved early
before flow.request is replaced by mitmproxy new outgoing request before flow.request is replaced by mitmproxy new outgoing request
""" """
flow.metadata["original_host"] = flow.request.headers["Host"] flow.metadata["original_host"] = flow.request.host_header
def request(self, flow): def request(self, flow):
if flow.client_conn.ssl_established: if flow.client_conn.ssl_established:
@ -53,7 +53,7 @@ class Rerouter:
if m.group("port"): if m.group("port"):
port = int(m.group("port")) port = int(m.group("port"))
flow.request.headers["Host"] = host_header flow.request.host_header = host_header
flow.request.host = sni or host_header flow.request.host = sni or host_header
flow.request.port = port flow.request.port = port

View File

@ -73,7 +73,7 @@ def python_code(flow: http.HTTPFlow):
headers = flow.request.headers.copy() headers = flow.request.headers.copy()
# requests adds those by default. # requests adds those by default.
for x in ("host", "content-length"): for x in (":authority", "host", "content-length"):
headers.pop(x, None) headers.pop(x, None)
writearg("headers", dict(headers)) writearg("headers", dict(headers))
try: try:
@ -130,7 +130,7 @@ def locust_code(flow):
if flow.request.headers: if flow.request.headers:
lines = [ lines = [
(_native(k), _native(v)) for k, v in flow.request.headers.fields (_native(k), _native(v)) for k, v in flow.request.headers.fields
if _native(k).lower() not in ["host", "cookie"] if _native(k).lower() not in [":authority", "host", "cookie"]
] ]
lines = [" '%s': '%s',\n" % (k, v) for k, v in lines] lines = [" '%s': '%s',\n" % (k, v) for k, v in lines]
headers += "\n headers = {\n%s }\n" % "".join(lines) headers += "\n headers = {\n%s }\n" % "".join(lines)

View File

@ -78,8 +78,9 @@ def _assemble_request_headers(request_data):
Args: Args:
request_data (mitmproxy.net.http.request.RequestData) request_data (mitmproxy.net.http.request.RequestData)
""" """
headers = request_data.headers.copy() headers = request_data.headers
if "host" not in headers and request_data.scheme and request_data.host and request_data.port: if "host" not in headers and request_data.scheme and request_data.host and request_data.port:
headers = headers.copy()
headers["host"] = mitmproxy.net.http.url.hostport( headers["host"] = mitmproxy.net.http.url.hostport(
request_data.scheme, request_data.scheme,
request_data.host, request_data.host,

View File

@ -1,5 +1,6 @@
import re import re
import urllib import urllib
from typing import Optional
from mitmproxy.types import multidict from mitmproxy.types import multidict
from mitmproxy.utils import strutils from mitmproxy.utils import strutils
@ -164,11 +165,44 @@ class Request(message.Message):
self.data.host = host self.data.host = host
# Update host header # Update host header
if "host" in self.headers: if self.host_header is not None:
if host: self.host_header = host
self.headers["host"] = host
@property
def host_header(self) -> Optional[str]:
"""
The request's host/authority header.
This property maps to either ``request.headers["Host"]`` or
``request.headers[":authority"]``, depending on whether it's HTTP/1.x or HTTP/2.0.
"""
if ":authority" in self.headers:
return self.headers[":authority"]
if "Host" in self.headers:
return self.headers["Host"]
return None
@host_header.setter
def host_header(self, val: Optional[str]) -> None:
if val is None:
self.headers.pop("Host", None)
self.headers.pop(":authority", None)
elif self.host_header is not None:
# Update any existing headers.
if ":authority" in self.headers:
self.headers[":authority"] = val
if "Host" in self.headers:
self.headers["Host"] = val
else: else:
self.headers.pop("host") # Only add the correct new header.
if self.http_version.upper().startswith("HTTP/2"):
self.headers[":authority"] = val
else:
self.headers["Host"] = val
@host_header.deleter
def host_header(self):
self.host_header = None
@property @property
def port(self): def port(self):
@ -211,9 +245,10 @@ class Request(message.Message):
def _parse_host_header(self): def _parse_host_header(self):
"""Extract the host and port from Host header""" """Extract the host and port from Host header"""
if "host" not in self.headers: host = self.host_header
if not host:
return None, None return None, None
host, port = self.headers["host"], None port = None
m = host_header_re.match(host) m = host_header_re.match(host)
if m: if m:
host = m.group("host").strip("[]") host = m.group("host").strip("[]")

View File

@ -291,7 +291,7 @@ class HttpLayer(base.Layer):
# update host header in reverse proxy mode # update host header in reverse proxy mode
if self.config.options.mode == "reverse": if self.config.options.mode == "reverse":
f.request.headers["Host"] = self.config.upstream_server.address.host f.request.host_header = self.config.upstream_server.address.host
# Determine .scheme, .host and .port attributes for inline scripts. For # Determine .scheme, .host and .port attributes for inline scripts. For
# absolute-form requests, they are directly given in the request. For # absolute-form requests, they are directly given in the request. For
@ -301,11 +301,10 @@ class HttpLayer(base.Layer):
if self.mode is HTTPMode.transparent: if self.mode is HTTPMode.transparent:
# Setting request.host also updates the host header, which we want # Setting request.host also updates the host header, which we want
# to preserve # to preserve
host_header = f.request.headers.get("host", None) host_header = f.request.host_header
f.request.host = self.__initial_server_conn.address.host f.request.host = self.__initial_server_conn.address.host
f.request.port = self.__initial_server_conn.address.port f.request.port = self.__initial_server_conn.address.port
if host_header: f.request.host_header = host_header # set again as .host overwrites this.
f.request.headers["host"] = host_header
f.request.scheme = "https" if self.__initial_server_tls else "http" f.request.scheme = "https" if self.__initial_server_tls else "http"
self.channel.ask("request", f) self.channel.ask("request", f)

View File

@ -97,7 +97,7 @@ class TestRequestCore:
request.host = d request.host = d
assert request.data.host == b"foo\xFF\x00bar" assert request.data.host == b"foo\xFF\x00bar"
def test_host_header_update(self): def test_host_update_also_updates_header(self):
request = treq() request = treq()
assert "host" not in request.headers assert "host" not in request.headers
request.host = "example.com" request.host = "example.com"
@ -107,6 +107,51 @@ class TestRequestCore:
request.host = "example.org" request.host = "example.org"
assert request.headers["Host"] == "example.org" assert request.headers["Host"] == "example.org"
def test_get_host_header(self):
no_hdr = treq()
assert no_hdr.host_header is None
h1 = treq(headers=(
(b"host", b"example.com"),
))
assert h1.host_header == "example.com"
h2 = treq(headers=(
(b":authority", b"example.org"),
))
assert h2.host_header == "example.org"
both_hdrs = treq(headers=(
(b"host", b"example.org"),
(b":authority", b"example.com"),
))
assert both_hdrs.host_header == "example.com"
def test_modify_host_header(self):
h1 = treq()
assert "host" not in h1.headers
assert ":authority" not in h1.headers
h1.host_header = "example.com"
assert "host" in h1.headers
assert ":authority" not in h1.headers
h1.host_header = None
assert "host" not in h1.headers
h2 = treq(http_version=b"HTTP/2.0")
h2.host_header = "example.org"
assert "host" not in h2.headers
assert ":authority" in h2.headers
del h2.host_header
assert ":authority" not in h2.headers
both_hdrs = treq(headers=(
(b":authority", b"example.com"),
(b"host", b"example.org"),
))
both_hdrs.host_header = "foo.example.com"
assert both_hdrs.headers["Host"] == "foo.example.com"
assert both_hdrs.headers[":authority"] == "foo.example.com"
class TestRequestUtils: class TestRequestUtils:
""" """