HTTPRequest -> http.Request, add request.authority

This commit is contained in:
Maximilian Hils 2020-07-16 15:00:41 +02:00
parent 2dfcb537f2
commit 5af57cfa99
57 changed files with 1067 additions and 1036 deletions

View File

@ -4,7 +4,7 @@ from mitmproxy import ctx
def request(flow):
# Avoid an infinite loop by not replaying already replayed requests
if flow.request.is_replay:
if flow.is_replay == "request":
return
flow = flow.copy()
# Only interactive tools have a view. If we have one, add a duplicate entry

View File

@ -17,6 +17,7 @@ from mitmproxy import options
from mitmproxy.coretypes import basethread
from mitmproxy.net import server_spec, tls
from mitmproxy.net.http import http1
from mitmproxy.net.http.url import hostport
from mitmproxy.utils import human
@ -46,7 +47,7 @@ class RequestReplayThread(basethread.BaseThread):
f.live = True
r = f.request
bsl = human.parse_size(self.options.body_size_limit)
first_line_format_backup = r.first_line_format
authority_backup = r.authority
server = None
try:
f.response = None
@ -79,9 +80,9 @@ class RequestReplayThread(basethread.BaseThread):
sni=f.server_conn.sni,
**tls.client_arguments_from_options(self.options)
)
r.first_line_format = "relative"
r.authority = b""
else:
r.first_line_format = "absolute"
r.authority = hostport(r.scheme, r.host, r.port)
else:
server_address = (r.host, r.port)
server = connections.ServerConnection(server_address)
@ -91,7 +92,7 @@ class RequestReplayThread(basethread.BaseThread):
sni=f.server_conn.sni,
**tls.client_arguments_from_options(self.options)
)
r.first_line_format = "relative"
r.authority = ""
server.wfile.write(http1.assemble_request(r))
server.wfile.flush()
@ -101,9 +102,7 @@ class RequestReplayThread(basethread.BaseThread):
f.server_conn.close()
f.server_conn = server
f.response = http.HTTPResponse.wrap(
http1.read_response(server.rfile, r, body_size_limit=bsl)
)
f.response = http1.read_response(server.rfile, r, body_size_limit=bsl)
response_reply = self.channel.ask("response", f)
if response_reply == exceptions.Kill:
raise exceptions.Kill()
@ -115,7 +114,7 @@ class RequestReplayThread(basethread.BaseThread):
except Exception as e:
self.channel.tell("log", log.LogEntry(repr(e), "error"))
finally:
r.first_line_format = first_line_format_backup
r.authority = authority_backup
f.live = False
if server and server.connected():
server.finish()
@ -200,7 +199,7 @@ class ClientPlayback:
lst.append(hf)
# Prepare the flow for replay
hf.backup()
hf.request.is_replay = True
hf.is_replay = "request"
hf.response = None
hf.error = None
# https://github.com/mitmproxy/mitmproxy/issues/2197

View File

@ -127,7 +127,7 @@ class Dumper:
human.format_address(flow.client_conn.address)
)
)
elif flow.request.is_replay:
elif flow.is_replay == "request":
client = click.style("[replay]", fg="yellow", bold=True)
else:
client = ""
@ -166,7 +166,7 @@ class Dumper:
self.echo(line)
def _echo_response_line(self, flow):
if flow.response.is_replay:
if flow.is_replay == "response":
replay = click.style("[replay] ", fg="yellow", bold=True)
else:
replay = ""

View File

@ -37,7 +37,7 @@ class Intercept:
if self.filt:
should_intercept = all([
self.filt(f),
not f.request.is_replay,
not f.is_replay == "request",
])
if should_intercept and ctx.options.intercept_active:
f.intercept()

View File

@ -47,4 +47,4 @@ class MapRemote:
# this is a bit messy: setting .url also updates the host header,
# so we really only do that if the replacement affected the URL.
if url != new_url:
flow.request.url = new_url
flow.request.url = new_url # type: ignore

View File

@ -202,10 +202,10 @@ class ServerPlayback:
if rflow:
assert rflow.response
response = rflow.response.copy()
response.is_replay = True
if ctx.options.server_replay_refresh:
response.refresh()
f.response = response
f.is_replay = "response"
elif ctx.options.server_replay_kill_extra:
ctx.log.warn(
"server_playback: killed non-replay request {}".format(

View File

@ -1,5 +1,8 @@
import abc
import uuid
from typing import Type, TypeVar
T = TypeVar('T', bound='Serializable')
class Serializable(metaclass=abc.ABCMeta):
@ -9,7 +12,7 @@ class Serializable(metaclass=abc.ABCMeta):
@classmethod
@abc.abstractmethod
def from_state(cls, state):
def from_state(cls: Type[T], state) -> T:
"""
Create a new object from the given state.
"""
@ -29,7 +32,7 @@ class Serializable(metaclass=abc.ABCMeta):
"""
raise NotImplementedError()
def copy(self):
def copy(self: T) -> T:
state = self.get_state()
if isinstance(state, dict) and "id" in state:
state["id"] = str(uuid.uuid4())

View File

@ -77,6 +77,7 @@ class Flow(stateobject.StateObject):
self._backup: typing.Optional[Flow] = None
self.reply: typing.Optional[controller.Reply] = None
self.marked: bool = False
self.is_replay: typing.Optional[str] = None
self.metadata: typing.Dict[str, typing.Any] = dict()
_stateobject_attributes = dict(
@ -86,6 +87,7 @@ class Flow(stateobject.StateObject):
server_conn=connections.ServerConnection,
type=str,
intercepted=bool,
is_replay=str,
marked=bool,
metadata=typing.Dict[str, typing.Any],
)

View File

@ -1,142 +1,13 @@
import html
from typing import Optional
import time
from typing import Optional, Tuple
from mitmproxy import connections
from mitmproxy import flow
from mitmproxy import version
from mitmproxy.net import http
class HTTPRequest(http.Request):
"""
A mitmproxy HTTP request.
"""
# This is a very thin wrapper on top of :py:class:`mitmproxy.net.http.Request` and
# may be removed in the future.
def __init__(
self,
first_line_format,
method,
scheme,
host,
port,
path,
http_version,
headers,
content,
trailers=None,
timestamp_start=None,
timestamp_end=None,
is_replay=False,
):
http.Request.__init__(
self,
first_line_format,
method,
scheme,
host,
port,
path,
http_version,
headers,
content,
trailers,
timestamp_start,
timestamp_end,
)
# Is this request replayed?
self.is_replay = is_replay
self.stream = None
def get_state(self):
state = super().get_state()
state["is_replay"] = self.is_replay
return state
def set_state(self, state):
state = state.copy()
self.is_replay = state.pop("is_replay")
super().set_state(state)
@classmethod
def wrap(self, request):
"""
Wraps an existing :py:class:`mitmproxy.net.http.Request`.
"""
req = HTTPRequest(
first_line_format=request.data.first_line_format,
method=request.data.method,
scheme=request.data.scheme,
host=request.data.host,
port=request.data.port,
path=request.data.path,
http_version=request.data.http_version,
headers=request.data.headers,
content=request.data.content,
trailers=request.data.trailers,
timestamp_start=request.data.timestamp_start,
timestamp_end=request.data.timestamp_end,
)
return req
def __hash__(self):
return id(self)
class HTTPResponse(http.Response):
"""
A mitmproxy HTTP response.
"""
# This is a very thin wrapper on top of :py:class:`mitmproxy.net.http.Response` and
# may be removed in the future.
def __init__(
self,
http_version,
status_code,
reason,
headers,
content,
trailers=None,
timestamp_start=None,
timestamp_end=None,
is_replay=False
):
http.Response.__init__(
self,
http_version,
status_code,
reason,
headers,
content,
trailers,
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
)
# Is this request replayed?
self.is_replay = is_replay
self.stream = None
@classmethod
def wrap(self, response):
"""
Wraps an existing :py:class:`mitmproxy.net.http.Response`.
"""
resp = HTTPResponse(
http_version=response.data.http_version,
status_code=response.data.status_code,
reason=response.data.reason,
headers=response.data.headers,
content=response.data.content,
trailers=response.data.trailers,
timestamp_start=response.data.timestamp_start,
timestamp_end=response.data.timestamp_end,
)
return resp
HTTPRequest = http.Request
HTTPResponse = http.Response
class HTTPFlow(flow.Flow):
@ -197,8 +68,7 @@ def make_error_response(
message: str = "",
headers: Optional[http.Headers] = None,
) -> HTTPResponse:
reason = http.status_codes.RESPONSES.get(status_code, "Unknown")
body = """
body: bytes = """
<html>
<head>
<title>{status_code} {reason}</title>
@ -210,7 +80,7 @@ def make_error_response(
</html>
""".strip().format(
status_code=status_code,
reason=reason,
reason=http.status_codes.RESPONSES.get(status_code, "Unknown"),
message=html.escape(message),
).encode("utf8", "replace")
@ -222,19 +92,23 @@ def make_error_response(
Content_Type="text/html"
)
return HTTPResponse(
b"HTTP/1.1",
status_code,
reason,
headers,
body,
)
return HTTPResponse.make(status_code, body, headers)
def make_connect_request(address):
def make_connect_request(address: Tuple[str, int]) -> HTTPRequest:
return HTTPRequest(
"authority", b"CONNECT", None, address[0], address[1], None, b"HTTP/1.1",
http.Headers(), b""
host=address[0],
port=address[1],
method=b"CONNECT",
scheme=b"",
authority=f"{address[0]}:{address[1]}".encode(),
path=b"",
http_version=b"HTTP/1.1",
headers=http.Headers(),
content=b"",
trailers=None,
timestamp_start=time.time(),
timestamp_end=time.time(),
)
@ -247,9 +121,11 @@ def make_connect_response(http_version):
b"Connection established",
http.Headers(),
b"",
None,
time.time(),
time.time(),
)
expect_continue_response = HTTPResponse(
b"HTTP/1.1", 100, b"Continue", http.Headers(), b""
)
def make_expect_continue_response():
return HTTPResponse.make(100)

View File

@ -179,6 +179,21 @@ def convert_7_8(data):
return data
def convert_8_9(data):
data["version"] = 9
data["request"].pop("first_line_format")
data["request"]["authority"] = b""
is_request_replay = data["request"].pop("is_replay", False)
is_response_replay = data["response"].pop("is_replay", False)
if is_request_replay: # pragma: no cover
data["is_replay"] = "request"
elif is_response_replay: # pragma: no cover
data["is_replay"] = "response"
else:
data["is_replay"] = None
return data
def _convert_dict_keys(o: Any) -> Any:
if isinstance(o, dict):
return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()}
@ -234,6 +249,7 @@ converters = {
5: convert_5_6,
6: convert_6_7,
7: convert_7_8,
8: convert_8_9,
}
@ -251,8 +267,8 @@ def migrate_flow(flow_data: Dict[Union[bytes, str], Any]) -> Dict[Union[bytes, s
flow_data = converters[flow_version](flow_data)
else:
should_upgrade = (
isinstance(flow_version, int)
and flow_version > version.FLOW_FORMAT_VERSION
isinstance(flow_version, int)
and flow_version > version.FLOW_FORMAT_VERSION
)
raise ValueError(
"{} cannot read files with flow format version {}{}.".format(

View File

@ -106,8 +106,8 @@ def dumps(f: flow.Flow) -> bytes:
def _load_http_request(o: http_pb2.HTTPRequest) -> HTTPRequest:
d: dict = {}
_move_attrs(o, d, ['first_line_format', 'method', 'scheme', 'host', 'port', 'path', 'http_version', 'content',
'timestamp_start', 'timestamp_end', 'is_replay'])
_move_attrs(o, d, ['host', 'port', 'method', 'scheme', 'authority', 'path', 'http_version', 'content',
'timestamp_start', 'timestamp_end'])
if d['content'] is None:
d['content'] = b""
d["headers"] = []
@ -120,7 +120,7 @@ def _load_http_request(o: http_pb2.HTTPRequest) -> HTTPRequest:
def _load_http_response(o: http_pb2.HTTPResponse) -> HTTPResponse:
d: dict = {}
_move_attrs(o, d, ['http_version', 'status_code', 'reason',
'content', 'timestamp_start', 'timestamp_end', 'is_replay'])
'content', 'timestamp_start', 'timestamp_end'])
if d['content'] is None:
d['content'] = b""
d["headers"] = []

View File

@ -3,28 +3,37 @@ import re
# Allow underscore in host name
# Note: This could be a DNS label, a hostname, a FQDN, or an IP
from typing import AnyStr
_label_valid = re.compile(br"[A-Z\d\-_]{1,63}$", re.IGNORECASE)
def is_valid_host(host: bytes) -> bool:
def is_valid_host(host: AnyStr) -> bool:
"""
Checks if the passed bytes are a valid DNS hostname or an IPv4/IPv6 address.
"""
if isinstance(host, str):
try:
host_bytes = host.encode("idna")
except UnicodeError:
return False
else:
host_bytes = host
try:
host.decode("idna")
host_bytes.decode("idna")
except ValueError:
return False
# RFC1035: 255 bytes or less.
if len(host) > 255:
if len(host_bytes) > 255:
return False
if host and host[-1:] == b".":
host = host[:-1]
if host_bytes and host_bytes.endswith(b"."):
host_bytes = host_bytes[:-1]
# DNS hostname
if all(_label_valid.match(x) for x in host.split(b".")):
if all(_label_valid.match(x) for x in host_bytes.split(b".")):
return True
# IPv4/IPv6 address
try:
ipaddress.ip_address(host.decode('idna'))
ipaddress.ip_address(host_bytes.decode('idna'))
return True
except ValueError:
return False

View File

@ -11,8 +11,7 @@ import zlib
import brotli
import zstandard as zstd
from typing import Union, Optional, AnyStr # noqa
from typing import Union, Optional, AnyStr, overload # noqa
# We have a shared single-element cache for encoding and decoding.
# This is quite useful in practice, e.g.
@ -24,9 +23,24 @@ CachedDecode = collections.namedtuple(
_cache = CachedDecode(None, None, None, None)
@overload
def decode(encoded: None, encoding: str, errors: str = 'strict') -> None:
...
@overload
def decode(encoded: str, encoding: str, errors: str = 'strict') -> str:
...
@overload
def decode(encoded: bytes, encoding: str, errors: str = 'strict') -> Union[str, bytes]:
...
def decode(
encoded: Optional[bytes], encoding: str, errors: str='strict'
) -> Optional[AnyStr]:
encoded: Union[None, str, bytes], encoding: str, errors: str = 'strict'
) -> Union[None, str, bytes]:
"""
Decode the given input object
@ -41,10 +55,10 @@ def decode(
global _cache
cached = (
isinstance(encoded, bytes) and
_cache.encoded == encoded and
_cache.encoding == encoding and
_cache.errors == errors
isinstance(encoded, bytes) and
_cache.encoded == encoded and
_cache.encoding == encoding and
_cache.errors == errors
)
if cached:
return _cache.decoded
@ -52,7 +66,7 @@ def decode(
try:
decoded = custom_decode[encoding](encoded)
except KeyError:
decoded = codecs.decode(encoded, encoding, errors)
decoded = codecs.decode(encoded, encoding, errors) # type: ignore
if encoding in ("gzip", "deflate", "br", "zstd"):
_cache = CachedDecode(encoded, encoding, errors, decoded)
return decoded
@ -67,7 +81,22 @@ def decode(
))
def encode(decoded: Optional[str], encoding: str, errors: str='strict') -> Optional[AnyStr]:
@overload
def encode(decoded: None, encoding: str, errors: str = 'strict') -> None:
...
@overload
def encode(decoded: str, encoding: str, errors: str = 'strict') -> Union[str, bytes]:
...
@overload
def encode(decoded: bytes, encoding: str, errors: str = 'strict') -> bytes:
...
def encode(decoded: Union[None, str, bytes], encoding, errors='strict') -> Union[None, str, bytes]:
"""
Encode the given input object
@ -82,10 +111,10 @@ def encode(decoded: Optional[str], encoding: str, errors: str='strict') -> Optio
global _cache
cached = (
isinstance(decoded, bytes) and
_cache.decoded == decoded and
_cache.encoding == encoding and
_cache.errors == errors
isinstance(decoded, bytes) and
_cache.decoded == decoded and
_cache.encoding == encoding and
_cache.errors == errors
)
if cached:
return _cache.encoded
@ -93,7 +122,7 @@ def encode(decoded: Optional[str], encoding: str, errors: str='strict') -> Optio
try:
encoded = custom_encode[encoding](decoded)
except KeyError:
encoded = codecs.encode(decoded, encoding, errors)
encoded = codecs.encode(decoded, encoding, errors) # type: ignore
if encoding in ("gzip", "deflate", "br", "zstd"):
_cache = CachedDecode(encoded, encoding, errors, decoded)
return encoded
@ -150,7 +179,7 @@ def decode_zstd(content: bytes) -> bytes:
except zstd.ZstdError:
# If the zstd stream is streamed without a size header,
# try decoding with a 10MiB output buffer
return zstd_ctx.decompress(content, max_output_size=10 * 2**20)
return zstd_ctx.decompress(content, max_output_size=10 * 2 ** 20)
def encode_zstd(content: bytes) -> bytes:

View File

@ -1,7 +1,10 @@
import collections
from typing import Dict, Optional, Tuple
from mitmproxy.coretypes import multidict
from mitmproxy.utils import strutils
# See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/
@ -146,7 +149,7 @@ class Headers(multidict.MultiDict):
return super().items()
def parse_content_type(c):
def parse_content_type(c: str) -> Optional[Tuple[str, str, Dict[str, str]]]:
"""
A simple parser for content-type values. Returns a (type, subtype,
parameters) tuple, where type and subtype are strings, and parameters

View File

@ -45,31 +45,26 @@ def _assemble_request_line(request_data):
Args:
request_data (mitmproxy.net.http.request.RequestData)
"""
form = request_data.first_line_format
if form == "relative":
if request_data.method.upper() == b"CONNECT":
return b"%s %s %s" % (
request_data.method,
request_data.authority,
request_data.http_version
)
elif request_data.authority:
return b"%s %s://%s%s %s" % (
request_data.method,
request_data.scheme,
request_data.authority,
request_data.path,
request_data.http_version
)
else:
return b"%s %s %s" % (
request_data.method,
request_data.path,
request_data.http_version
)
elif form == "authority":
return b"%s %s:%d %s" % (
request_data.method,
request_data.host,
request_data.port,
request_data.http_version
)
elif form == "absolute":
return b"%s %s://%s:%d%s %s" % (
request_data.method,
request_data.scheme,
request_data.host,
request_data.port,
request_data.path,
request_data.http_version
)
else:
raise RuntimeError("Invalid request form")
def _assemble_request_headers(request_data):

View File

@ -1,15 +1,13 @@
import time
import sys
import re
import sys
import time
import typing
from mitmproxy import exceptions
from mitmproxy.net.http import headers
from mitmproxy.net.http import request
from mitmproxy.net.http import response
from mitmproxy.net.http import headers
from mitmproxy.net.http import url
from mitmproxy.net import check
from mitmproxy import exceptions
def get_header_tokens(headers, key):
@ -51,7 +49,7 @@ def read_request_head(rfile):
if hasattr(rfile, "reset_timestamps"):
rfile.reset_timestamps()
form, method, scheme, host, port, path, http_version = _read_request_line(rfile)
host, port, method, scheme, authority, path, http_version = _read_request_line(rfile)
headers = _read_headers(rfile)
if hasattr(rfile, "first_byte_timestamp"):
@ -59,7 +57,7 @@ def read_request_head(rfile):
timestamp_start = rfile.first_byte_timestamp
return request.Request(
form, method, scheme, host, port, path, http_version, headers, None, None, timestamp_start
host, port, method, scheme, authority, path, http_version, headers, None, None, timestamp_start, None
)
@ -98,7 +96,7 @@ def read_response_head(rfile):
# more accurate timestamp_start
timestamp_start = rfile.first_byte_timestamp
return response.Response(http_version, status_code, message, headers, None, None, timestamp_start)
return response.Response(http_version, status_code, message, headers, None, None, timestamp_start, None)
def read_body(rfile, expected_size, limit=None, max_chunk_size=4096):
@ -248,45 +246,32 @@ def _read_request_line(rfile):
raise exceptions.HttpReadDisconnect("Client disconnected")
try:
method, path, http_version = line.split()
method, target, http_version = line.split()
if path == b"*" or path.startswith(b"/"):
form = "relative"
scheme, host, port = None, None, None
if target == b"*" or target.startswith(b"/"):
scheme, authority, path = b"", b"", target
host, port = "", 0
elif method == b"CONNECT":
form = "authority"
host, port = _parse_authority_form(path)
scheme, path = None, None
scheme, authority, path = b"", target, b""
host, port = url.parse_authority(authority, check=True)
if not port:
raise ValueError
else:
form = "absolute"
scheme, host, port, path = url.parse(path)
scheme, rest = target.split(b"://", maxsplit=1)
authority, path_ = rest.split(b"/", maxsplit=1)
path = b"/" + path_
host, port = url.parse_authority(authority, check=True)
port = port or url.default_port(scheme)
if not port:
raise ValueError
# TODO: we can probably get rid of this check?
url.parse(target)
_check_http_version(http_version)
except ValueError:
raise exceptions.HttpSyntaxException("Bad HTTP request line: {}".format(line))
raise exceptions.HttpSyntaxException(f"Bad HTTP request line: {line}")
return form, method, scheme, host, port, path, http_version
def _parse_authority_form(hostport):
"""
Returns (host, port) if hostport is a valid authority-form host specification.
http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1
Raises:
ValueError, if the input is malformed
"""
try:
host, port = hostport.rsplit(b":", 1)
if host.startswith(b"[") and host.endswith(b"]"):
host = host[1:-1]
port = int(port)
if not check.is_valid_host(host) or not check.is_valid_port(port):
raise ValueError()
except ValueError:
raise exceptions.HttpSyntaxException("Invalid host specification: {}".format(hostport))
return host, port
return host, port, method, scheme, authority, path, http_version
def _read_response_line(rfile):

View File

@ -1,50 +1,54 @@
import re
from typing import Optional # noqa
from dataclasses import dataclass, fields
from typing import Callable, Optional, Union, cast
from mitmproxy.utils import strutils
from mitmproxy.net.http import encoding
from mitmproxy.coretypes import serializable
from mitmproxy.net.http import headers as mheaders
from mitmproxy.net.http import encoding
from mitmproxy.net.http.headers import Headers, assemble_content_type, parse_content_type
from mitmproxy.utils import strutils, typecheck
@dataclass
class MessageData(serializable.Serializable):
headers: mheaders.Headers
content: bytes
http_version: bytes
headers: Headers
content: Optional[bytes]
trailers: Optional[Headers]
timestamp_start: float
timestamp_end: float
timestamp_end: Optional[float]
def __eq__(self, other):
if isinstance(other, MessageData):
return self.__dict__ == other.__dict__
return False
# noinspection PyUnreachableCode
if __debug__:
def __post_init__(self):
for field in fields(self):
val = getattr(self, field.name)
typecheck.check_option_type(field.name, val, field.type)
def set_state(self, state):
for k, v in state.items():
if k == "headers":
v = mheaders.Headers.from_state(v)
if k in ("headers", "trailers") and v is not None:
v = Headers.from_state(v)
setattr(self, k, v)
def get_state(self):
state = vars(self).copy()
state["headers"] = state["headers"].get_state()
if 'trailers' in state and state["trailers"] is not None:
if state["trailers"] is not None:
state["trailers"] = state["trailers"].get_state()
return state
@classmethod
def from_state(cls, state):
state["headers"] = mheaders.Headers.from_state(state["headers"])
state["headers"] = Headers.from_state(state["headers"])
if state["trailers"] is not None:
state["trailers"] = Headers.from_state(state["trailers"])
return cls(**state)
class Message(serializable.Serializable):
data: MessageData
def __eq__(self, other):
if isinstance(other, Message):
return self.data == other.data
return False
@classmethod
def from_state(cls, state):
return cls(**state)
def get_state(self):
return self.data.get_state()
@ -52,29 +56,48 @@ class Message(serializable.Serializable):
def set_state(self, state):
self.data.set_state(state)
@classmethod
def from_state(cls, state):
state["headers"] = mheaders.Headers.from_state(state["headers"])
if 'trailers' in state and state["trailers"] is not None:
state["trailers"] = mheaders.Headers.from_state(state["trailers"])
return cls(**state)
data: MessageData
stream: Union[Callable, bool] = False
@property
def headers(self):
def http_version(self) -> str:
"""
Message headers object
Version string, e.g. "HTTP/1.1"
"""
return self.data.http_version.decode("utf-8", "surrogateescape")
Returns:
mitmproxy.net.http.Headers
@http_version.setter
def http_version(self, http_version: Union[str, bytes]) -> None:
self.data.http_version = strutils.always_bytes(http_version, "utf-8", "surrogateescape")
@property
def is_http2(self) -> bool:
return self.data.http_version == b"HTTP/2.0"
@property
def headers(self) -> Headers:
"""
The HTTP headers.
"""
return self.data.headers
@headers.setter
def headers(self, h):
def headers(self, h: Headers) -> None:
self.data.headers = h
@property
def raw_content(self) -> bytes:
def trailers(self) -> Optional[Headers]:
"""
The HTTP trailers.
"""
return self.data.trailers
@trailers.setter
def trailers(self, h: Optional[Headers]) -> None:
self.data.trailers = h
@property
def raw_content(self) -> Optional[bytes]:
"""
The raw (potentially compressed) HTTP message body as bytes.
@ -83,10 +106,10 @@ class Message(serializable.Serializable):
return self.data.content
@raw_content.setter
def raw_content(self, content):
def raw_content(self, content: Optional[bytes]) -> None:
self.data.content = content
def get_content(self, strict: bool=True) -> Optional[bytes]:
def get_content(self, strict: bool = True) -> Optional[bytes]:
"""
The uncompressed HTTP message body as bytes.
@ -112,15 +135,14 @@ class Message(serializable.Serializable):
else:
return self.raw_content
def set_content(self, value):
def set_content(self, value: Optional[bytes]) -> None:
if value is None:
self.raw_content = None
return
if not isinstance(value, bytes):
raise TypeError(
"Message content must be bytes, not {}. "
f"Message content must be bytes, not {type(value).__name__}. "
"Please use .text if you want to assign a str."
.format(type(value).__name__)
)
ce = self.headers.get("content-encoding")
try:
@ -135,59 +157,34 @@ class Message(serializable.Serializable):
content = property(get_content, set_content)
@property
def trailers(self):
"""
Message trailers object
Returns:
mitmproxy.net.http.Headers
"""
return self.data.trailers
@trailers.setter
def trailers(self, h):
self.data.trailers = h
@property
def http_version(self):
"""
Version string, e.g. "HTTP/1.1"
"""
return self.data.http_version.decode("utf-8", "surrogateescape")
@http_version.setter
def http_version(self, http_version):
self.data.http_version = strutils.always_bytes(http_version, "utf-8", "surrogateescape")
@property
def timestamp_start(self):
def timestamp_start(self) -> float:
"""
First byte timestamp
"""
return self.data.timestamp_start
@timestamp_start.setter
def timestamp_start(self, timestamp_start):
def timestamp_start(self, timestamp_start: float) -> None:
self.data.timestamp_start = timestamp_start
@property
def timestamp_end(self):
def timestamp_end(self) -> Optional[float]:
"""
Last byte timestamp
"""
return self.data.timestamp_end
@timestamp_end.setter
def timestamp_end(self, timestamp_end):
def timestamp_end(self, timestamp_end: Optional[float]):
self.data.timestamp_end = timestamp_end
def _get_content_type_charset(self) -> Optional[str]:
ct = mheaders.parse_content_type(self.headers.get("content-type", ""))
ct = parse_content_type(self.headers.get("content-type", ""))
if ct:
return ct[2].get("charset")
return None
def _guess_encoding(self, content=b"") -> str:
def _guess_encoding(self, content: bytes = b"") -> str:
enc = self._get_content_type_charset()
if not enc:
if "json" in self.headers.get("content-type", ""):
@ -204,7 +201,7 @@ class Message(serializable.Serializable):
return enc
def get_text(self, strict: bool=True) -> Optional[str]:
def get_text(self, strict: bool = True) -> Optional[str]:
"""
The uncompressed and decoded HTTP message body as text.
@ -218,13 +215,13 @@ class Message(serializable.Serializable):
return None
enc = self._guess_encoding(content)
try:
return encoding.decode(content, enc)
return cast(str, encoding.decode(content, enc))
except ValueError:
if strict:
raise
return content.decode("utf8", "surrogateescape")
def set_text(self, text):
def set_text(self, text: Optional[str]) -> None:
if text is None:
self.content = None
return
@ -234,15 +231,15 @@ class Message(serializable.Serializable):
self.content = encoding.encode(text, enc)
except ValueError:
# Fall back to UTF-8 and update the content-type header.
ct = mheaders.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {})
ct = parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {})
ct[2]["charset"] = "utf-8"
self.headers["content-type"] = mheaders.assemble_content_type(*ct)
self.headers["content-type"] = assemble_content_type(*ct)
enc = "utf8"
self.content = text.encode(enc, "surrogateescape")
text = property(get_text, set_text)
def decode(self, strict=True):
def decode(self, strict: bool = True) -> None:
"""
Decodes body based on the current Content-Encoding header, then
removes the header. If there is no Content-Encoding header, no
@ -255,7 +252,7 @@ class Message(serializable.Serializable):
self.headers.pop("content-encoding", None)
self.content = decoded
def encode(self, e):
def encode(self, e: str) -> None:
"""
Encodes body with the encoding e, where e is "gzip", "deflate", "identity", "br", or "zstd".
Any existing content-encodings are overwritten,

View File

@ -1,67 +1,24 @@
import re
import urllib
import time
from typing import Optional, AnyStr, Dict, Iterable, Tuple, Union
import urllib.parse
from dataclasses import dataclass
from typing import Dict, Iterable, Optional, Tuple, Union
from mitmproxy.coretypes import multidict
from mitmproxy.utils import strutils
from mitmproxy.net.http import multipart
from mitmproxy.net.http import cookies
from mitmproxy.net.http import headers as nheaders
from mitmproxy.net.http import message
import mitmproxy.net.http.url
# This regex extracts & splits the host header into host and port.
# Handles the edge case of IPv6 addresses containing colons.
# https://bugzilla.mozilla.org/show_bug.cgi?id=45891
host_header_re = re.compile(r"^(?P<host>[^:]+|\[.+\])(?::(?P<port>\d+))?$")
from mitmproxy.coretypes import multidict
from mitmproxy.net.http import cookies, multipart
from mitmproxy.net.http import message
from mitmproxy.net.http.headers import Headers
from mitmproxy.utils.strutils import always_bytes, always_str
@dataclass
class RequestData(message.MessageData):
def __init__(
self,
first_line_format,
method,
scheme,
host,
port,
path,
http_version,
headers=(),
content=None,
trailers=None,
timestamp_start=None,
timestamp_end=None
):
if isinstance(method, str):
method = method.encode("ascii", "strict")
if isinstance(scheme, str):
scheme = scheme.encode("ascii", "strict")
if isinstance(host, str):
host = host.encode("idna", "strict")
if isinstance(path, str):
path = path.encode("ascii", "strict")
if isinstance(http_version, str):
http_version = http_version.encode("ascii", "strict")
if not isinstance(headers, nheaders.Headers):
headers = nheaders.Headers(headers)
if isinstance(content, str):
raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
if trailers is not None and not isinstance(trailers, nheaders.Headers):
trailers = nheaders.Headers(trailers)
self.first_line_format = first_line_format
self.method = method
self.scheme = scheme
self.host = host
self.port = port
self.path = path
self.http_version = http_version
self.headers = headers
self.content = content
self.trailers = trailers
self.timestamp_start = timestamp_start
self.timestamp_end = timestamp_end
host: str
port: int
method: bytes
scheme: bytes
authority: bytes
path: bytes
class Request(message.Message):
@ -70,19 +27,64 @@ class Request(message.Message):
"""
data: RequestData
def __init__(self, *args, **kwargs):
super().__init__()
self.data = RequestData(*args, **kwargs)
def __init__(
self,
host: str,
port: int,
method: bytes,
scheme: bytes,
authority: bytes,
path: bytes,
http_version: bytes,
headers: Union[Headers, Tuple[Tuple[bytes, bytes], ...]],
content: Optional[bytes],
trailers: Union[None, Headers, Tuple[Tuple[bytes, bytes], ...]],
timestamp_start: float,
timestamp_end: Optional[float],
):
# auto-convert invalid types to retain compatibility with older code.
if isinstance(host, bytes):
host = host.decode("idna", "strict")
if isinstance(method, str):
method = method.encode("ascii", "strict")
if isinstance(scheme, str):
scheme = scheme.encode("ascii", "strict")
if isinstance(authority, str):
authority = authority.encode("ascii", "strict")
if isinstance(path, str):
path = path.encode("ascii", "strict")
if isinstance(http_version, str):
http_version = http_version.encode("ascii", "strict")
def __repr__(self):
if isinstance(content, str):
raise ValueError(f"Content must be bytes, not {type(content).__name__}")
if not isinstance(headers, Headers):
headers = Headers(headers)
if trailers is not None and not isinstance(trailers, Headers):
trailers = Headers(trailers)
self.data = RequestData(
host=host,
port=port,
method=method,
scheme=scheme,
authority=authority,
path=path,
http_version=http_version,
headers=headers,
content=content,
trailers=trailers,
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
)
def __repr__(self) -> str:
if self.host and self.port:
hostport = "{}:{}".format(self.host, self.port)
hostport = f"{self.host}:{self.port}"
else:
hostport = ""
path = self.path or ""
return "Request({} {}{})".format(
self.method, hostport, path
)
return f"Request({self.method} {hostport}{path})"
@classmethod
def make(
@ -90,117 +92,134 @@ class Request(message.Message):
method: str,
url: str,
content: Union[bytes, str] = "",
headers: Union[Dict[str, AnyStr], Iterable[Tuple[bytes, bytes]]] = ()
):
headers: Union[Headers, Dict[Union[str, bytes], Union[str, bytes]], Iterable[Tuple[bytes, bytes]]] = ()
) -> "Request":
"""
Simplified API for creating request objects.
"""
req = cls(
"absolute",
method,
"",
"",
"",
"",
"HTTP/1.1",
(),
b""
)
req.url = url
req.timestamp_start = time.time()
# Headers can be list or dict, we differentiate here.
if isinstance(headers, dict):
req.headers = nheaders.Headers(**headers)
if isinstance(headers, Headers):
pass
elif isinstance(headers, dict):
headers = Headers(
(always_bytes(k, "utf-8", "surrogateescape"),
always_bytes(v, "utf-8", "surrogateescape"))
for k, v in headers.items()
)
elif isinstance(headers, Iterable):
req.headers = nheaders.Headers(headers)
headers = Headers(headers)
else:
raise TypeError("Expected headers to be an iterable or dict, but is {}.".format(
type(headers).__name__
))
req = cls(
"",
0,
method.encode("utf-8", "surrogateescape"),
b"",
b"",
b"",
b"HTTP/1.1",
headers,
b"",
None,
time.time(),
time.time(),
)
req.url = url
# Assign this manually to update the content-length header.
if isinstance(content, bytes):
req.content = content
elif isinstance(content, str):
req.text = content
else:
raise TypeError("Expected content to be str or bytes, but is {}.".format(
type(content).__name__
))
raise TypeError(f"Expected content to be str or bytes, but is {type(content).__name__}.")
return req
@property
def first_line_format(self):
def first_line_format(self) -> str:
"""
HTTP request form as defined in `RFC7230 <https://tools.ietf.org/html/rfc7230#section-5.3>`_.
origin-form and asterisk-form are subsumed as "relative".
"""
return self.data.first_line_format
@first_line_format.setter
def first_line_format(self, first_line_format):
self.data.first_line_format = first_line_format
if self.method == "CONNECT":
return "authority"
elif self.authority:
return "absolute"
else:
return "relative"
@property
def method(self):
def method(self) -> str:
"""
HTTP request method, e.g. "GET".
"""
return self.data.method.decode("utf-8", "surrogateescape").upper()
@method.setter
def method(self, method):
self.data.method = strutils.always_bytes(method, "utf-8", "surrogateescape")
def method(self, val: Union[str, bytes]) -> None:
self.data.method = always_bytes(val, "utf-8", "surrogateescape")
@property
def scheme(self):
def scheme(self) -> str:
"""
HTTP request scheme, which should be "http" or "https".
"""
if self.data.scheme is None:
return None
return self.data.scheme.decode("utf-8", "surrogateescape")
@scheme.setter
def scheme(self, scheme):
self.data.scheme = strutils.always_bytes(scheme, "utf-8", "surrogateescape")
def scheme(self, val: Union[str, bytes]) -> None:
self.data.scheme = always_bytes(val, "utf-8", "surrogateescape")
@property
def host(self):
def authority(self) -> str:
"""
HTTP request authority.
For HTTP/1, this is the authority portion of the request target
(in either absolute-form or authority-form)
For HTTP/2, this is the :authority pseudo header.
"""
try:
return self.data.authority.decode("idna")
except UnicodeError:
return self.data.authority.decode("utf8", "surrogateescape")
@authority.setter
def authority(self, val: Union[str, bytes]) -> None:
if isinstance(val, str):
try:
val = val.encode("idna", "strict")
except UnicodeError:
val = val.encode("utf8", "surrogateescape") # type: ignore
self.data.authority = val
@property
def host(self) -> str:
"""
Target host. This may be parsed from the raw request
(e.g. from a ``GET http://example.com/ HTTP/1.1`` request line)
or inferred from the proxy mode (e.g. an IP in transparent mode).
Setting the host attribute also updates the host header, if present.
Setting the host attribute also updates the host header and authority information, if present.
"""
if not self.data.host:
return self.data.host
try:
return self.data.host.decode("idna")
except UnicodeError:
return self.data.host.decode("utf8", "surrogateescape")
return self.data.host
@host.setter
def host(self, host):
if isinstance(host, str):
try:
# There's no non-strict mode for IDNA encoding.
# We don't want this operation to fail though, so we try
# utf8 as a last resort.
host = host.encode("idna", "strict")
except UnicodeError:
host = host.encode("utf8", "surrogateescape")
self.data.host = host
def host(self, val: Union[str, bytes]) -> None:
self.data.host = always_str(val, "idna", "strict")
# Update host header
if self.host_header is not None:
self.host_header = host
if "Host" in self.data.headers:
self.data.headers["Host"] = val
# Update authority
if self.data.authority:
self.authority = mitmproxy.net.http.url.hostport(self.scheme, self.host, self.port)
@property
def host_header(self) -> Optional[str]:
@ -208,111 +227,92 @@ class Request(message.Message):
The request's host/authority header.
This property maps to either ``request.headers["Host"]`` or
``request.headers[":authority"]``, depending on whether it's HTTP/1.x or HTTP/2.0.
``request.authority``, depending on whether it's HTTP/1.x or HTTP/2.0.
"""
if ":authority" in self.headers:
return self.headers[":authority"]
if "Host" in self.headers:
return self.headers["Host"]
return None
if self.is_http2:
return self.authority or self.data.headers.get("Host", None)
else:
return self.data.headers.get("Host", None)
@host_header.setter
def host_header(self, val: Optional[str]) -> None:
def host_header(self, val: Union[None, str, bytes]) -> None:
if val is None:
if self.is_http2:
self.data.authority = b""
self.headers.pop("Host", None)
self.headers.pop(":authority", None)
elif self.host_header is not None:
# Update any existing headers.
if ":authority" in self.headers:
self.headers[":authority"] = val
if "Host" in self.headers:
self.headers["Host"] = val
else:
# Only add the correct new header.
if self.http_version.upper().startswith("HTTP/2"):
self.headers[":authority"] = val
else:
if self.is_http2:
self.authority = val # type: ignore
if not self.is_http2 or "Host" in self.headers:
# For h2, we only overwrite, but not create, as :authority is the h2 host header.
self.headers["Host"] = val
@host_header.deleter
def host_header(self):
self.host_header = None
@property
def port(self):
def port(self) -> int:
"""
Target port
"""
return self.data.port
@port.setter
def port(self, port):
def port(self, port: int) -> None:
self.data.port = port
@property
def path(self):
def path(self) -> str:
"""
HTTP request path, e.g. "/index.html".
Guaranteed to start with a slash, except for OPTIONS requests, which may just be "*".
Usually starts with a slash, except for OPTIONS requests, which may just be "*".
"""
if self.data.path is None:
return None
else:
return self.data.path.decode("utf-8", "surrogateescape")
return self.data.path.decode("utf-8", "surrogateescape")
@path.setter
def path(self, path):
self.data.path = strutils.always_bytes(path, "utf-8", "surrogateescape")
def path(self, val: Union[str, bytes]) -> None:
self.data.path = always_bytes(val, "utf-8", "surrogateescape")
@property
def url(self):
def url(self) -> str:
"""
The URL string, constructed from the request's URL components
The URL string, constructed from the request's URL components.
"""
if self.first_line_format == "authority":
return "%s:%d" % (self.host, self.port)
return f"{self.host}:{self.port}"
return mitmproxy.net.http.url.unparse(self.scheme, self.host, self.port, self.path)
@url.setter
def url(self, url):
self.scheme, self.host, self.port, self.path = mitmproxy.net.http.url.parse(url)
def _parse_host_header(self):
"""Extract the host and port from Host header"""
host = self.host_header
if not host:
return None, None
port = None
m = host_header_re.match(host)
if m:
host = m.group("host").strip("[]")
if m.group("port"):
port = int(m.group("port"))
return host, port
def url(self, val: Union[str, bytes]) -> None:
val = always_str(val, "utf-8", "surrogateescape")
self.scheme, self.host, self.port, self.path = mitmproxy.net.http.url.parse(val)
@property
def pretty_host(self):
def pretty_host(self) -> str:
"""
Similar to :py:attr:`host`, but using the Host headers as an additional preferred data source.
Similar to :py:attr:`host`, but using the host/:authority header as an additional (preferred) data source.
This is useful in transparent mode where :py:attr:`host` is only an IP address,
but may not reflect the actual destination as the Host header could be spoofed.
"""
host, port = self._parse_host_header()
if not host:
authority = self.host_header
if authority:
return mitmproxy.net.http.url.parse_authority(authority, check=False)[0]
else:
return self.host
if not port:
port = 443 if self.scheme == 'https' else 80
# Prefer the original address if host header has an unexpected form
return host if port == self.port else self.host
@property
def pretty_url(self):
def pretty_url(self) -> str:
"""
Like :py:attr:`url`, but using :py:attr:`pretty_host` instead of :py:attr:`host`.
"""
if self.first_line_format == "authority":
return "%s:%d" % (self.pretty_host, self.port)
return mitmproxy.net.http.url.unparse(self.scheme, self.pretty_host, self.port, self.path)
return self.authority
host_header = self.host_header
if not host_header:
return self.url
pretty_host, pretty_port = mitmproxy.net.http.url.parse_authority(host_header, check=False)
pretty_port = pretty_port or mitmproxy.net.http.url.default_port(self.scheme) or 443
return mitmproxy.net.http.url.unparse(self.scheme, pretty_host, pretty_port, self.path)
def _get_query(self):
query = urllib.parse.urlparse(self.url).query
@ -379,7 +379,7 @@ class Request(message.Message):
_, _, _, params, query, fragment = urllib.parse.urlparse(self.url)
self.path = urllib.parse.urlunparse(["", "", path, params, query, fragment])
def anticache(self):
def anticache(self) -> None:
"""
Modifies this request to remove headers that might produce a cached
response. That is, we remove ETags and If-Modified-Since headers.
@ -391,14 +391,14 @@ class Request(message.Message):
for i in delheaders:
self.headers.pop(i, None)
def anticomp(self):
def anticomp(self) -> None:
"""
Modifies this request to remove headers that will compress the
resource's data.
"""
self.headers["accept-encoding"] = "identity"
def constrain_encoding(self):
def constrain_encoding(self) -> None:
"""
Limits the permissible Accept-Encoding values, based on what we can
decode appropriately.

View File

@ -1,50 +1,25 @@
import time
from email.utils import parsedate_tz, formatdate, mktime_tz
from mitmproxy.utils import human
from mitmproxy.coretypes import multidict
from mitmproxy.net.http import cookies
from mitmproxy.net.http import headers as nheaders
from mitmproxy.net.http import message
from mitmproxy.net.http import status_codes
from mitmproxy.utils import strutils
from typing import AnyStr
from dataclasses import dataclass
from email.utils import formatdate, mktime_tz, parsedate_tz
from typing import Dict
from typing import Iterable
from typing import Optional
from typing import Tuple
from typing import Union
from mitmproxy.coretypes import multidict
from mitmproxy.net.http import cookies, message
from mitmproxy.net.http import status_codes
from mitmproxy.net.http.headers import Headers
from mitmproxy.utils import human
from mitmproxy.utils import strutils
from mitmproxy.utils.strutils import always_bytes
@dataclass
class ResponseData(message.MessageData):
def __init__(
self,
http_version,
status_code,
reason=None,
headers=(),
content=None,
trailers=None,
timestamp_start=None,
timestamp_end=None
):
if isinstance(http_version, str):
http_version = http_version.encode("ascii", "strict")
if isinstance(reason, str):
reason = reason.encode("ascii", "strict")
if not isinstance(headers, nheaders.Headers):
headers = nheaders.Headers(headers)
if isinstance(content, str):
raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
if trailers is not None and not isinstance(trailers, nheaders.Headers):
trailers = nheaders.Headers(trailers)
self.http_version = http_version
self.status_code = status_code
self.reason = reason
self.headers = headers
self.content = content
self.trailers = trailers
self.timestamp_start = timestamp_start
self.timestamp_end = timestamp_end
status_code: int
reason: bytes
class Response(message.Message):
@ -53,91 +28,119 @@ class Response(message.Message):
"""
data: ResponseData
def __init__(self, *args, **kwargs):
super().__init__()
self.data = ResponseData(*args, **kwargs)
def __init__(
self,
http_version: bytes,
status_code: int,
reason: bytes,
headers: Union[Headers, Tuple[Tuple[bytes, bytes], ...]],
content: Optional[bytes],
trailers: Union[None, Headers, Tuple[Tuple[bytes, bytes], ...]],
timestamp_start: float,
timestamp_end: Optional[float],
):
# auto-convert invalid types to retain compatibility with older code.
if isinstance(http_version, str):
http_version = http_version.encode("ascii", "strict")
if isinstance(reason, str):
reason = reason.encode("ascii", "strict")
def __repr__(self):
if isinstance(content, str):
raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
if not isinstance(headers, Headers):
headers = Headers(headers)
if trailers is not None and not isinstance(trailers, Headers):
trailers = Headers(trailers)
self.data = ResponseData(
http_version=http_version,
status_code=status_code,
reason=reason,
headers=headers,
content=content,
trailers=trailers,
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
)
def __repr__(self) -> str:
if self.raw_content:
details = "{}, {}".format(
self.headers.get("content-type", "unknown content type"),
human.pretty_size(len(self.raw_content))
)
ct = self.headers.get("content-type", "unknown content type")
size = human.pretty_size(len(self.raw_content))
details = f"{ct}, {size}"
else:
details = "no content"
return "Response({status_code} {reason}, {details})".format(
status_code=self.status_code,
reason=self.reason,
details=details
)
return f"Response({self.status_code}, {details})"
@classmethod
def make(
cls,
status_code: int=200,
content: Union[bytes, str]=b"",
headers: Union[Dict[str, AnyStr], Iterable[Tuple[bytes, bytes]]]=()
):
status_code: int = 200,
content: Union[bytes, str] = b"",
headers: Union[Headers, Dict[Union[str, bytes], Union[str, bytes]], Iterable[Tuple[bytes, bytes]]] = ()
) -> "Response":
"""
Simplified API for creating response objects.
"""
resp = cls(
b"HTTP/1.1",
status_code,
status_codes.RESPONSES.get(status_code, "").encode(),
(),
None
)
# Headers can be list or dict, we differentiate here.
if isinstance(headers, dict):
resp.headers = nheaders.Headers(**headers)
if isinstance(headers, Headers):
headers = headers
elif isinstance(headers, dict):
headers = Headers(
(always_bytes(k, "utf-8", "surrogateescape"),
always_bytes(v, "utf-8", "surrogateescape"))
for k, v in headers.items()
)
elif isinstance(headers, Iterable):
resp.headers = nheaders.Headers(headers)
headers = Headers(headers)
else:
raise TypeError("Expected headers to be an iterable or dict, but is {}.".format(
type(headers).__name__
))
resp = cls(
b"HTTP/1.1",
status_code,
status_codes.RESPONSES.get(status_code, "").encode(),
headers,
None,
None,
time.time(),
time.time(),
)
# Assign this manually to update the content-length header.
if isinstance(content, bytes):
resp.content = content
elif isinstance(content, str):
resp.text = content
else:
raise TypeError("Expected content to be str or bytes, but is {}.".format(
type(content).__name__
))
raise TypeError(f"Expected content to be str or bytes, but is {type(content).__name__}.")
return resp
@property
def status_code(self):
def status_code(self) -> int:
"""
HTTP Status Code, e.g. ``200``.
"""
return self.data.status_code
@status_code.setter
def status_code(self, status_code):
def status_code(self, status_code: int) -> None:
self.data.status_code = status_code
@property
def reason(self):
def reason(self) -> str:
"""
HTTP Reason Phrase, e.g. "Not Found".
HTTP2 responses do not contain a reason phrase and self.data.reason will be :py:obj:`None`.
When :py:obj:`None` return an empty reason phrase so that functions expecting a string work properly.
HTTP/2 responses do not contain a reason phrase, an empty string will be returned instead.
"""
# Encoding: http://stackoverflow.com/a/16674906/934719
if self.data.reason is not None:
return self.data.reason.decode("ISO-8859-1", "surrogateescape")
else:
return ""
return self.data.reason.decode("ISO-8859-1")
@reason.setter
def reason(self, reason):
self.data.reason = strutils.always_bytes(reason, "ISO-8859-1", "surrogateescape")
def reason(self, reason: Union[str, bytes]) -> None:
self.data.reason = strutils.always_bytes(reason, "ISO-8859-1")
def _get_cookies(self):
h = self.headers.get_all("set-cookie")

View File

@ -1,8 +1,17 @@
import re
import urllib.parse
from typing import AnyStr, Optional
from typing import Sequence
from typing import Tuple
from mitmproxy.net import check
# This regex extracts & splits the host header into host and port.
# Handles the edge case of IPv6 addresses containing colons.
# https://bugzilla.mozilla.org/show_bug.cgi?id=45891
from mitmproxy.net.check import is_valid_host, is_valid_port
from mitmproxy.utils.strutils import always_str
_authority_re = re.compile(r"^(?P<host>[^:]+|\[.+\])(?::(?P<port>\d+))?$")
def parse(url):
@ -21,6 +30,8 @@ def parse(url):
Raises:
ValueError, if the URL is not properly formatted.
"""
# FIXME: We shouldn't rely on urllib here.
# Size of Ascii character after encoding is 1 byte which is same as its size
# But non-Ascii character's size after encoding will be more than its size
def ascii_check(l):
@ -61,7 +72,7 @@ def parse(url):
return parsed.scheme, host, port, full_path
def unparse(scheme, host, port, path=""):
def unparse(scheme: str, host: str, port: int, path: str = "") -> str:
"""
Returns a URL string, constructed from the specified components.
@ -70,10 +81,11 @@ def unparse(scheme, host, port, path=""):
"""
if path == "*":
path = ""
return "%s://%s%s" % (scheme, hostport(scheme, host, port), path)
authority = hostport(scheme, host, port)
return f"{scheme}://{authority}{path}"
def encode(s: Sequence[Tuple[str, str]], similar_to: str=None) -> str:
def encode(s: Sequence[Tuple[str, str]], similar_to: str = None) -> str:
"""
Takes a list of (key, value) tuples and returns a urlencoded string.
If similar_to is passed, the output is formatted similar to the provided urlencoded string.
@ -100,7 +112,7 @@ def decode(s):
return urllib.parse.parse_qsl(s, keep_blank_values=True, errors='surrogateescape')
def quote(b: str, safe: str="/") -> str:
def quote(b: str, safe: str = "/") -> str:
"""
Returns:
An ascii-encodable str.
@ -118,14 +130,59 @@ def unquote(s: str) -> str:
return urllib.parse.unquote(s, errors="surrogateescape")
def hostport(scheme, host, port):
def hostport(scheme: AnyStr, host: AnyStr, port: int) -> AnyStr:
"""
Returns the host component, with a port specifcation if needed.
Returns the host component, with a port specification if needed.
"""
if (port, scheme) in [(80, "http"), (443, "https"), (80, b"http"), (443, b"https")]:
if default_port(scheme) == port:
return host
else:
if isinstance(host, bytes):
return b"%s:%d" % (host, port)
else:
return "%s:%d" % (host, port)
def default_port(scheme: AnyStr) -> Optional[int]:
return {
"http": 80,
b"http": 80,
"https": 443,
b"https": 443,
}.get(scheme, None)
def parse_authority(authority: AnyStr, check: bool) -> Tuple[str, Optional[int]]:
"""Extract the host and port from host header/authority information
Raises:
ValueError, if check is True and the authority information is malformed.
"""
try:
if isinstance(authority, bytes):
authority_str = authority.decode("idna")
else:
authority_str = authority
m = _authority_re.match(authority_str)
if not m:
raise ValueError
host = m.group("host")
if host.startswith("[") and host.endswith("]"):
host = host[1:-1]
if not is_valid_host(host):
raise ValueError
if m.group("port"):
port = int(m.group("port"))
if not is_valid_port(port):
raise ValueError
return host, port
else:
return host, None
except ValueError:
if check:
raise
else:
return always_str(authority, "utf-8", "surrogateescape"), None

View File

@ -148,12 +148,14 @@ MODE_REQUEST_FORMS = {
def validate_request_form(mode, request):
if request.first_line_format == "absolute" and request.scheme != "http":
if request.first_line_format == "absolute" and request.scheme not in ("http", "https"):
raise exceptions.HttpException(
"Invalid request scheme: %s" % request.scheme
)
allowed_request_forms = MODE_REQUEST_FORMS[mode]
if request.first_line_format not in allowed_request_forms:
if request.is_http2 and mode is HTTPMode.transparent and request.first_line_format == "absolute":
return # dirty hack: h2 may have authority info. will be fixed properly with sans-io.
if mode == HTTPMode.transparent:
err_message = textwrap.dedent((
"""
@ -252,7 +254,7 @@ class HttpLayer(base.Layer):
def _process_flow(self, f):
try:
try:
request = self.read_request_headers(f)
request: http.HTTPRequest = self.read_request_headers(f)
except exceptions.HttpReadDisconnect:
# don't throw an error for disconnects that happen
# before/between requests.
@ -287,7 +289,7 @@ class HttpLayer(base.Layer):
if request.headers.get("expect", "").lower() == "100-continue":
# TODO: We may have to use send_response_headers for HTTP2
# here.
self.send_response(http.expect_continue_response)
self.send_response(http.make_expect_continue_response())
request.headers.pop("expect")
if f.request.stream:
@ -318,7 +320,7 @@ class HttpLayer(base.Layer):
# set first line format to relative in regular mode,
# see https://github.com/mitmproxy/mitmproxy/issues/1759
if self.mode is HTTPMode.regular and request.first_line_format == "absolute":
request.first_line_format = "relative"
request.authority = ""
# update host header in reverse proxy mode
if self.config.options.mode.startswith("reverse:") and not self.config.options.keep_host_header:
@ -332,11 +334,9 @@ class HttpLayer(base.Layer):
if self.mode is HTTPMode.transparent:
# Setting request.host also updates the host header, which we want
# to preserve
host_header = f.request.host_header
f.request.host = self.__initial_server_address[0]
f.request.port = self.__initial_server_address[1]
f.request.host_header = host_header # set again as .host overwrites this.
f.request.scheme = "https" if self.__initial_server_tls else "http"
f.request.data.host = self.__initial_server_address[0]
f.request.data.port = self.__initial_server_address[1]
f.request.data.scheme = b"https" if self.__initial_server_tls else b"http"
self.channel.ask("request", f)
try:

View File

@ -1,6 +1,5 @@
from mitmproxy import http
from mitmproxy.proxy.protocol import http as httpbase
from mitmproxy.net.http import http1
from mitmproxy.proxy.protocol import http as httpbase
from mitmproxy.utils import human
@ -11,9 +10,7 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
self.mode = mode
def read_request_headers(self, flow):
return http.HTTPRequest.wrap(
http1.read_request_head(self.client_conn.rfile)
)
return http1.read_request_head(self.client_conn.rfile)
def read_request_body(self, request):
expected_size = http1.expected_http_body_size(request)
@ -50,8 +47,7 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
self.server_conn.wfile.flush()
def read_response_headers(self):
resp = http1.read_response_head(self.server_conn.rfile)
return http.HTTPResponse.wrap(resp)
return http1.read_response_head(self.server_conn.rfile)
def read_response_body(self, request, response):
expected_size = http1.expected_http_body_size(request, response)

View File

@ -16,7 +16,7 @@ from mitmproxy.proxy.protocol import http as httpbase
import mitmproxy.net.http
from mitmproxy.net import tcp
from mitmproxy.coretypes import basethread
from mitmproxy.net.http import http2, headers
from mitmproxy.net.http import http2, headers, url
from mitmproxy.utils import human
@ -506,19 +506,28 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
if self.pushed:
flow.metadata['h2-pushed-stream'] = True
first_line_format, method, scheme, host, port, path = http2.parse_headers(self.request_message.headers)
# pseudo header must be present, see https://http2.github.io/http2-spec/#rfc.section.8.1.2.3
authority = self.request_message.headers.pop(':authority', "")
method = self.request_message.headers.pop(':method')
scheme = self.request_message.headers.pop(':scheme')
path = self.request_message.headers.pop(':path')
host, port = url.parse_authority(authority, check=True)
port = port or url.default_port(scheme) or 0
return http.HTTPRequest(
first_line_format,
method,
scheme,
host,
port,
path,
method.encode(),
scheme.encode(),
authority.encode(),
path.encode(),
b"HTTP/2.0",
self.request_message.headers,
None,
timestamp_start=self.timestamp_start,
timestamp_end=self.timestamp_end,
None,
self.timestamp_start,
self.timestamp_end,
)
@detect_zombie_stream
@ -569,6 +578,8 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id
headers = request.headers.copy()
if request.authority:
headers.insert(0, ":authority", request.authority)
headers.insert(0, ":path", request.path)
headers.insert(0, ":method", request.method)
headers.insert(0, ":scheme", request.scheme)
@ -640,6 +651,7 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
reason=b'',
headers=headers,
content=None,
trailers=None,
timestamp_start=self.timestamp_start,
timestamp_end=self.timestamp_end,
)

View File

@ -40,25 +40,27 @@ def twebsocketflow(client_conn=True, server_conn=True, messages=True, err=None,
server_conn = tserver_conn()
if handshake_flow is True:
req = http.HTTPRequest(
"relative",
"GET",
"http",
"example.com",
80,
"/ws",
"HTTP/1.1",
b"GET",
b"http",
b"example.com",
b"/ws",
b"HTTP/1.1",
headers=net_http.Headers(
connection="upgrade",
upgrade="websocket",
sec_websocket_version="13",
sec_websocket_key="1234",
),
content=b'',
trailers=None,
timestamp_start=946681200,
timestamp_end=946681201,
content=b''
)
resp = http.HTTPResponse(
"HTTP/1.1",
b"HTTP/1.1",
101,
reason=net_http.status_codes.RESPONSES.get(101),
headers=net_http.Headers(
@ -66,9 +68,10 @@ def twebsocketflow(client_conn=True, server_conn=True, messages=True, err=None,
upgrade='websocket',
sec_websocket_accept=b'',
),
content=b'',
trailers=None,
timestamp_start=946681202,
timestamp_end=946681203,
content=b'',
)
handshake_flow = http.HTTPFlow(client_conn, server_conn)
handshake_flow.request = req
@ -114,11 +117,6 @@ def tflow(client_conn=True, server_conn=True, req=True, resp=None, err=None):
if err is True:
err = terr()
if req:
req = http.HTTPRequest.wrap(req)
if resp:
resp = http.HTTPResponse.wrap(resp)
f = http.HTTPFlow(client_conn, server_conn)
f.request = req
f.response = resp

View File

@ -12,21 +12,22 @@ def treader(bytes):
return tcp.Reader(fp)
def treq(**kwargs):
def treq(**kwargs) -> http.Request:
"""
Returns:
mitmproxy.net.http.Request
"""
default = dict(
first_line_format="relative",
host="address",
port=22,
method=b"GET",
scheme=b"http",
host=b"address",
port=22,
authority=b"",
path=b"/path",
http_version=b"HTTP/1.1",
headers=http.Headers(((b"header", b"qvalue"), (b"content-length", b"7"))),
content=b"content",
trailers=None,
timestamp_start=946681200,
timestamp_end=946681201,
)
@ -34,7 +35,7 @@ def treq(**kwargs):
return http.Request(**default)
def tresp(**kwargs):
def tresp(**kwargs) -> http.Response:
"""
Returns:
mitmproxy.net.http.Response
@ -45,6 +46,7 @@ def tresp(**kwargs):
reason=b"OK",
headers=http.Headers(((b"header-response", b"svalue"), (b"content-length", b"7"))),
content=b"message",
trailers=None,
timestamp_start=946681202,
timestamp_end=946681203,
)

View File

@ -381,6 +381,7 @@ def format_http_flow_list(
render_mode: RenderMode,
focused: bool,
marked: bool,
is_replay: bool,
request_method: str,
request_scheme: str,
request_host: str,
@ -389,13 +390,11 @@ def format_http_flow_list(
request_http_version: str,
request_timestamp: float,
request_is_push_promise: bool,
request_is_replay: bool,
intercepted: bool,
response_code: typing.Optional[int],
response_reason: typing.Optional[str],
response_content_length: typing.Optional[int],
response_content_type: typing.Optional[str],
response_is_replay: bool,
duration: typing.Optional[float],
error_message: typing.Optional[str],
) -> urwid.Widget:
@ -433,7 +432,7 @@ def format_http_flow_list(
else:
req.append(truncated_plain(request_url, url_style))
req.append(format_right_indicators(replay=request_is_replay or response_is_replay, marked=marked))
req.append(format_right_indicators(replay=is_replay, marked=marked))
resp = [
("fixed", preamble_len, urwid.Text(""))
@ -446,8 +445,6 @@ def format_http_flow_list(
status_style = style or HTTP_RESPONSE_CODE_STYLE.get(response_code // 100, "code_other")
resp.append(fcol(SYMBOL_RETURN, status_style))
if response_is_replay:
resp.append(fcol(SYMBOL_REPLAY, "replay"))
resp.append(fcol(str(response_code), status_style))
if response_reason and render_mode is RenderMode.DETAILVIEW:
resp.append(fcol(response_reason, status_style))
@ -485,6 +482,7 @@ def format_http_flow_table(
render_mode: RenderMode,
focused: bool,
marked: bool,
is_replay: typing.Optional[str],
request_method: str,
request_scheme: str,
request_host: str,
@ -493,13 +491,11 @@ def format_http_flow_table(
request_http_version: str,
request_timestamp: float,
request_is_push_promise: bool,
request_is_replay: bool,
intercepted: bool,
response_code: typing.Optional[int],
response_reason: typing.Optional[str],
response_content_length: typing.Optional[int],
response_content_type: typing.Optional[str],
response_is_replay: bool,
duration: typing.Optional[float],
error_message: typing.Optional[str],
) -> urwid.Widget:
@ -579,7 +575,7 @@ def format_http_flow_table(
items.append(("fixed", 5, urwid.Text("")))
items.append(format_right_indicators(
replay=request_is_replay or response_is_replay,
replay=bool(is_replay),
marked=marked
))
return urwid.Columns(items, dividechars=1, min_width=15)
@ -689,10 +685,9 @@ def format_flow(
response_content_length = len(f.response.raw_content)
else:
response_content_length = None
response_code = f.response.status_code
response_reason = f.response.reason
response_code: typing.Optional[int] = f.response.status_code
response_reason: typing.Optional[str] = f.response.reason
response_content_type = f.response.headers.get("content-type")
response_is_replay = f.response.is_replay
if f.response.timestamp_end:
duration = max([f.response.timestamp_end - f.request.timestamp_start, 0])
else:
@ -702,7 +697,6 @@ def format_flow(
response_code = None
response_reason = None
response_content_type = None
response_is_replay = False
duration = None
if render_mode in (RenderMode.LIST, RenderMode.DETAILVIEW):
@ -713,6 +707,7 @@ def format_flow(
render_mode=render_mode,
focused=focused,
marked=f.marked,
is_replay=f.is_replay,
request_method=f.request.method,
request_scheme=f.request.scheme,
request_host=f.request.pretty_host if hostheader else f.request.host,
@ -721,13 +716,11 @@ def format_flow(
request_http_version=f.request.http_version,
request_timestamp=f.request.timestamp_start,
request_is_push_promise='h2-pushed-stream' in f.metadata,
request_is_replay=f.request.is_replay,
intercepted=intercepted,
response_code=response_code,
response_reason=response_reason,
response_content_length=response_content_length,
response_content_type=response_content_type,
response_is_replay=response_is_replay,
duration=duration,
error_message=error_message,
)

View File

@ -33,6 +33,7 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
f = {
"id": flow.id,
"intercepted": flow.intercepted,
"is_replay": flow.is_replay,
"client_conn": flow.client_conn.get_state(),
"server_conn": flow.server_conn.get_state(),
"type": flow.type,
@ -72,7 +73,7 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
"contentHash": content_hash,
"timestamp_start": flow.request.timestamp_start,
"timestamp_end": flow.request.timestamp_end,
"is_replay": flow.request.is_replay,
"is_replay": flow.is_replay == "request", # TODO: remove, use flow.is_replay instead.
"pretty_host": flow.request.pretty_host,
}
if flow.response:
@ -91,7 +92,7 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
"contentHash": content_hash,
"timestamp_start": flow.response.timestamp_start,
"timestamp_end": flow.response.timestamp_end,
"is_replay": flow.response.is_replay,
"is_replay": flow.is_replay == "response", # TODO: remove, use flow.is_replay instead.
}
if flow.response.data.trailers:
f["response"]["trailers"] = tuple(flow.response.data.trailers.items(True))

View File

@ -1,27 +1,47 @@
import codecs
import io
import re
from typing import Iterable, Optional, Union, cast
from typing import Iterable, Union, overload
def always_bytes(str_or_bytes: Union[str, bytes, None], *encode_args) -> Optional[bytes]:
if isinstance(str_or_bytes, bytes) or str_or_bytes is None:
return cast(Optional[bytes], str_or_bytes)
# https://mypy.readthedocs.io/en/stable/more_types.html#function-overloading
@overload
def always_bytes(str_or_bytes: None, *encode_args) -> None:
...
@overload
def always_bytes(str_or_bytes: Union[str, bytes], *encode_args) -> bytes:
...
def always_bytes(str_or_bytes: Union[None, str, bytes], *encode_args) -> Union[None, bytes]:
if str_or_bytes is None or isinstance(str_or_bytes, bytes):
return str_or_bytes
elif isinstance(str_or_bytes, str):
return str_or_bytes.encode(*encode_args)
else:
raise TypeError("Expected str or bytes, but got {}.".format(type(str_or_bytes).__name__))
def always_str(str_or_bytes: Union[str, bytes, None], *decode_args) -> Optional[str]:
@overload
def always_str(str_or_bytes: None, *encode_args) -> None:
...
@overload
def always_str(str_or_bytes: Union[str, bytes], *encode_args) -> str:
...
def always_str(str_or_bytes: Union[None, str, bytes], *decode_args) -> Union[None, str]:
"""
Returns,
str_or_bytes unmodified, if
"""
if str_or_bytes is None:
return None
if isinstance(str_or_bytes, str):
return cast(str, str_or_bytes)
if str_or_bytes is None or isinstance(str_or_bytes, str):
return str_or_bytes
elif isinstance(str_or_bytes, bytes):
return str_or_bytes.decode(*decode_args)
else:

View File

@ -71,6 +71,8 @@ def check_option_type(name: str, value: typing.Any, typeinfo: Type) -> None:
elif typename.startswith("typing.Any"):
return
elif not isinstance(value, typeinfo):
if typeinfo is float and isinstance(value, int):
return
raise e

View File

@ -8,7 +8,7 @@ MITMPROXY = "mitmproxy " + VERSION
# Serialization format version. This is displayed nowhere, it just needs to be incremented by one
# for each change in the file format.
FLOW_FORMAT_VERSION = 8
FLOW_FORMAT_VERSION = 9
def get_dev_version() -> str:

View File

@ -190,11 +190,14 @@ class Response(_HTTP2Message):
body = body.string()
resp = http.Response(
b'HTTP/2.0',
int(self.status_code.string()),
b'',
headers,
body,
http_version=b'HTTP/2.0',
status_code=int(self.status_code.string()),
reason=b'',
headers=headers,
content=body,
trailers=None,
timestamp_start=0,
timestamp_end=0
)
resp.stream_id = self.stream_id
@ -273,15 +276,18 @@ class Request(_HTTP2Message):
body = body.string()
req = http.Request(
b'',
"",
0,
self.method.string(),
b'http',
b'',
b'',
path,
(2, 0),
b"HTTP/2.0",
headers,
body,
None,
0,
0,
)
req.stream_id = self.stream_id

View File

@ -237,15 +237,18 @@ class Pathoc(tcp.TCPClient):
def http_connect(self, connect_to):
req = net_http.Request(
first_line_format='authority',
method='CONNECT',
scheme=None,
host=connect_to[0].encode("idna"),
host=connect_to[0],
port=connect_to[1],
path=None,
http_version='HTTP/1.1',
headers=[(b"Host", connect_to[0].encode("idna"))],
method=b'CONNECT',
scheme=b"",
authority=f"{connect_to[0]}:{connect_to[1]}".encode(),
path=b"",
http_version=b'HTTP/1.1',
headers=((b"Host", connect_to[0].encode("idna")),),
content=b'',
trailers=None,
timestamp_start=0,
timestamp_end=0,
)
self.wfile.write(net_http.http1.assemble_request(req))
self.wfile.flush()
@ -437,14 +440,18 @@ class Pathoc(tcp.TCPClient):
# build a dummy request to read the response
# ideally this would be returned directly from language.serve
dummy_req = net_http.Request(
first_line_format="relative",
host="localhost",
port=80,
method=req["method"],
scheme=b"http",
host=b"localhost",
port=80,
authority=b"",
path=b"/",
http_version=b"HTTP/1.1",
headers=(),
content=b'',
trailers=None,
timestamp_start=time.time(),
timestamp_end=None,
)
resp = self.protocol.read_response(self.rfile, dummy_req)

View File

@ -2,14 +2,13 @@ import itertools
import time
import hyperframe.frame
from hpack.hpack import Encoder, Decoder
from hpack.hpack import Decoder, Encoder
from mitmproxy.net.http import http2
import mitmproxy.net.http.headers
import mitmproxy.net.http.response
import mitmproxy.net.http.request
import mitmproxy.net.http.response
from mitmproxy.coretypes import bidi
from mitmproxy.net.http import http2, url
from .. import language
@ -98,19 +97,26 @@ class HTTP2StateProtocol:
timestamp_end = time.time()
first_line_format, method, scheme, host, port, path = http2.parse_headers(headers)
# pseudo header must be present, see https://http2.github.io/http2-spec/#rfc.section.8.1.2.3
authority = headers.pop(':authority', "")
method = headers.pop(':method', "")
scheme = headers.pop(':scheme', "")
path = headers.pop(':path', "")
request = mitmproxy.net.http.request.Request(
first_line_format,
method,
scheme,
host,
port,
path,
b"HTTP/2.0",
headers,
body,
None,
host, port = url.parse_authority(authority, check=False)
port = port or url.default_port(scheme) or 0
request = mitmproxy.net.http.Request(
host=host,
port=port,
method=method.encode(),
scheme=scheme.encode(),
authority=authority.encode(),
path=path.encode(),
http_version=b"HTTP/2.0",
headers=headers,
content=body,
trailers=None,
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
)
@ -150,11 +156,12 @@ class HTTP2StateProtocol:
timestamp_end = None
response = mitmproxy.net.http.response.Response(
b"HTTP/2.0",
int(headers.get(':status', 502)),
b'',
headers,
body,
http_version=b"HTTP/2.0",
status_code=int(headers.get(':status', 502)),
reason=b'',
headers=headers,
content=body,
trailers=None,
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
)

View File

@ -20,6 +20,7 @@ exclude_lines =
raise NotImplementedError()
if typing.TYPE_CHECKING:
if TYPE_CHECKING:
@overload
[mypy]
ignore_missing_imports = True
@ -55,13 +56,16 @@ exclude =
[tool:individual_coverage]
exclude =
mitmproxy/addons/onboardingapp/app.py
mitmproxy/addons/session.py
mitmproxy/addons/termlog.py
mitmproxy/contentviews/base.py
mitmproxy/controller.py
mitmproxy/ctx.py
mitmproxy/exceptions.py
mitmproxy/flow.py
mitmproxy/io/db.py
mitmproxy/io/io.py
mitmproxy/io/protobuf.py
mitmproxy/io/tnetstring.py
mitmproxy/log.py
mitmproxy/master.py

View File

@ -34,7 +34,6 @@ setup(
"Operating System :: POSIX",
"Operating System :: Microsoft :: Windows",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: Implementation :: CPython",

View File

@ -1,10 +1,10 @@
import io
from mitmproxy import http
from mitmproxy.addons import disable_h2c
from mitmproxy.net.http import http1
from mitmproxy.exceptions import Kill
from mitmproxy.test import tflow
from mitmproxy.net.http import http1
from mitmproxy.test import taddons
from mitmproxy.test import tflow
class TestDisableH2CleartextUpgrade:
@ -30,7 +30,7 @@ class TestDisableH2CleartextUpgrade:
b = io.BytesIO(b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
f = tflow.tflow()
f.request = http.HTTPRequest.wrap(http1.read_request(b))
f.request = http1.read_request(b)
f.intercept()
a.request(f)

View File

@ -1,15 +1,14 @@
import io
import shutil
import pytest
from unittest import mock
from mitmproxy.test import tflow
from mitmproxy.test import taddons
from mitmproxy.test import tutils
import pytest
from mitmproxy.addons import dumper
from mitmproxy import exceptions
from mitmproxy import http
from mitmproxy.addons import dumper
from mitmproxy.test import taddons
from mitmproxy.test import tflow
from mitmproxy.test import tutils
def test_configure():
@ -83,7 +82,7 @@ def test_simple():
flow.client_conn = mock.MagicMock()
flow.client_conn.address[0] = "foo"
flow.response = tutils.tresp(content=None)
flow.response.is_replay = True
flow.is_replay = "response"
flow.response.status_code = 300
d.response(flow)
assert sio.getvalue()
@ -104,8 +103,7 @@ def test_simple():
ctx.configure(d, flow_detail=4)
flow = tflow.tflow()
flow.request.content = None
flow.response = http.HTTPResponse.wrap(tutils.tresp())
flow.response.content = None
flow.response = tutils.tresp(content=None)
d.response(flow)
assert "content missing" in sio.getvalue()
sio.truncate(0)
@ -135,13 +133,13 @@ def test_echo_request_line():
with taddons.context(d) as ctx:
ctx.configure(d, flow_detail=3, showhost=True)
f = tflow.tflow(client_conn=None, server_conn=True, resp=True)
f.request.is_replay = True
f.is_replay = "request"
d._echo_request_line(f)
assert "[replay]" in sio.getvalue()
sio.truncate(0)
f = tflow.tflow(client_conn=None, server_conn=True, resp=True)
f.request.is_replay = False
f.is_replay = None
d._echo_request_line(f)
assert "[replay]" not in sio.getvalue()
sio.truncate(0)

View File

@ -331,7 +331,7 @@ def test_server_playback_full():
tf = tflow.tflow()
assert not tf.response
s.request(tf)
assert tf.response == f.response
assert tf.response.data == f.response.data
tf = tflow.tflow()
tf.request.content = b"gibble"

View File

@ -160,6 +160,7 @@ class TestSession:
assert len(s._view) == 4
@pytest.mark.asyncio
@pytest.mark.skip
async def test_storage_flush_with_specials(self):
s = self.start_session(fp=0.5)
f = self.tft()
@ -187,6 +188,7 @@ class TestSession:
assert s._flush_period == s._FP_DEFAULT
@pytest.mark.asyncio
@pytest.mark.skip
async def test_storage_bodies(self):
# Need to test for configure
# Need to test for set_order
@ -202,8 +204,8 @@ class TestSession:
).fetchall()[0]
assert content == (1, b"A" * 1001)
assert s.db_store.body_ledger == {f.id}
f.response = http.HTTPResponse.wrap(tutils.tresp(content=b"A" * 1001))
f2.response = http.HTTPResponse.wrap(tutils.tresp(content=b"A" * 1001))
f.response = tutils.tresp(content=b"A" * 1001)
f2.response = tutils.tresp(content=b"A" * 1001)
# Content length is wrong for some reason -- quick fix
f.response.headers['content-length'] = b"1001"
f2.response.headers['content-length'] = b"1001"
@ -222,6 +224,7 @@ class TestSession:
assert all([lf.__dict__ == rf.__dict__ for lf, rf in list(zip(s.load_view(), [f, f2]))])
@pytest.mark.asyncio
@pytest.mark.skip
async def test_storage_order(self):
s = self.start_session(fp=0.5)
s.request(self.tft(method="GET", start=4))

View File

@ -1,7 +1,10 @@
import pytest
from mitmproxy.io import db
from mitmproxy.test import tflow
@pytest.mark.skip
class TestDB:
def test_create(self, tdata):

View File

@ -1,12 +1,12 @@
import pytest
from mitmproxy import certs
from mitmproxy import http
from mitmproxy import exceptions
from mitmproxy.test import tflow, tutils
from mitmproxy.io import protobuf
from mitmproxy.test import tflow, tutils
@pytest.mark.skip
class TestProtobuf:
def test_roundtrip_client(self):
@ -66,25 +66,25 @@ class TestProtobuf:
assert s.via.__dict__ == ls.via.__dict__
def test_roundtrip_http_request(self):
req = http.HTTPRequest.wrap(tutils.treq())
req = tutils.treq()
preq = protobuf._dump_http_request(req)
lreq = protobuf._load_http_request(preq)
assert req.__dict__ == lreq.__dict__
def test_roundtrip_http_request_empty_content(self):
req = http.HTTPRequest.wrap(tutils.treq(content=b""))
req = tutils.treq(content=b"")
preq = protobuf._dump_http_request(req)
lreq = protobuf._load_http_request(preq)
assert req.__dict__ == lreq.__dict__
def test_roundtrip_http_response(self):
res = http.HTTPResponse.wrap(tutils.tresp())
res = tutils.tresp()
pres = protobuf._dump_http_response(res)
lres = protobuf._load_http_response(pres)
assert res.__dict__ == lres.__dict__
def test_roundtrip_http_response_empty_content(self):
res = http.HTTPResponse.wrap(tutils.tresp(content=b""))
res = tutils.tresp(content=b"")
pres = protobuf._dump_http_response(res)
lres = protobuf._load_http_response(pres)
assert res.__dict__ == lres.__dict__

View File

@ -65,15 +65,12 @@ def test_assemble_body():
def test_assemble_request_line():
assert _assemble_request_line(treq().data) == b"GET /path HTTP/1.1"
authority_request = treq(method=b"CONNECT", first_line_format="authority").data
authority_request = treq(method=b"CONNECT", authority=b"address:22").data
assert _assemble_request_line(authority_request) == b"CONNECT address:22 HTTP/1.1"
absolute_request = treq(first_line_format="absolute").data
absolute_request = treq(scheme=b"http", authority=b"address:22").data
assert _assemble_request_line(absolute_request) == b"GET http://address:22/path HTTP/1.1"
with pytest.raises(RuntimeError):
_assemble_request_line(treq(first_line_format="invalid_form").data)
def test_assemble_request_headers():
# https://github.com/mitmproxy/mitmproxy/issues/186

View File

@ -7,7 +7,7 @@ from mitmproxy.net.http import Headers
from mitmproxy.net.http.http1.read import (
read_request, read_response, read_request_head,
read_response_head, read_body, connection_close, expected_http_body_size, _get_first_line,
_read_request_line, _parse_authority_form, _read_response_line, _check_http_version,
_read_request_line, _read_response_line, _check_http_version,
_read_headers, _read_chunked, get_header_tokens
)
from mitmproxy.test.tutils import treq, tresp
@ -242,35 +242,26 @@ def test_read_request_line():
return _read_request_line(BytesIO(b))
assert (t(b"GET / HTTP/1.1") ==
("relative", b"GET", None, None, None, b"/", b"HTTP/1.1"))
("", 0, b"GET", b"", b"", b"/", b"HTTP/1.1"))
assert (t(b"OPTIONS * HTTP/1.1") ==
("relative", b"OPTIONS", None, None, None, b"*", b"HTTP/1.1"))
("", 0, b"OPTIONS", b"", b"", b"*", b"HTTP/1.1"))
assert (t(b"CONNECT foo:42 HTTP/1.1") ==
("authority", b"CONNECT", None, b"foo", 42, None, b"HTTP/1.1"))
("foo", 42, b"CONNECT", b"", b"foo:42", b"", b"HTTP/1.1"))
assert (t(b"GET http://foo:42/bar HTTP/1.1") ==
("absolute", b"GET", b"http", b"foo", 42, b"/bar", b"HTTP/1.1"))
("foo", 42, b"GET", b"http", b"foo:42", b"/bar", b"HTTP/1.1"))
with pytest.raises(exceptions.HttpSyntaxException):
t(b"GET / WTF/1.1")
with pytest.raises(exceptions.HttpSyntaxException):
t(b"CONNECT example.com HTTP/1.1") # port missing
with pytest.raises(exceptions.HttpSyntaxException):
t(b"GET ws://example.com/ HTTP/1.1") # port missing
with pytest.raises(exceptions.HttpSyntaxException):
t(b"this is not http")
with pytest.raises(exceptions.HttpReadDisconnect):
t(b"")
def test_parse_authority_form():
assert _parse_authority_form(b"foo:42") == (b"foo", 42)
assert _parse_authority_form(b"[2001:db8:42::]:443") == (b"2001:db8:42::", 443)
with pytest.raises(exceptions.HttpSyntaxException):
_parse_authority_form(b"foo")
with pytest.raises(exceptions.HttpSyntaxException):
_parse_authority_form(b"foo:bar")
with pytest.raises(exceptions.HttpSyntaxException):
_parse_authority_form(b"foo:99999999")
with pytest.raises(exceptions.HttpSyntaxException):
_parse_authority_form(b"f\x00oo:80")
def test_read_response_line():
def t(b):
return _read_response_line(BytesIO(b))

View File

@ -64,17 +64,17 @@ class TestMessage:
def test_eq_ne(self):
resp = tutils.tresp(timestamp_start=42, timestamp_end=42)
same = tutils.tresp(timestamp_start=42, timestamp_end=42)
assert resp == same
assert resp.data == same.data
other = tutils.tresp(timestamp_start=0, timestamp_end=0)
assert resp != other
assert resp.data != other.data
assert resp != 0
def test_serializable(self):
resp = tutils.tresp()
resp2 = http.Response.from_state(resp.get_state())
assert resp == resp2
assert resp.data == resp2.data
def test_content_length_update(self):
resp = tutils.tresp()

View File

@ -32,12 +32,29 @@ class TestRequestCore:
"""
Tests for addons and the attributes that are directly proxied from the data structure
"""
def test_repr(self):
request = treq()
assert repr(request) == "Request(GET address:22/path)"
request.host = None
assert repr(request) == "Request(GET /path)"
def test_init_conv(self):
assert Request(
b"example.com",
80,
"GET",
"http",
"example.com",
"/",
"HTTP/1.1",
(),
None,
(),
0,
0,
) # type: ignore
def test_make(self):
r = Request.make("GET", "https://example.com/")
assert r.method == "GET"
@ -61,56 +78,55 @@ class TestRequestCore:
r = Request.make("GET", "https://example.com/", headers=({"foo": "baz"}))
assert r.headers["foo"] == "baz"
r = Request.make("GET", "https://example.com/", headers=Headers(foo="qux"))
assert r.headers["foo"] == "qux"
with pytest.raises(TypeError):
Request.make("GET", "https://example.com/", headers=42)
def test_first_line_format(self):
_test_passthrough_attr(treq(), "first_line_format")
assert treq(method=b"CONNECT").first_line_format == "authority"
assert treq(authority=b"example.com").first_line_format == "absolute"
assert treq(authority=b"").first_line_format == "relative"
def test_method(self):
_test_decoded_attr(treq(), "method")
def test_scheme(self):
_test_decoded_attr(treq(), "scheme")
assert treq(scheme=None).scheme is None
def test_port(self):
_test_passthrough_attr(treq(), "port")
def test_path(self):
req = treq()
_test_decoded_attr(req, "path")
# path can also be None.
req.path = None
assert req.path is None
assert req.data.path is None
_test_decoded_attr(treq(), "path")
def test_host(self):
def test_authority(self):
request = treq()
assert request.host == request.data.host.decode("idna")
assert request.authority == request.data.authority.decode("idna")
# Test IDNA encoding
# Set str, get raw bytes
request.host = "ídna.example"
assert request.data.host == b"xn--dna-qma.example"
request.authority = "ídna.example"
assert request.data.authority == b"xn--dna-qma.example"
# Set raw bytes, get decoded
request.data.host = b"xn--idn-gla.example"
assert request.host == "idná.example"
request.data.authority = b"xn--idn-gla.example"
assert request.authority == "idná.example"
# Set bytes, get raw bytes
request.host = b"xn--dn-qia9b.example"
assert request.data.host == b"xn--dn-qia9b.example"
request.authority = b"xn--dn-qia9b.example"
assert request.data.authority == b"xn--dn-qia9b.example"
# IDNA encoding is not bijective
request.host = "fußball"
assert request.host == "fussball"
request.authority = "fußball"
assert request.authority == "fussball"
# Don't fail on garbage
request.data.host = b"foo\xFF\x00bar"
assert request.host.startswith("foo")
assert request.host.endswith("bar")
request.data.authority = b"foo\xFF\x00bar"
assert request.authority.startswith("foo")
assert request.authority.endswith("bar")
# foo.bar = foo.bar should not cause any side effects.
d = request.host
request.host = d
assert request.data.host == b"foo\xFF\x00bar"
d = request.authority
request.authority = d
assert request.data.authority == b"foo\xFF\x00bar"
def test_host_update_also_updates_header(self):
request = treq()
@ -119,59 +135,61 @@ class TestRequestCore:
assert "host" not in request.headers
request.headers["Host"] = "foo"
request.authority = "foo"
request.host = "example.org"
assert request.headers["Host"] == "example.org"
assert request.authority == "example.org:22"
def test_get_host_header(self):
no_hdr = treq()
assert no_hdr.host_header is None
h1 = treq(headers=(
(b"host", b"example.com"),
))
assert h1.host_header == "example.com"
h1 = treq(
headers=((b"host", b"header.example.com"),),
authority=b"authority.example.com"
)
assert h1.host_header == "header.example.com"
h2 = treq(headers=(
(b":authority", b"example.org"),
))
assert h2.host_header == "example.org"
h2 = h1.copy()
h2.http_version = "HTTP/2.0"
assert h2.host_header == "authority.example.com"
both_hdrs = treq(headers=(
(b"host", b"example.org"),
(b":authority", b"example.com"),
))
assert both_hdrs.host_header == "example.com"
h2_host_only = h2.copy()
h2_host_only.authority = ""
assert h2_host_only.host_header == "header.example.com"
def test_modify_host_header(self):
h1 = treq()
assert "host" not in h1.headers
assert ":authority" not in h1.headers
h1.host_header = "example.com"
assert "host" in h1.headers
assert ":authority" not in h1.headers
assert h1.headers["Host"] == "example.com"
assert not h1.authority
h1.host_header = None
assert "host" not in h1.headers
assert not h1.authority
h2 = treq(http_version=b"HTTP/2.0")
h2.host_header = "example.org"
assert "host" not in h2.headers
assert ":authority" in h2.headers
del h2.host_header
assert ":authority" not in h2.headers
assert h2.authority == "example.org"
both_hdrs = treq(headers=(
(b":authority", b"example.com"),
(b"host", b"example.org"),
))
both_hdrs.host_header = "foo.example.com"
assert both_hdrs.headers["Host"] == "foo.example.com"
assert both_hdrs.headers[":authority"] == "foo.example.com"
h2.headers["Host"] = "example.org"
h2.host_header = "foo.example.com"
assert h2.headers["Host"] == "foo.example.com"
assert h2.authority == "foo.example.com"
h2.host_header = None
assert "host" not in h2.headers
assert not h2.authority
class TestRequestUtils:
"""
Tests for additional convenience methods.
"""
def test_url(self):
request = treq()
assert request.url == "http://address:22/path"
@ -190,7 +208,7 @@ class TestRequestUtils:
assert request.url == "http://address:22"
def test_url_authority(self):
request = treq(first_line_format="authority")
request = treq(method=b"CONNECT")
assert request.url == "address:22"
def test_pretty_host(self):
@ -201,17 +219,9 @@ class TestRequestUtils:
# Same port as self.port (22)
request.headers["host"] = "other:22"
assert request.pretty_host == "other"
# Different ports
request.headers["host"] = "other"
assert request.pretty_host == "address"
assert request.host == "address"
# Empty host
request.host = None
assert request.pretty_host is None
assert request.host is None
# Invalid IDNA
request.headers["host"] = ".disqus.com:22"
request.headers["host"] = ".disqus.com"
assert request.pretty_host == ".disqus.com"
def test_pretty_url(self):
@ -219,19 +229,19 @@ class TestRequestUtils:
# Without host header
assert request.url == "http://address:22/path"
assert request.pretty_url == "http://address:22/path"
# Same port as self.port (22)
request.headers["host"] = "other:22"
assert request.pretty_url == "http://other:22/path"
# Different ports
request.headers["host"] = "other"
assert request.pretty_url == "http://address:22/path"
request = treq(method=b"CONNECT", authority=b"example:44")
assert request.pretty_url == "example:44"
def test_pretty_url_options(self):
request = treq(method=b"OPTIONS", path=b"*")
assert request.pretty_url == "http://address:22"
def test_pretty_url_authority(self):
request = treq(first_line_format="authority")
request = treq(method=b"CONNECT", authority="address:22")
assert request.pretty_url == "address:22"
def test_get_query(self):

View File

@ -33,9 +33,9 @@ class TestResponseCore:
"""
def test_repr(self):
response = tresp()
assert repr(response) == "Response(200 OK, unknown content type, 7b)"
assert repr(response) == "Response(200, unknown content type, 7b)"
response.content = None
assert repr(response) == "Response(200 OK, no content)"
assert repr(response) == "Response(200, no content)"
def test_make(self):
r = Response.make()
@ -58,6 +58,9 @@ class TestResponseCore:
r = Response.make(headers=({"foo": "baz"}))
assert r.headers["foo"] == "baz"
r = Response.make(headers=Headers(foo="qux"))
assert r.headers["foo"] == "qux"
with pytest.raises(TypeError):
Response.make(headers=42)
@ -74,18 +77,9 @@ class TestResponseCore:
resp.reason = b"DEF"
assert resp.data.reason == b"DEF"
resp.reason = None
assert resp.data.reason is None
resp.data.reason = b'cr\xe9e'
assert resp.reason == "crée"
# HTTP2 responses do not contain a reason phrase and self.data.reason will be None.
# This should render to an empty reason phrase so that functions
# expecting a string work properly.
resp.data.reason = None
assert resp.reason == ""
class TestResponseUtils:
"""

View File

@ -1,7 +1,10 @@
from typing import AnyStr
import pytest
import sys
from mitmproxy.net.http import url
from mitmproxy.net.http.url import parse_authority
def test_parse():
@ -50,7 +53,6 @@ def test_parse():
def test_ascii_check():
test_url = "https://xyz.tax-edu.net?flag=selectCourse&lc_id=42825&lc_name=茅莽莽猫氓猫氓".encode()
scheme, host, port, full_path = url.parse(test_url)
assert scheme == b'https'
@ -115,10 +117,10 @@ def test_empty_key_trailing_equal_sign():
post_data_empty_key_middle = [('one', 'two'), ('emptykey', ''), ('three', 'four')]
post_data_empty_key_end = [('one', 'two'), ('three', 'four'), ('emptykey', '')]
assert url.encode(post_data_empty_key_middle, similar_to = reference_with_equal) == "one=two&emptykey=&three=four"
assert url.encode(post_data_empty_key_end, similar_to = reference_with_equal) == "one=two&three=four&emptykey="
assert url.encode(post_data_empty_key_middle, similar_to = reference_without_equal) == "one=two&emptykey&three=four"
assert url.encode(post_data_empty_key_end, similar_to = reference_without_equal) == "one=two&three=four&emptykey"
assert url.encode(post_data_empty_key_middle, similar_to=reference_with_equal) == "one=two&emptykey=&three=four"
assert url.encode(post_data_empty_key_end, similar_to=reference_with_equal) == "one=two&three=four&emptykey="
assert url.encode(post_data_empty_key_middle, similar_to=reference_without_equal) == "one=two&emptykey&three=four"
assert url.encode(post_data_empty_key_end, similar_to=reference_without_equal) == "one=two&three=four&emptykey"
def test_encode():
@ -147,3 +149,33 @@ def test_unquote():
def test_hostport():
assert url.hostport(b"https", b"foo.com", 8080) == b"foo.com:8080"
def test_default_port():
assert url.default_port("http") == 80
assert url.default_port(b"https") == 443
assert url.default_port(b"qux") is None
@pytest.mark.parametrize(
"authority,valid,out", [
["foo:42", True, ("foo", 42)],
[b"foo:42", True, ("foo", 42)],
["127.0.0.1:443", True, ("127.0.0.1", 443)],
["[2001:db8:42::]:443", True, ("2001:db8:42::", 443)],
[b"xn--aaa-pla.example:80", True, ("äaaa.example", 80)],
["foo", True, ("foo", None)],
["foo..bar", False, ("foo..bar", None)],
["foo:bar", False, ("foo:bar", None)],
["foo:999999999", False, ("foo:999999999", None)],
[b"\xff", False, ('\udcff', None)]
]
)
def test_parse_authority(authority: AnyStr, valid: bool, out):
assert parse_authority(authority, False) == out
if valid:
assert parse_authority(authority, True) == out
else:
with pytest.raises(ValueError):
parse_authority(authority, True)

View File

@ -69,4 +69,8 @@ def test_is_valid_host():
assert check.is_valid_host(b'api-.a.example.com')
assert check.is_valid_host(b'api-._a.example.com')
assert check.is_valid_host(b'api-.a_.example.com')
assert check.is_valid_host(b'api-.ab.example.com')
assert check.is_valid_host(b'api-.ab.example.com')
# Test str
assert check.is_valid_host('example.tld')
assert not check.is_valid_host("foo..bar") # cannot be idna-encoded.

View File

@ -59,6 +59,7 @@ class TestExpectHeader(tservers.HTTPProxyTest):
client.wfile.flush()
assert client.rfile.readline() == b"HTTP/1.1 100 Continue\r\n"
assert client.rfile.readline() == b"content-length: 0\r\n"
assert client.rfile.readline() == b"\r\n"
client.wfile.write(b"0123456789abcdef\r\n")

View File

@ -10,6 +10,7 @@ import h2
from mitmproxy import options
import mitmproxy.net
import mitmproxy.http
from ...net import tservers as net_tservers
from mitmproxy import exceptions
from mitmproxy.net.http import http1, http2
@ -124,17 +125,9 @@ class _Http2TestBase:
self.client.connect()
# send CONNECT request
self.client.wfile.write(http1.assemble_request(mitmproxy.net.http.Request(
'authority',
b'CONNECT',
b'',
b'localhost',
self.server.server.address[1],
b'/',
b'HTTP/1.1',
[(b'host', b'localhost:%d' % self.server.server.address[1])],
b'',
)))
self.client.wfile.write(http1.assemble_request(
mitmproxy.http.make_connect_request(("localhost", self.server.server.address[1]))
))
self.client.wfile.flush()
# read CONNECT response

View File

@ -6,7 +6,7 @@ import traceback
from mitmproxy import options
from mitmproxy import exceptions
from mitmproxy.http import HTTPFlow
from mitmproxy.http import HTTPFlow, make_connect_request
from mitmproxy.websocket import WebSocketFlow
from mitmproxy.net import tcp
@ -27,9 +27,9 @@ class _WebSocketServerBase(net_tservers.ServerTestBase):
assert websockets.check_handshake(request.headers)
response = http.Response(
"HTTP/1.1",
101,
reason=http.status_codes.RESPONSES.get(101),
http_version=b"HTTP/1.1",
status_code=101,
reason=http.status_codes.RESPONSES.get(101).encode(),
headers=http.Headers(
connection='upgrade',
upgrade='websocket',
@ -37,6 +37,9 @@ class _WebSocketServerBase(net_tservers.ServerTestBase):
sec_websocket_extensions='permessage-deflate' if "permessage-deflate" in request.headers.values() else ''
),
content=b'',
trailers=None,
timestamp_start=0,
timestamp_end=0,
)
self.wfile.write(http.http1.assemble_response(response))
self.wfile.flush()
@ -86,15 +89,7 @@ class _WebSocketTestBase:
self.client = tcp.TCPClient(("127.0.0.1", self.proxy.port))
self.client.connect()
request = http.Request(
"authority",
"CONNECT",
"",
"127.0.0.1",
self.server.server.address[1],
"",
"HTTP/1.1",
content=b'')
request = make_connect_request(("127.0.0.1", self.server.server.address[1]))
self.client.wfile.write(http.http1.assemble_request(request))
self.client.wfile.flush()
@ -105,13 +100,13 @@ class _WebSocketTestBase:
assert self.client.tls_established
request = http.Request(
"relative",
"GET",
"http",
"127.0.0.1",
self.server.server.address[1],
"/ws",
"HTTP/1.1",
host="127.0.0.1",
port=self.server.server.address[1],
method=b"GET",
scheme=b"http",
authority=b"",
path=b"/ws",
http_version=b"HTTP/1.1",
headers=http.Headers(
connection="upgrade",
upgrade="websocket",
@ -119,7 +114,11 @@ class _WebSocketTestBase:
sec_websocket_key="1234",
sec_websocket_extensions="permessage-deflate" if extension else ""
),
content=b'')
content=b'',
trailers=None,
timestamp_start=0,
timestamp_end=0,
)
self.client.wfile.write(http.http1.assemble_request(request))
self.client.wfile.flush()

View File

@ -206,7 +206,7 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin):
p = self.pathoc()
with p.connect():
ret = p.request("get:'https://localhost:%s/'" % self.server.port)
assert ret.status_code == 400
assert ret.status_code == 502
def test_connection_close(self):
# Add a body, so we have a content-length header, which combined with
@ -806,7 +806,7 @@ class TestStreamRequest(tservers.HTTPProxyTest):
class AFakeResponse:
def request(self, f):
f.response = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp())
f.response = mitmproxy.test.tutils.tresp()
class TestFakeResponse(tservers.HTTPProxyTest):
@ -873,7 +873,7 @@ class TestTransparentResolveError(tservers.TransparentProxyTest):
class AIncomplete:
def request(self, f):
resp = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp())
resp = mitmproxy.test.tutils.tresp()
resp.content = None
f.response = resp

View File

@ -1,14 +1,14 @@
import io
import pytest
from mitmproxy.test import tflow, taddons
import mitmproxy.io
from mitmproxy import flow
from mitmproxy import flowfilter
from mitmproxy import options
from mitmproxy.io import tnetstring
from mitmproxy.exceptions import FlowReadException
from mitmproxy import flow
from mitmproxy import http
from mitmproxy.io import tnetstring
from mitmproxy.test import taddons, tflow
from . import tservers
@ -29,7 +29,7 @@ class TestSerialize:
f2 = l[0]
assert f2.get_state() == f.get_state()
assert f2.request == f.request
assert f2.request.data == f.request.data
assert f2.marked
def test_filter(self):
@ -128,11 +128,11 @@ class TestFlowMaster:
with taddons.context(s, options=opts) as ctx:
f = tflow.tflow(req=None)
await ctx.master.addons.handle_lifecycle("clientconnect", f.client_conn)
f.request = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq())
f.request = mitmproxy.test.tutils.treq()
await ctx.master.addons.handle_lifecycle("request", f)
assert len(s.flows) == 1
f.response = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp())
f.response = mitmproxy.test.tutils.tresp()
await ctx.master.addons.handle_lifecycle("response", f)
assert len(s.flows) == 1

View File

@ -24,7 +24,7 @@ class TestHTTPRequest:
assert hash(r)
def test_get_url(self):
r = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq())
r = mitmproxy.test.tutils.treq()
assert r.url == "http://address:22/path"
@ -45,7 +45,7 @@ class TestHTTPRequest:
assert r.pretty_url == "https://foo.com:22/path"
def test_constrain_encoding(self):
r = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq())
r = mitmproxy.test.tutils.treq()
r.headers["accept-encoding"] = "gzip, oink"
r.constrain_encoding()
assert "oink" not in r.headers["accept-encoding"]
@ -55,7 +55,7 @@ class TestHTTPRequest:
assert "oink" not in r.headers["accept-encoding"]
def test_get_content_type(self):
resp = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp())
resp = mitmproxy.test.tutils.tresp()
resp.headers = Headers(content_type="text/plain")
assert resp.headers["content-type"] == "text/plain"
@ -69,7 +69,7 @@ class TestHTTPResponse:
assert resp2.get_state() == resp.get_state()
def test_get_content_type(self):
resp = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp())
resp = mitmproxy.test.tutils.tresp()
resp.headers = Headers(content_type="text/plain")
assert resp.headers["content-type"] == "text/plain"
@ -118,7 +118,7 @@ class TestHTTPFlow:
def test_backup(self):
f = tflow.tflow()
f.response = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp())
f.response = mitmproxy.test.tutils.tresp()
f.request.content = b"foo"
assert not f.modified()
f.backup()
@ -218,5 +218,6 @@ def test_make_connect_response():
def test_expect_continue_response():
assert http.expect_continue_response.http_version == 'HTTP/1.1'
assert http.expect_continue_response.status_code == 100
resp = http.make_expect_continue_response()
assert resp.http_version == 'HTTP/1.1'
assert resp.status_code == 100

View File

@ -17,6 +17,7 @@ class T(TBase):
def test_check_option_type():
typecheck.check_option_type("foo", 42, int)
typecheck.check_option_type("foo", 42, float)
with pytest.raises(TypeError):
typecheck.check_option_type("foo", 42, str)
with pytest.raises(TypeError):

View File

@ -1,15 +1,13 @@
from unittest import mock
import codecs
import pytest
import hyperframe
import pytest
from mitmproxy.net import tcp, http
from mitmproxy.net.http import http2
from mitmproxy import exceptions
from ...mitmproxy.net import tservers as net_tservers
from mitmproxy.net import http, tcp
from mitmproxy.net.http import http2
from pathod.protocols.http2 import HTTP2StateProtocol, TCPHandler
from ...mitmproxy.net import tservers as net_tservers
class TestTCPHandlerWrapper:
@ -100,23 +98,23 @@ class TestPerformServerConnectionPreface(net_tservers.ServerTestBase):
def handle(self):
# send magic
self.wfile.write(codecs.decode('505249202a20485454502f322e300d0a0d0a534d0d0a0d0a', 'hex_codec'))
self.wfile.write(bytes.fromhex("505249202a20485454502f322e300d0a0d0a534d0d0a0d0a"))
self.wfile.flush()
# send empty settings frame
self.wfile.write(codecs.decode('000000040000000000', 'hex_codec'))
self.wfile.write(bytes.fromhex("000000040000000000"))
self.wfile.flush()
# check empty settings frame
raw = http2.read_raw_frame(self.rfile)
assert raw == codecs.decode('00000c040000000000000200000000000300000001', 'hex_codec')
assert raw == bytes.fromhex("00000c040000000000000200000000000300000001")
# check settings acknowledgement
raw = http2.read_raw_frame(self.rfile)
assert raw == codecs.decode('000000040100000000', 'hex_codec')
assert raw == bytes.fromhex("000000040100000000")
# send settings acknowledgement
self.wfile.write(codecs.decode('000000040100000000', 'hex_codec'))
self.wfile.write(bytes.fromhex("000000040100000000"))
self.wfile.flush()
def test_perform_server_connection_preface(self):
@ -141,18 +139,18 @@ class TestPerformClientConnectionPreface(net_tservers.ServerTestBase):
# check empty settings frame
assert self.rfile.read(9) ==\
codecs.decode('000000040000000000', 'hex_codec')
bytes.fromhex("000000040000000000")
# send empty settings frame
self.wfile.write(codecs.decode('000000040000000000', 'hex_codec'))
self.wfile.write(bytes.fromhex("000000040000000000"))
self.wfile.flush()
# check settings acknowledgement
assert self.rfile.read(9) == \
codecs.decode('000000040100000000', 'hex_codec')
bytes.fromhex("000000040100000000")
# send settings acknowledgement
self.wfile.write(codecs.decode('000000040100000000', 'hex_codec'))
self.wfile.write(bytes.fromhex("000000040100000000"))
self.wfile.flush()
def test_perform_client_connection_preface(self):
@ -197,7 +195,7 @@ class TestApplySettings(net_tservers.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
# check settings acknowledgement
assert self.rfile.read(9) == codecs.decode('000000040100000000', 'hex_codec')
assert self.rfile.read(9) == bytes.fromhex("000000040100000000")
self.wfile.write(b"OK")
self.wfile.flush()
self.rfile.safe_read(9) # just to keep the connection alive a bit longer
@ -236,15 +234,13 @@ class TestCreateHeaders:
(b':scheme', b'https'),
(b'foo', b'bar')])
bytes = HTTP2StateProtocol(self.c)._create_headers(
data = HTTP2StateProtocol(self.c)._create_headers(
headers, 1, end_stream=True)
assert b''.join(bytes) ==\
codecs.decode('000014010500000001824488355217caf3a69a3f87408294e7838c767f', 'hex_codec')
assert b''.join(data) == bytes.fromhex("000014010500000001824488355217caf3a69a3f87408294e7838c767f")
bytes = HTTP2StateProtocol(self.c)._create_headers(
data = HTTP2StateProtocol(self.c)._create_headers(
headers, 1, end_stream=False)
assert b''.join(bytes) ==\
codecs.decode('000014010400000001824488355217caf3a69a3f87408294e7838c767f', 'hex_codec')
assert b''.join(data) == bytes.fromhex("000014010400000001824488355217caf3a69a3f87408294e7838c767f")
def test_create_headers_multiple_frames(self):
headers = http.Headers([
@ -256,11 +252,11 @@ class TestCreateHeaders:
protocol = HTTP2StateProtocol(self.c)
protocol.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE] = 8
bytes = protocol._create_headers(headers, 1, end_stream=True)
assert len(bytes) == 3
assert bytes[0] == codecs.decode('000008010100000001828487408294e783', 'hex_codec')
assert bytes[1] == codecs.decode('0000080900000000018c767f7685ee5b10', 'hex_codec')
assert bytes[2] == codecs.decode('00000209040000000163d5', 'hex_codec')
data = protocol._create_headers(headers, 1, end_stream=True)
assert len(data) == 3
assert data[0] == bytes.fromhex("000008010100000001828487408294e783")
assert data[1] == bytes.fromhex("0000080900000000018c767f7685ee5b10")
assert data[2] == bytes.fromhex("00000209040000000163d5")
class TestCreateBody:
@ -273,17 +269,17 @@ class TestCreateBody:
def test_create_body_single_frame(self):
protocol = HTTP2StateProtocol(self.c)
bytes = protocol._create_body(b'foobar', 1)
assert b''.join(bytes) == codecs.decode('000006000100000001666f6f626172', 'hex_codec')
data = protocol._create_body(b'foobar', 1)
assert b''.join(data) == bytes.fromhex("000006000100000001666f6f626172")
def test_create_body_multiple_frames(self):
protocol = HTTP2StateProtocol(self.c)
protocol.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE] = 5
bytes = protocol._create_body(b'foobarmehm42', 1)
assert len(bytes) == 3
assert bytes[0] == codecs.decode('000005000000000001666f6f6261', 'hex_codec')
assert bytes[1] == codecs.decode('000005000000000001726d65686d', 'hex_codec')
assert bytes[2] == codecs.decode('0000020001000000013432', 'hex_codec')
data = protocol._create_body(b'foobarmehm42', 1)
assert len(data) == 3
assert data[0] == bytes.fromhex("000005000000000001666f6f6261")
assert data[1] == bytes.fromhex("000005000000000001726d65686d")
assert data[2] == bytes.fromhex("0000020001000000013432")
class TestReadRequest(net_tservers.ServerTestBase):
@ -291,9 +287,9 @@ class TestReadRequest(net_tservers.ServerTestBase):
def handle(self):
self.wfile.write(
codecs.decode('000003010400000001828487', 'hex_codec'))
bytes.fromhex("000003010400000001828487"))
self.wfile.write(
codecs.decode('000006000100000001666f6f626172', 'hex_codec'))
bytes.fromhex("000006000100000001666f6f626172"))
self.wfile.flush()
self.rfile.safe_read(9) # just to keep the connection alive a bit longer
@ -320,7 +316,7 @@ class TestReadRequestRelative(net_tservers.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
self.wfile.write(
codecs.decode('00000c0105000000014287d5af7e4d5a777f4481f9', 'hex_codec'))
bytes.fromhex("00000c0105000000014287d5af7e4d5a777f4481f9"))
self.wfile.flush()
ssl = True
@ -339,37 +335,13 @@ class TestReadRequestRelative(net_tservers.ServerTestBase):
assert req.path == "*"
class TestReadRequestAbsolute(net_tservers.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
self.wfile.write(
codecs.decode('00001901050000000182448d9d29aee30c0e492c2a1170426366871c92585422e085', 'hex_codec'))
self.wfile.flush()
ssl = True
def test_absolute_form(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_tls()
protocol = HTTP2StateProtocol(c, is_server=True)
protocol.connection_preface_performed = True
req = protocol.read_request(NotImplemented)
assert req.first_line_format == "absolute"
assert req.scheme == "http"
assert req.host == "address"
assert req.port == 22
class TestReadResponse(net_tservers.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
self.wfile.write(
codecs.decode('00000801040000002a88628594e78c767f', 'hex_codec'))
bytes.fromhex("00000801040000002a88628594e78c767f"))
self.wfile.write(
codecs.decode('00000600010000002a666f6f626172', 'hex_codec'))
bytes.fromhex("00000600010000002a666f6f626172"))
self.wfile.flush()
self.rfile.safe_read(9) # just to keep the connection alive a bit longer
@ -396,7 +368,7 @@ class TestReadEmptyResponse(net_tservers.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
self.wfile.write(
codecs.decode('00000801050000002a88628594e78c767f', 'hex_codec'))
bytes.fromhex("00000801050000002a88628594e78c767f"))
self.wfile.flush()
ssl = True
@ -422,89 +394,107 @@ class TestAssembleRequest:
c = tcp.TCPClient(("127.0.0.1", 0))
def test_request_simple(self):
bytes = HTTP2StateProtocol(self.c).assemble_request(http.Request(
b'',
b'GET',
b'https',
b'',
b'',
b'/',
b"HTTP/2.0",
(),
None,
data = HTTP2StateProtocol(self.c).assemble_request(http.Request(
host="",
port=0,
method=b'GET',
scheme=b'https',
authority=b'',
path=b'/',
http_version=b"HTTP/2.0",
headers=(),
content=None,
trailers=None,
timestamp_start=0,
timestamp_end=0
))
assert len(bytes) == 1
assert bytes[0] == codecs.decode('00000d0105000000018284874188089d5c0b8170dc07', 'hex_codec')
assert len(data) == 1
assert data[0] == bytes.fromhex('00000d0105000000018284874188089d5c0b8170dc07')
def test_request_with_stream_id(self):
req = http.Request(
b'',
b'GET',
b'https',
b'',
b'',
b'/',
b"HTTP/2.0",
(),
None,
host="",
port=0,
method=b'GET',
scheme=b'https',
authority=b'',
path=b'/',
http_version=b"HTTP/2.0",
headers=(),
content=None,
trailers=None,
timestamp_start=0,
timestamp_end=0
)
req.stream_id = 0x42
bytes = HTTP2StateProtocol(self.c).assemble_request(req)
assert len(bytes) == 1
assert bytes[0] == codecs.decode('00000d0105000000428284874188089d5c0b8170dc07', 'hex_codec')
data = HTTP2StateProtocol(self.c).assemble_request(req)
assert len(data) == 1
assert data[0] == bytes.fromhex('00000d0105000000428284874188089d5c0b8170dc07')
def test_request_with_body(self):
bytes = HTTP2StateProtocol(self.c).assemble_request(http.Request(
b'',
b'GET',
b'https',
b'',
b'',
b'/',
b"HTTP/2.0",
http.Headers([(b'foo', b'bar')]),
b'foobar',
data = HTTP2StateProtocol(self.c).assemble_request(http.Request(
host="",
port=0,
method=b'GET',
scheme=b'https',
authority=b'',
path=b'/',
http_version=b"HTTP/2.0",
headers=http.Headers([(b'foo', b'bar')]),
content=b'foobar',
trailers=None,
timestamp_start=0,
timestamp_end=None,
))
assert len(bytes) == 2
assert bytes[0] ==\
codecs.decode('0000150104000000018284874188089d5c0b8170dc07408294e7838c767f', 'hex_codec')
assert bytes[1] ==\
codecs.decode('000006000100000001666f6f626172', 'hex_codec')
assert len(data) == 2
assert data[0] == bytes.fromhex("0000150104000000018284874188089d5c0b8170dc07408294e7838c767f")
assert data[1] == bytes.fromhex("000006000100000001666f6f626172")
class TestAssembleResponse:
c = tcp.TCPClient(("127.0.0.1", 0))
def test_simple(self):
bytes = HTTP2StateProtocol(self.c, is_server=True).assemble_response(http.Response(
b"HTTP/2.0",
200,
data = HTTP2StateProtocol(self.c, is_server=True).assemble_response(http.Response(
http_version=b"HTTP/2.0",
status_code=200,
reason=b"",
headers=(),
content=b"",
trailers=None,
timestamp_start=0,
timestamp_end=0,
))
assert len(bytes) == 1
assert bytes[0] ==\
codecs.decode('00000101050000000288', 'hex_codec')
assert len(data) == 1
assert data[0] == bytes.fromhex("00000101050000000288")
def test_with_stream_id(self):
resp = http.Response(
b"HTTP/2.0",
200,
http_version=b"HTTP/2.0",
status_code=200,
reason=b"",
headers=(),
content=b"",
trailers=None,
timestamp_start=0,
timestamp_end=0,
)
resp.stream_id = 0x42
bytes = HTTP2StateProtocol(self.c, is_server=True).assemble_response(resp)
assert len(bytes) == 1
assert bytes[0] ==\
codecs.decode('00000101050000004288', 'hex_codec')
data = HTTP2StateProtocol(self.c, is_server=True).assemble_response(resp)
assert len(data) == 1
assert data[0] == bytes.fromhex("00000101050000004288")
def test_with_body(self):
bytes = HTTP2StateProtocol(self.c, is_server=True).assemble_response(http.Response(
b"HTTP/2.0",
200,
b'',
http.Headers(foo=b"bar"),
b'foobar'
data = HTTP2StateProtocol(self.c, is_server=True).assemble_response(http.Response(
http_version=b"HTTP/2.0",
status_code=200,
reason=b'',
headers=http.Headers(foo=b"bar"),
content=b'foobar',
trailers=None,
timestamp_start=0,
timestamp_end=0,
))
assert len(bytes) == 2
assert bytes[0] ==\
codecs.decode('00000901040000000288408294e7838c767f', 'hex_codec')
assert bytes[1] ==\
codecs.decode('000006000100000002666f6f626172', 'hex_codec')
assert len(data) == 2
assert data[0] == bytes.fromhex("00000901040000000288408294e7838c767f")
assert data[1] == bytes.fromhex("000006000100000002666f6f626172")

View File

@ -1,23 +1,16 @@
import io
from unittest.mock import Mock
import pytest
from mitmproxy.net import http
from mitmproxy.net.http import http1
from mitmproxy import exceptions
from pathod import pathoc, language
from pathod.protocols.http2 import HTTP2StateProtocol
from mitmproxy.net.http import http1
from mitmproxy.test import tutils
from pathod import language, pathoc
from pathod.protocols.http2 import HTTP2StateProtocol
from . import tservers
def test_response():
r = http.Response(b"HTTP/1.1", 200, b"Message", {}, None, None)
assert repr(r)
class PathocTestDaemon(tservers.DaemonTests):
def tval(self, requests, timeout=None, showssl=False, **kwargs):
s = io.StringIO()