mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2025-02-07 10:40:09 +00:00
Merge pull request #2040 from mhils/request-host-header
Add "Request.host_header"
This commit is contained in:
commit
4158a1ae55
@ -34,7 +34,7 @@ class Rerouter:
|
|||||||
The original host header is retrieved early
|
The original host header is retrieved early
|
||||||
before flow.request is replaced by mitmproxy new outgoing request
|
before flow.request is replaced by mitmproxy new outgoing request
|
||||||
"""
|
"""
|
||||||
flow.metadata["original_host"] = flow.request.headers["Host"]
|
flow.metadata["original_host"] = flow.request.host_header
|
||||||
|
|
||||||
def request(self, flow):
|
def request(self, flow):
|
||||||
if flow.client_conn.ssl_established:
|
if flow.client_conn.ssl_established:
|
||||||
@ -53,7 +53,7 @@ class Rerouter:
|
|||||||
if m.group("port"):
|
if m.group("port"):
|
||||||
port = int(m.group("port"))
|
port = int(m.group("port"))
|
||||||
|
|
||||||
flow.request.headers["Host"] = host_header
|
flow.request.host_header = host_header
|
||||||
flow.request.host = sni or host_header
|
flow.request.host = sni or host_header
|
||||||
flow.request.port = port
|
flow.request.port = port
|
||||||
|
|
||||||
|
@ -73,7 +73,7 @@ def python_code(flow: http.HTTPFlow):
|
|||||||
|
|
||||||
headers = flow.request.headers.copy()
|
headers = flow.request.headers.copy()
|
||||||
# requests adds those by default.
|
# requests adds those by default.
|
||||||
for x in ("host", "content-length"):
|
for x in (":authority", "host", "content-length"):
|
||||||
headers.pop(x, None)
|
headers.pop(x, None)
|
||||||
writearg("headers", dict(headers))
|
writearg("headers", dict(headers))
|
||||||
try:
|
try:
|
||||||
@ -130,7 +130,7 @@ def locust_code(flow):
|
|||||||
if flow.request.headers:
|
if flow.request.headers:
|
||||||
lines = [
|
lines = [
|
||||||
(_native(k), _native(v)) for k, v in flow.request.headers.fields
|
(_native(k), _native(v)) for k, v in flow.request.headers.fields
|
||||||
if _native(k).lower() not in ["host", "cookie"]
|
if _native(k).lower() not in [":authority", "host", "cookie"]
|
||||||
]
|
]
|
||||||
lines = [" '%s': '%s',\n" % (k, v) for k, v in lines]
|
lines = [" '%s': '%s',\n" % (k, v) for k, v in lines]
|
||||||
headers += "\n headers = {\n%s }\n" % "".join(lines)
|
headers += "\n headers = {\n%s }\n" % "".join(lines)
|
||||||
|
@ -78,8 +78,9 @@ def _assemble_request_headers(request_data):
|
|||||||
Args:
|
Args:
|
||||||
request_data (mitmproxy.net.http.request.RequestData)
|
request_data (mitmproxy.net.http.request.RequestData)
|
||||||
"""
|
"""
|
||||||
headers = request_data.headers.copy()
|
headers = request_data.headers
|
||||||
if "host" not in headers and request_data.scheme and request_data.host and request_data.port:
|
if "host" not in headers and request_data.scheme and request_data.host and request_data.port:
|
||||||
|
headers = headers.copy()
|
||||||
headers["host"] = mitmproxy.net.http.url.hostport(
|
headers["host"] = mitmproxy.net.http.url.hostport(
|
||||||
request_data.scheme,
|
request_data.scheme,
|
||||||
request_data.host,
|
request_data.host,
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
import urllib
|
import urllib
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from mitmproxy.types import multidict
|
from mitmproxy.types import multidict
|
||||||
from mitmproxy.utils import strutils
|
from mitmproxy.utils import strutils
|
||||||
@ -164,11 +165,44 @@ class Request(message.Message):
|
|||||||
self.data.host = host
|
self.data.host = host
|
||||||
|
|
||||||
# Update host header
|
# Update host header
|
||||||
if "host" in self.headers:
|
if self.host_header is not None:
|
||||||
if host:
|
self.host_header = host
|
||||||
self.headers["host"] = host
|
|
||||||
|
@property
|
||||||
|
def host_header(self) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
The request's host/authority header.
|
||||||
|
|
||||||
|
This property maps to either ``request.headers["Host"]`` or
|
||||||
|
``request.headers[":authority"]``, depending on whether it's HTTP/1.x or HTTP/2.0.
|
||||||
|
"""
|
||||||
|
if ":authority" in self.headers:
|
||||||
|
return self.headers[":authority"]
|
||||||
|
if "Host" in self.headers:
|
||||||
|
return self.headers["Host"]
|
||||||
|
return None
|
||||||
|
|
||||||
|
@host_header.setter
|
||||||
|
def host_header(self, val: Optional[str]) -> None:
|
||||||
|
if val is None:
|
||||||
|
self.headers.pop("Host", None)
|
||||||
|
self.headers.pop(":authority", None)
|
||||||
|
elif self.host_header is not None:
|
||||||
|
# Update any existing headers.
|
||||||
|
if ":authority" in self.headers:
|
||||||
|
self.headers[":authority"] = val
|
||||||
|
if "Host" in self.headers:
|
||||||
|
self.headers["Host"] = val
|
||||||
|
else:
|
||||||
|
# Only add the correct new header.
|
||||||
|
if self.http_version.upper().startswith("HTTP/2"):
|
||||||
|
self.headers[":authority"] = val
|
||||||
else:
|
else:
|
||||||
self.headers.pop("host")
|
self.headers["Host"] = val
|
||||||
|
|
||||||
|
@host_header.deleter
|
||||||
|
def host_header(self):
|
||||||
|
self.host_header = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def port(self):
|
def port(self):
|
||||||
@ -211,9 +245,10 @@ class Request(message.Message):
|
|||||||
|
|
||||||
def _parse_host_header(self):
|
def _parse_host_header(self):
|
||||||
"""Extract the host and port from Host header"""
|
"""Extract the host and port from Host header"""
|
||||||
if "host" not in self.headers:
|
host = self.host_header
|
||||||
|
if not host:
|
||||||
return None, None
|
return None, None
|
||||||
host, port = self.headers["host"], None
|
port = None
|
||||||
m = host_header_re.match(host)
|
m = host_header_re.match(host)
|
||||||
if m:
|
if m:
|
||||||
host = m.group("host").strip("[]")
|
host = m.group("host").strip("[]")
|
||||||
|
@ -291,7 +291,7 @@ class HttpLayer(base.Layer):
|
|||||||
|
|
||||||
# update host header in reverse proxy mode
|
# update host header in reverse proxy mode
|
||||||
if self.config.options.mode == "reverse":
|
if self.config.options.mode == "reverse":
|
||||||
f.request.headers["Host"] = self.config.upstream_server.address.host
|
f.request.host_header = self.config.upstream_server.address.host
|
||||||
|
|
||||||
# Determine .scheme, .host and .port attributes for inline scripts. For
|
# Determine .scheme, .host and .port attributes for inline scripts. For
|
||||||
# absolute-form requests, they are directly given in the request. For
|
# absolute-form requests, they are directly given in the request. For
|
||||||
@ -301,11 +301,10 @@ class HttpLayer(base.Layer):
|
|||||||
if self.mode is HTTPMode.transparent:
|
if self.mode is HTTPMode.transparent:
|
||||||
# Setting request.host also updates the host header, which we want
|
# Setting request.host also updates the host header, which we want
|
||||||
# to preserve
|
# to preserve
|
||||||
host_header = f.request.headers.get("host", None)
|
host_header = f.request.host_header
|
||||||
f.request.host = self.__initial_server_conn.address.host
|
f.request.host = self.__initial_server_conn.address.host
|
||||||
f.request.port = self.__initial_server_conn.address.port
|
f.request.port = self.__initial_server_conn.address.port
|
||||||
if host_header:
|
f.request.host_header = host_header # set again as .host overwrites this.
|
||||||
f.request.headers["host"] = host_header
|
|
||||||
f.request.scheme = "https" if self.__initial_server_tls else "http"
|
f.request.scheme = "https" if self.__initial_server_tls else "http"
|
||||||
self.channel.ask("request", f)
|
self.channel.ask("request", f)
|
||||||
|
|
||||||
|
@ -97,7 +97,7 @@ class TestRequestCore:
|
|||||||
request.host = d
|
request.host = d
|
||||||
assert request.data.host == b"foo\xFF\x00bar"
|
assert request.data.host == b"foo\xFF\x00bar"
|
||||||
|
|
||||||
def test_host_header_update(self):
|
def test_host_update_also_updates_header(self):
|
||||||
request = treq()
|
request = treq()
|
||||||
assert "host" not in request.headers
|
assert "host" not in request.headers
|
||||||
request.host = "example.com"
|
request.host = "example.com"
|
||||||
@ -107,6 +107,51 @@ class TestRequestCore:
|
|||||||
request.host = "example.org"
|
request.host = "example.org"
|
||||||
assert request.headers["Host"] == "example.org"
|
assert request.headers["Host"] == "example.org"
|
||||||
|
|
||||||
|
def test_get_host_header(self):
|
||||||
|
no_hdr = treq()
|
||||||
|
assert no_hdr.host_header is None
|
||||||
|
|
||||||
|
h1 = treq(headers=(
|
||||||
|
(b"host", b"example.com"),
|
||||||
|
))
|
||||||
|
assert h1.host_header == "example.com"
|
||||||
|
|
||||||
|
h2 = treq(headers=(
|
||||||
|
(b":authority", b"example.org"),
|
||||||
|
))
|
||||||
|
assert h2.host_header == "example.org"
|
||||||
|
|
||||||
|
both_hdrs = treq(headers=(
|
||||||
|
(b"host", b"example.org"),
|
||||||
|
(b":authority", b"example.com"),
|
||||||
|
))
|
||||||
|
assert both_hdrs.host_header == "example.com"
|
||||||
|
|
||||||
|
def test_modify_host_header(self):
|
||||||
|
h1 = treq()
|
||||||
|
assert "host" not in h1.headers
|
||||||
|
assert ":authority" not in h1.headers
|
||||||
|
h1.host_header = "example.com"
|
||||||
|
assert "host" in h1.headers
|
||||||
|
assert ":authority" not in h1.headers
|
||||||
|
h1.host_header = None
|
||||||
|
assert "host" not in h1.headers
|
||||||
|
|
||||||
|
h2 = treq(http_version=b"HTTP/2.0")
|
||||||
|
h2.host_header = "example.org"
|
||||||
|
assert "host" not in h2.headers
|
||||||
|
assert ":authority" in h2.headers
|
||||||
|
del h2.host_header
|
||||||
|
assert ":authority" not in h2.headers
|
||||||
|
|
||||||
|
both_hdrs = treq(headers=(
|
||||||
|
(b":authority", b"example.com"),
|
||||||
|
(b"host", b"example.org"),
|
||||||
|
))
|
||||||
|
both_hdrs.host_header = "foo.example.com"
|
||||||
|
assert both_hdrs.headers["Host"] == "foo.example.com"
|
||||||
|
assert both_hdrs.headers[":authority"] == "foo.example.com"
|
||||||
|
|
||||||
|
|
||||||
class TestRequestUtils:
|
class TestRequestUtils:
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user