mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-22 15:37:45 +00:00
[sans-io] lint!
This commit is contained in:
parent
6b2e49eb13
commit
8201a90e22
@ -33,16 +33,16 @@ class MockServer(layers.http.HttpConnection):
|
||||
|
||||
def _handle_event(self, event: events.Event) -> CommandGenerator[None]:
|
||||
if isinstance(event, events.Start):
|
||||
has_content = bool(self.flow.request.raw_content)
|
||||
content = self.flow.request.raw_content
|
||||
self.flow.request.timestamp_start = self.flow.request.timestamp_end = time.time()
|
||||
yield layers.http.ReceiveHttp(layers.http.RequestHeaders(
|
||||
1,
|
||||
self.flow.request,
|
||||
end_stream=not has_content,
|
||||
end_stream=not content,
|
||||
replay_flow=self.flow,
|
||||
))
|
||||
if has_content:
|
||||
yield layers.http.ReceiveHttp(layers.http.RequestData(1, self.flow.request.raw_content))
|
||||
if content:
|
||||
yield layers.http.ReceiveHttp(layers.http.RequestData(1, content))
|
||||
yield layers.http.ReceiveHttp(layers.http.RequestEndOfMessage(1))
|
||||
elif isinstance(event, (
|
||||
layers.http.ResponseHeaders,
|
||||
@ -56,6 +56,8 @@ class MockServer(layers.http.HttpConnection):
|
||||
|
||||
|
||||
class ReplayHandler(server.ConnectionHandler):
|
||||
layer: layers.HttpLayer
|
||||
|
||||
def __init__(self, flow: http.HTTPFlow, options: Options) -> None:
|
||||
client = flow.client_conn.copy()
|
||||
client.state = ConnectionState.OPEN
|
||||
@ -91,8 +93,9 @@ class ReplayHandler(server.ConnectionHandler):
|
||||
if self.transports:
|
||||
# close server connections
|
||||
for x in self.transports.values():
|
||||
x.handler.cancel()
|
||||
await asyncio.wait([x.handler for x in self.transports.values()])
|
||||
if x.handler:
|
||||
x.handler.cancel()
|
||||
await asyncio.wait([x.handler for x in self.transports.values() if x.handler])
|
||||
# signal completion
|
||||
self.done.set()
|
||||
|
||||
@ -140,6 +143,7 @@ class ClientPlayback:
|
||||
return "Can't replay flow with missing content."
|
||||
else:
|
||||
return "Can only replay HTTP flows."
|
||||
return None
|
||||
|
||||
def load(self, loader):
|
||||
loader.add_option(
|
||||
|
@ -14,7 +14,7 @@ LayerCls = typing.Type[layer.Layer]
|
||||
|
||||
def stack_match(
|
||||
context: context.Context,
|
||||
layers: typing.List[typing.Union[LayerCls, typing.Tuple[LayerCls, ...]]]
|
||||
layers: typing.Sequence[typing.Union[LayerCls, typing.Tuple[LayerCls, ...]]]
|
||||
) -> bool:
|
||||
if len(context.layers) != len(layers):
|
||||
return False
|
||||
@ -74,11 +74,15 @@ class NextLayer:
|
||||
hostnames.append(context.server.address[0])
|
||||
if is_tls_record_magic(data_client):
|
||||
try:
|
||||
sni = parse_client_hello(data_client).sni
|
||||
ch = parse_client_hello(data_client)
|
||||
if ch is None:
|
||||
return None
|
||||
sni = ch.sni
|
||||
except ValueError:
|
||||
return None # defer decision, wait for more input data
|
||||
else:
|
||||
hostnames.append(sni.decode("idna"))
|
||||
if sni:
|
||||
hostnames.append(sni.decode("idna"))
|
||||
|
||||
if not hostnames:
|
||||
return False
|
||||
@ -95,6 +99,8 @@ class NextLayer:
|
||||
for host in hostnames
|
||||
for rex in ctx.options.allow_hosts
|
||||
)
|
||||
else: # pragma: no cover
|
||||
raise AssertionError()
|
||||
|
||||
def next_layer(self, nextlayer: layer.NextLayer):
|
||||
if isinstance(nextlayer, base.Layer):
|
||||
@ -106,10 +112,13 @@ class NextLayer:
|
||||
return self.make_top_layer(context)
|
||||
|
||||
if len(data_client) < 3:
|
||||
return
|
||||
return None
|
||||
|
||||
client_tls = is_tls_record_magic(data_client)
|
||||
s = lambda *layers: stack_match(context, layers)
|
||||
|
||||
def s(*layers):
|
||||
return stack_match(context, layers)
|
||||
|
||||
top_layer = context.layers[-1]
|
||||
|
||||
# 1. check for --ignore/--allow
|
||||
@ -117,7 +126,7 @@ class NextLayer:
|
||||
if ignore is True:
|
||||
return layers.TCPLayer(context, ignore=True)
|
||||
if ignore is None:
|
||||
return
|
||||
return None
|
||||
|
||||
# 2. Check for TLS
|
||||
if client_tls:
|
||||
|
@ -51,7 +51,7 @@ class ProxyConnectionHandler(server.StreamConnectionHandler):
|
||||
|
||||
def log(self, message: str, level: str = "info") -> None:
|
||||
x = log.LogEntry(self.log_prefix + message, level)
|
||||
x.reply = controller.DummyReply()
|
||||
x.reply = controller.DummyReply() # type: ignore
|
||||
asyncio_utils.create_task(
|
||||
self.master.addons.handle_lifecycle("log", x),
|
||||
name="ProxyConnectionHandler.log"
|
||||
|
@ -33,7 +33,7 @@ class TlsConfig:
|
||||
"""
|
||||
This addon supplies the proxy core with the desired OpenSSL connection objects to negotiate TLS.
|
||||
"""
|
||||
certstore: certs.CertStore = None
|
||||
certstore: certs.CertStore
|
||||
|
||||
# TODO: We should support configuring TLS 1.3 cipher suites (https://github.com/mitmproxy/mitmproxy/issues/4260)
|
||||
# TODO: We should re-use SSL.Context options here, if only for TLS session resumption.
|
||||
@ -57,7 +57,7 @@ class TlsConfig:
|
||||
our certificate should have and then fetches a matching cert from the certstore.
|
||||
"""
|
||||
altnames: List[bytes] = []
|
||||
organization: Optional[str] = None
|
||||
organization: Optional[bytes] = None
|
||||
|
||||
# Use upstream certificate if available.
|
||||
if conn_context.server.certificate_list:
|
||||
@ -130,6 +130,7 @@ class TlsConfig:
|
||||
def create_proxy_server_ssl_conn(self, tls_start: tls.TlsStartData) -> None:
|
||||
client = tls_start.context.client
|
||||
server = cast(context.Server, tls_start.conn)
|
||||
assert server.address
|
||||
|
||||
if server.sni is True:
|
||||
server.sni = client.sni or server.address[0].encode()
|
||||
@ -179,7 +180,7 @@ class TlsConfig:
|
||||
args["cipher_list"] = ':'.join(server.cipher_list) if server.cipher_list else None
|
||||
ssl_ctx = net_tls.create_client_context(
|
||||
cert=client_cert,
|
||||
sni=server.sni.decode("idna"), # TODO: Should pass-through here.
|
||||
sni=server.sni.decode("idna") if server.sni else None, # TODO: Should pass-through here.
|
||||
alpn_protos=server.alpn_offers,
|
||||
**args
|
||||
)
|
||||
|
@ -16,26 +16,27 @@ def _parse_authority_form(hostport: bytes) -> Tuple[bytes, int]:
|
||||
ValueError, if the input is malformed
|
||||
"""
|
||||
try:
|
||||
host, port = hostport.rsplit(b":", 1)
|
||||
host, port_str = hostport.rsplit(b":", 1)
|
||||
if host.startswith(b"[") and host.endswith(b"]"):
|
||||
host = host[1:-1]
|
||||
port = int(port)
|
||||
port = int(port_str)
|
||||
if not check.is_valid_host(host) or not check.is_valid_port(port):
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid host specification: {hostport}")
|
||||
raise ValueError(f"Invalid host specification: {hostport!r}")
|
||||
|
||||
return host, port
|
||||
|
||||
|
||||
def raise_if_http_version_unknown(http_version):
|
||||
def raise_if_http_version_unknown(http_version: bytes) -> None:
|
||||
if not re.match(br"^HTTP/\d\.\d$", http_version):
|
||||
raise ValueError(f"Unknown HTTP version: {http_version}")
|
||||
raise ValueError(f"Unknown HTTP version: {http_version!r}")
|
||||
|
||||
|
||||
def _read_request_line(line: bytes) -> Tuple[str, int, bytes, bytes, bytes, bytes, bytes]:
|
||||
try:
|
||||
method, target, http_version = line.split()
|
||||
port: Optional[int]
|
||||
|
||||
if target == b"*" or target.startswith(b"/"):
|
||||
scheme, authority, path = b"", b"", target
|
||||
@ -58,7 +59,7 @@ def _read_request_line(line: bytes) -> Tuple[str, int, bytes, bytes, bytes, byte
|
||||
|
||||
raise_if_http_version_unknown(http_version)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Bad HTTP request line: {line}") from e
|
||||
raise ValueError(f"Bad HTTP request line: {line!r}") from e
|
||||
|
||||
return host, port, method, scheme, authority, path, http_version
|
||||
|
||||
@ -69,16 +70,16 @@ def _read_response_line(line: bytes) -> Tuple[bytes, int, bytes]:
|
||||
if len(parts) == 2: # handle missing message gracefully
|
||||
parts.append(b"")
|
||||
|
||||
http_version, status_code, reason = parts
|
||||
status_code = int(status_code)
|
||||
http_version, status_code_str, reason = parts
|
||||
status_code = int(status_code_str)
|
||||
raise_if_http_version_unknown(http_version)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Bad HTTP response line: {line}") from e
|
||||
raise ValueError(f"Bad HTTP response line: {line!r}") from e
|
||||
|
||||
return http_version, status_code, reason
|
||||
|
||||
|
||||
def _read_headers(lines: Iterable[bytes]):
|
||||
def _read_headers(lines: Iterable[bytes]) -> headers.Headers:
|
||||
"""
|
||||
Read a set of headers.
|
||||
Stop once a blank line is reached.
|
||||
@ -89,7 +90,7 @@ def _read_headers(lines: Iterable[bytes]):
|
||||
Raises:
|
||||
exceptions.HttpSyntaxException
|
||||
"""
|
||||
ret = []
|
||||
ret: List[Tuple[bytes, bytes]] = []
|
||||
for line in lines:
|
||||
if line[0] in b" \t":
|
||||
if not ret:
|
||||
@ -104,7 +105,7 @@ def _read_headers(lines: Iterable[bytes]):
|
||||
raise ValueError()
|
||||
ret.append((name, value))
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid header line: {line}")
|
||||
raise ValueError(f"Invalid header line: {line!r}")
|
||||
return headers.Headers(ret)
|
||||
|
||||
|
||||
|
@ -8,17 +8,20 @@ The counterpart to commands are events.
|
||||
"""
|
||||
import dataclasses
|
||||
import re
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Type
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Type, Union, TYPE_CHECKING
|
||||
|
||||
from mitmproxy.proxy2.context import Connection, Server
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import mitmproxy.proxy2.layer
|
||||
|
||||
|
||||
class Command:
|
||||
"""
|
||||
Base class for all commands
|
||||
"""
|
||||
|
||||
blocking: ClassVar[bool] = False
|
||||
blocking: Union[bool, "mitmproxy.proxy2.layer.Layer"] = False
|
||||
"""
|
||||
Determines if the command blocks until it has been completed.
|
||||
|
||||
|
@ -1,13 +1,16 @@
|
||||
import uuid
|
||||
import warnings
|
||||
from enum import Flag
|
||||
from typing import List, Literal, Optional, Sequence, Tuple, Union
|
||||
from typing import List, Literal, Optional, Sequence, Tuple, Union, TYPE_CHECKING
|
||||
|
||||
from mitmproxy import certs
|
||||
from mitmproxy.coretypes import serializable
|
||||
from mitmproxy.net import server_spec
|
||||
from mitmproxy.options import Options
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import mitmproxy.proxy2.layer
|
||||
|
||||
|
||||
class ConnectionState(Flag):
|
||||
CLOSED = 0
|
||||
@ -38,9 +41,9 @@ class Connection(serializable.Serializable):
|
||||
tls: bool = False
|
||||
certificate_list: Optional[Sequence[certs.Cert]] = None
|
||||
"""
|
||||
The TLS certificate list as sent by the peer.
|
||||
The TLS certificate list as sent by the peer.
|
||||
The first certificate is the end-entity certificate.
|
||||
|
||||
|
||||
[RFC 8446] Prior to TLS 1.3, "certificate_list" ordering required each
|
||||
certificate to certify the one immediately preceding it; however,
|
||||
some implementations allowed some flexibility. Servers sometimes
|
||||
@ -84,7 +87,7 @@ class Connection(serializable.Serializable):
|
||||
return f"{type(self).__name__}({attrs})"
|
||||
|
||||
@property
|
||||
def alpn_proto_negotiated(self) -> bytes:
|
||||
def alpn_proto_negotiated(self) -> Optional[bytes]:
|
||||
warnings.warn("Server.alpn_proto_negotiated is deprecated, use Server.alpn instead.", PendingDeprecationWarning)
|
||||
return self.alpn
|
||||
|
||||
@ -185,11 +188,11 @@ class Server(Connection):
|
||||
timestamp_tcp_setup: Optional[float] = None
|
||||
"""TCP ACK received"""
|
||||
|
||||
sni = True
|
||||
sni: Union[bytes, Literal[True], None] = True
|
||||
"""True: client SNI, False: no SNI, bytes: custom value"""
|
||||
via: Optional[server_spec.ServerSpec] = None
|
||||
|
||||
def __init__(self, address: Optional[tuple]):
|
||||
def __init__(self, address: Optional[Address]):
|
||||
self.id = str(uuid.uuid4())
|
||||
self.address = address
|
||||
|
||||
@ -250,7 +253,7 @@ class Server(Connection):
|
||||
self.via = state["via2"]
|
||||
|
||||
@property
|
||||
def ip_address(self) -> Address:
|
||||
def ip_address(self) -> Optional[Address]:
|
||||
warnings.warn("Server.ip_address is deprecated, use Server.peername instead.", PendingDeprecationWarning)
|
||||
return self.peername
|
||||
|
||||
|
@ -3,22 +3,22 @@ Base class for protocol layers.
|
||||
"""
|
||||
import collections
|
||||
import textwrap
|
||||
import typing
|
||||
from abc import abstractmethod
|
||||
from typing import Optional, List, ClassVar, Deque, NamedTuple, Generator, Any, TypeVar
|
||||
|
||||
from mitmproxy import log
|
||||
from mitmproxy.proxy2 import commands, events
|
||||
from mitmproxy.proxy2.commands import Command, Hook
|
||||
from mitmproxy.proxy2.context import Connection, Context
|
||||
|
||||
T = typing.TypeVar('T')
|
||||
CommandGenerator = typing.Generator[Command, typing.Optional[events.CommandReply], T]
|
||||
T = TypeVar('T')
|
||||
CommandGenerator = Generator[Command, Any, T]
|
||||
"""
|
||||
A function annotated with CommandGenerator[bool] may yield commands and ultimately return a boolean value.
|
||||
"""
|
||||
|
||||
|
||||
class Paused(typing.NamedTuple):
|
||||
class Paused(NamedTuple):
|
||||
"""
|
||||
State of a layer that's paused because it is waiting for a command reply.
|
||||
"""
|
||||
@ -27,11 +27,11 @@ class Paused(typing.NamedTuple):
|
||||
|
||||
|
||||
class Layer:
|
||||
__last_debug_message: typing.ClassVar[str] = ""
|
||||
__last_debug_message: ClassVar[str] = ""
|
||||
context: Context
|
||||
_paused: typing.Optional[Paused]
|
||||
_paused_event_queue: typing.Deque[events.Event]
|
||||
debug: typing.Optional[str] = None
|
||||
_paused: Optional[Paused]
|
||||
_paused_event_queue: Deque[events.Event]
|
||||
debug: Optional[str] = None
|
||||
"""
|
||||
Enable debug logging by assigning a prefix string for log messages.
|
||||
Different amounts of whitespace for different layers work well.
|
||||
@ -94,6 +94,7 @@ class Layer:
|
||||
if self.debug is not None:
|
||||
yield self.__debug(f"{'>>' if pause_finished else '>!'} {event}")
|
||||
if pause_finished:
|
||||
assert isinstance(event, events.CommandReply)
|
||||
yield from self.__continue(event)
|
||||
else:
|
||||
self._paused_event_queue.append(event)
|
||||
@ -114,7 +115,7 @@ class Layer:
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
while command:
|
||||
while True:
|
||||
if self.debug is not None:
|
||||
if not isinstance(command, commands.Log):
|
||||
yield self.__debug(f"<< {command}")
|
||||
@ -128,19 +129,23 @@ class Layer:
|
||||
return
|
||||
else:
|
||||
yield command
|
||||
command = next(command_generator, None)
|
||||
try:
|
||||
command = next(command_generator)
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
def __continue(self, event: events.CommandReply):
|
||||
"""continue processing events after being paused"""
|
||||
assert self._paused is not None
|
||||
command_generator = self._paused.generator
|
||||
self._paused = None
|
||||
yield from self.__process(command_generator, event.reply)
|
||||
|
||||
while not self._paused and self._paused_event_queue:
|
||||
event = self._paused_event_queue.popleft()
|
||||
ev = self._paused_event_queue.popleft()
|
||||
if self.debug is not None:
|
||||
yield self.__debug(f"!> {event}")
|
||||
command_generator = self._handle_event(event)
|
||||
yield self.__debug(f"!> {ev}")
|
||||
command_generator = self._handle_event(ev)
|
||||
yield from self.__process(command_generator)
|
||||
|
||||
|
||||
@ -152,10 +157,10 @@ class NextLayerHook(Hook):
|
||||
|
||||
|
||||
class NextLayer(Layer):
|
||||
layer: typing.Optional[Layer]
|
||||
layer: Optional[Layer]
|
||||
"""The next layer. To be set by an addon."""
|
||||
|
||||
events: typing.List[mevents.Event]
|
||||
events: List[mevents.Event]
|
||||
"""All events that happened before a decision was made."""
|
||||
|
||||
_ask_on_start: bool
|
||||
|
@ -1,7 +1,7 @@
|
||||
import collections
|
||||
import time
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union, Dict, DefaultDict, List
|
||||
|
||||
from mitmproxy import flow, http
|
||||
from mitmproxy.net import server_spec
|
||||
@ -13,7 +13,7 @@ from mitmproxy.proxy2.layers import tls, websocket, tcp
|
||||
from mitmproxy.proxy2.layers.http import _upstream_proxy
|
||||
from mitmproxy.proxy2.utils import expect
|
||||
from mitmproxy.utils import human
|
||||
from ._base import HttpCommand, HttpConnection, ReceiveHttp, StreamId
|
||||
from ._base import HttpCommand, ReceiveHttp, StreamId, HttpConnection
|
||||
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
|
||||
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
|
||||
from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \
|
||||
@ -22,7 +22,7 @@ from ._http1 import Http1Client, Http1Server
|
||||
from ._http2 import Http2Client, Http2Server
|
||||
|
||||
|
||||
def validate_request(mode, request) -> typing.Optional[str]:
|
||||
def validate_request(mode, request) -> Optional[str]:
|
||||
if request.scheme not in ("http", "https", ""):
|
||||
return f"Invalid request scheme: {request.scheme}"
|
||||
if mode is HTTPMode.transparent and request.method == "CONNECT":
|
||||
@ -39,9 +39,9 @@ class GetHttpConnection(HttpCommand):
|
||||
Open an HTTP Connection. This may not actually open a connection, but return an existing HTTP connection instead.
|
||||
"""
|
||||
blocking = True
|
||||
address: typing.Tuple[str, int]
|
||||
address: Tuple[str, int]
|
||||
tls: bool
|
||||
via: typing.Optional[server_spec.ServerSpec]
|
||||
via: Optional[server_spec.ServerSpec]
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
@ -61,29 +61,23 @@ class GetHttpConnection(HttpCommand):
|
||||
@dataclass
|
||||
class GetHttpConnectionReply(events.CommandReply):
|
||||
command: GetHttpConnection
|
||||
reply: typing.Tuple[typing.Optional[Connection], typing.Optional[str]]
|
||||
reply: Union[Tuple[None, str], Tuple[Connection, None]]
|
||||
"""connection object, error message"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegisterHttpConnection(HttpCommand):
|
||||
"""
|
||||
Register that a HTTP connection has been successfully established.
|
||||
Register that a HTTP connection attempt has been completed.
|
||||
"""
|
||||
connection: Connection
|
||||
err: str
|
||||
|
||||
def __init__(self, connection: Connection, err: str):
|
||||
self.connection = connection
|
||||
self.err = err
|
||||
err: Optional[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendHttp(HttpCommand):
|
||||
connection: Connection
|
||||
event: HttpEvent
|
||||
|
||||
def __init__(self, event: HttpEvent, connection: Connection):
|
||||
self.connection = connection
|
||||
self.event = event
|
||||
connection: Connection
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Send({self.event})"
|
||||
@ -93,8 +87,8 @@ class HttpStream(layer.Layer):
|
||||
request_body_buf: bytes
|
||||
response_body_buf: bytes
|
||||
flow: http.HTTPFlow
|
||||
stream_id: StreamId = None
|
||||
child_layer: typing.Optional[layer.Layer] = None
|
||||
stream_id: StreamId
|
||||
child_layer: Optional[layer.Layer] = None
|
||||
|
||||
@property
|
||||
def mode(self):
|
||||
@ -102,12 +96,13 @@ class HttpStream(layer.Layer):
|
||||
parent: HttpLayer = self.context.layers[i - 1]
|
||||
return parent.mode
|
||||
|
||||
def __init__(self, context: Context):
|
||||
def __init__(self, context: Context, stream_id: int):
|
||||
super().__init__(context)
|
||||
self.request_body_buf = b""
|
||||
self.response_body_buf = b""
|
||||
self.client_state = self.state_uninitialized
|
||||
self.server_state = self.state_uninitialized
|
||||
self.stream_id = stream_id
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
@ -131,7 +126,6 @@ class HttpStream(layer.Layer):
|
||||
|
||||
@expect(RequestHeaders)
|
||||
def state_wait_for_request_headers(self, event: RequestHeaders) -> layer.CommandGenerator[None]:
|
||||
self.stream_id = event.stream_id
|
||||
if not event.replay_flow:
|
||||
self.flow = http.HTTPFlow(
|
||||
self.context.client,
|
||||
@ -152,6 +146,7 @@ class HttpStream(layer.Layer):
|
||||
|
||||
if self.mode is HTTPMode.transparent:
|
||||
# Determine .scheme, .host and .port attributes for transparent requests
|
||||
assert self.context.server.address
|
||||
self.flow.request.data.host = self.context.server.address[0]
|
||||
self.flow.request.data.port = self.context.server.address[1]
|
||||
self.flow.request.scheme = "https" if self.context.server.tls else "http"
|
||||
@ -176,10 +171,11 @@ class HttpStream(layer.Layer):
|
||||
if self.mode is HTTPMode.regular and not self.flow.request.is_http2:
|
||||
# Set the request target to origin-form for HTTP/1, some servers don't support absolute-form requests.
|
||||
# see https://github.com/mitmproxy/mitmproxy/issues/1759
|
||||
self.flow.request.authority = b""
|
||||
self.flow.request.authority = ""
|
||||
|
||||
# update host header in reverse proxy mode
|
||||
if self.context.options.mode.startswith("reverse:") and not self.context.options.keep_host_header:
|
||||
assert self.context.server.address
|
||||
self.flow.request.host_header = url.hostport(
|
||||
"https" if self.context.server.tls else "http",
|
||||
self.context.server.address[0],
|
||||
@ -210,7 +206,7 @@ class HttpStream(layer.Layer):
|
||||
self.server_state = self.state_wait_for_response_headers
|
||||
|
||||
@expect(RequestData, RequestEndOfMessage)
|
||||
def state_stream_request_body(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
def state_stream_request_body(self, event: Union[RequestData, RequestEndOfMessage]) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, RequestData):
|
||||
if callable(self.flow.request.stream):
|
||||
event.data = self.flow.request.stream(event.data)
|
||||
@ -255,10 +251,10 @@ class HttpStream(layer.Layer):
|
||||
if not ok:
|
||||
return
|
||||
|
||||
has_content = bool(self.flow.request.raw_content)
|
||||
yield SendHttp(RequestHeaders(self.stream_id, self.flow.request, not has_content), self.context.server)
|
||||
if has_content:
|
||||
yield SendHttp(RequestData(self.stream_id, self.flow.request.raw_content), self.context.server)
|
||||
content = self.flow.request.raw_content
|
||||
yield SendHttp(RequestHeaders(self.stream_id, self.flow.request, not content), self.context.server)
|
||||
if content:
|
||||
yield SendHttp(RequestData(self.stream_id, content), self.context.server)
|
||||
yield SendHttp(RequestEndOfMessage(self.stream_id), self.context.server)
|
||||
|
||||
@expect(ResponseHeaders)
|
||||
@ -275,6 +271,7 @@ class HttpStream(layer.Layer):
|
||||
|
||||
@expect(ResponseData, ResponseEndOfMessage)
|
||||
def state_stream_response_body(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
assert self.flow.response
|
||||
if isinstance(event, ResponseData):
|
||||
if callable(self.flow.response.stream):
|
||||
data = self.flow.response.stream(event.data)
|
||||
@ -289,12 +286,14 @@ class HttpStream(layer.Layer):
|
||||
if isinstance(event, ResponseData):
|
||||
self.response_body_buf += event.data
|
||||
elif isinstance(event, ResponseEndOfMessage):
|
||||
assert self.flow.response
|
||||
self.flow.response.data.content = self.response_body_buf
|
||||
self.response_body_buf = b""
|
||||
yield from self.send_response()
|
||||
|
||||
def send_response(self, already_streamed: bool = False):
|
||||
"""We have either consumed the entire response from the server or the response was set by an addon."""
|
||||
assert self.flow.response
|
||||
self.flow.response.timestamp_end = time.time()
|
||||
yield HttpResponseHook(self.flow)
|
||||
self.server_state = self.state_done
|
||||
@ -302,10 +301,10 @@ class HttpStream(layer.Layer):
|
||||
return
|
||||
|
||||
if not already_streamed:
|
||||
has_content = bool(self.flow.response.raw_content)
|
||||
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response, not has_content), self.context.client)
|
||||
if has_content:
|
||||
yield SendHttp(ResponseData(self.stream_id, self.flow.response.raw_content), self.context.client)
|
||||
content = self.flow.response.raw_content
|
||||
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response, not content), self.context.client)
|
||||
if content:
|
||||
yield SendHttp(ResponseData(self.stream_id, content), self.context.client)
|
||||
|
||||
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
||||
|
||||
@ -331,7 +330,7 @@ class HttpStream(layer.Layer):
|
||||
)
|
||||
# The client may have closed the connection while we were waiting for the hook to complete.
|
||||
# We peek into the event queue to see if that is the case.
|
||||
killed_by_remote = False
|
||||
killed_by_remote = None
|
||||
for evt in self._paused_event_queue:
|
||||
if isinstance(evt, RequestProtocolError):
|
||||
killed_by_remote = evt.message
|
||||
@ -356,7 +355,7 @@ class HttpStream(layer.Layer):
|
||||
|
||||
def handle_protocol_error(
|
||||
self,
|
||||
event: typing.Union[RequestProtocolError, ResponseProtocolError]
|
||||
event: Union[RequestProtocolError, ResponseProtocolError]
|
||||
) -> layer.CommandGenerator[None]:
|
||||
is_client_error_but_we_already_talk_upstream = (
|
||||
isinstance(event, RequestProtocolError)
|
||||
@ -450,6 +449,8 @@ class HttpStream(layer.Layer):
|
||||
|
||||
@expect(RequestData, RequestEndOfMessage, events.Event)
|
||||
def passthrough(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
assert self.flow.response
|
||||
assert self.child_layer
|
||||
# HTTP events -> normal connection events
|
||||
if isinstance(event, RequestData):
|
||||
event = events.DataReceived(self.context.client, event.data)
|
||||
@ -509,10 +510,10 @@ class HttpLayer(layer.Layer):
|
||||
ConnectionEvent -> HttpEvent -> HttpCommand -> ConnectionCommand
|
||||
"""
|
||||
mode: HTTPMode
|
||||
command_sources: typing.Dict[commands.Command, layer.Layer]
|
||||
streams: typing.Dict[int, HttpStream]
|
||||
connections: typing.Dict[Connection, layer.Layer]
|
||||
waiting_for_establishment: typing.DefaultDict[Connection, typing.List[GetHttpConnection]]
|
||||
command_sources: Dict[commands.Command, layer.Layer]
|
||||
streams: Dict[int, HttpStream]
|
||||
connections: Dict[Connection, layer.Layer]
|
||||
waiting_for_establishment: DefaultDict[Connection, List[GetHttpConnection]]
|
||||
|
||||
def __init__(self, context: Context, mode: HTTPMode):
|
||||
super().__init__(context)
|
||||
@ -522,6 +523,7 @@ class HttpLayer(layer.Layer):
|
||||
self.streams = {}
|
||||
self.command_sources = {}
|
||||
|
||||
http_conn: HttpConnection
|
||||
if self.context.client.alpn == b"h2":
|
||||
http_conn = Http2Server(context.fork())
|
||||
else:
|
||||
@ -554,7 +556,7 @@ class HttpLayer(layer.Layer):
|
||||
|
||||
def event_to_child(
|
||||
self,
|
||||
child: typing.Union[layer.Layer, HttpStream],
|
||||
child: Union[layer.Layer, HttpStream],
|
||||
event: events.Event,
|
||||
) -> layer.CommandGenerator[None]:
|
||||
for command in child.handle_event(event):
|
||||
@ -567,7 +569,7 @@ class HttpLayer(layer.Layer):
|
||||
|
||||
if isinstance(command, ReceiveHttp):
|
||||
if isinstance(command.event, RequestHeaders):
|
||||
self.streams[command.event.stream_id] = yield from self.make_stream()
|
||||
yield from self.make_stream(command.event.stream_id)
|
||||
stream = self.streams[command.event.stream_id]
|
||||
yield from self.event_to_child(stream, command.event)
|
||||
elif isinstance(command, SendHttp):
|
||||
@ -585,11 +587,10 @@ class HttpLayer(layer.Layer):
|
||||
else:
|
||||
raise AssertionError(f"Not a command: {event}")
|
||||
|
||||
def make_stream(self) -> layer.CommandGenerator[HttpStream]:
|
||||
def make_stream(self, stream_id: int) -> layer.CommandGenerator[None]:
|
||||
ctx = self.context.fork()
|
||||
stream = HttpStream(ctx)
|
||||
yield from self.event_to_child(stream, events.Start())
|
||||
return stream
|
||||
self.streams[stream_id] = HttpStream(ctx, stream_id)
|
||||
yield from self.event_to_child(self.streams[stream_id], events.Start())
|
||||
|
||||
def get_connection(self, event: GetHttpConnection, *, reuse: bool = True) -> layer.CommandGenerator[None]:
|
||||
# Do we already have a connection we can re-use?
|
||||
@ -649,6 +650,7 @@ class HttpLayer(layer.Layer):
|
||||
def register_connection(self, command: RegisterHttpConnection) -> layer.CommandGenerator[None]:
|
||||
waiting = self.waiting_for_establishment.pop(command.connection)
|
||||
|
||||
reply: Union[Tuple[None, str], Tuple[Connection, None]]
|
||||
if command.err:
|
||||
reply = (None, command.err)
|
||||
else:
|
||||
@ -675,11 +677,13 @@ class HttpLayer(layer.Layer):
|
||||
class HttpClient(layer.Layer):
|
||||
@expect(events.Start)
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
err: Optional[str]
|
||||
if self.context.server.connected:
|
||||
err = None
|
||||
else:
|
||||
err = yield commands.OpenConnection(self.context.server)
|
||||
if not err:
|
||||
child_layer: layer.Layer
|
||||
if self.context.server.alpn == b"h2":
|
||||
child_layer = Http2Client(self.context)
|
||||
else:
|
||||
|
@ -27,7 +27,8 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
|
||||
response: Optional[http.HTTPResponse] = None
|
||||
request_done: bool = False
|
||||
response_done: bool = False
|
||||
state: Callable[[events.ConnectionEvent], layer.CommandGenerator[None]]
|
||||
# this is a bit of a hack to make both mypy and PyCharm happy.
|
||||
state: Union[Callable[[events.ConnectionEvent], layer.CommandGenerator[None]], Callable]
|
||||
body_reader: TBodyReader
|
||||
buf: ReceiveBuffer
|
||||
|
||||
@ -50,10 +51,12 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, HttpEvent):
|
||||
yield from self.send(event)
|
||||
else:
|
||||
elif isinstance(event, events.ConnectionEvent):
|
||||
if isinstance(event, events.DataReceived) and self.state != self.passthrough:
|
||||
self.buf += event.data
|
||||
yield from self.state(event)
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event}")
|
||||
|
||||
@expect(events.Start)
|
||||
def start(self, _) -> layer.CommandGenerator[None]:
|
||||
@ -63,6 +66,7 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
|
||||
state = start
|
||||
|
||||
def read_body(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
assert self.stream_id
|
||||
while True:
|
||||
try:
|
||||
if isinstance(event, events.DataReceived):
|
||||
@ -83,6 +87,7 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
|
||||
if data:
|
||||
yield ReceiveHttp(self.ReceiveData(self.stream_id, data))
|
||||
elif isinstance(h11_event, h11.EndOfMessage):
|
||||
assert self.request
|
||||
if h11_event.headers:
|
||||
raise NotImplementedError(f"HTTP trailers are not implemented yet.")
|
||||
if self.request.data.method.upper() != b"CONNECT":
|
||||
@ -99,6 +104,7 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
|
||||
We wait for the current flow to be finished before parsing the next message,
|
||||
as we may want to upgrade to WebSocket or plain TCP before that.
|
||||
"""
|
||||
assert self.stream_id
|
||||
if isinstance(event, events.DataReceived):
|
||||
return
|
||||
elif isinstance(event, events.ConnectionClosed):
|
||||
@ -123,6 +129,7 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
|
||||
self.buf.compress()
|
||||
|
||||
def passthrough(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
assert self.stream_id
|
||||
if isinstance(event, events.DataReceived):
|
||||
yield ReceiveHttp(self.ReceiveData(self.stream_id, event.data))
|
||||
elif isinstance(event, events.ConnectionClosed):
|
||||
@ -137,19 +144,19 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
|
||||
if response:
|
||||
self.response_done = True
|
||||
if self.request_done and self.response_done:
|
||||
assert self.request
|
||||
assert self.response
|
||||
if should_make_pipe(self.request, self.response):
|
||||
yield from self.make_pipe()
|
||||
return
|
||||
connection_done = (
|
||||
http1_sansio.expected_http_body_size(self.request, self.response) == -1 or
|
||||
http1.connection_close(self.request.http_version, self.request.headers) or
|
||||
http1.connection_close(self.response.http_version, self.response.headers) or
|
||||
(
|
||||
# If we proxy HTTP/2 to HTTP/1, we only use upstream connections for one request.
|
||||
# This simplifies our connection management quite a bit as we can rely on
|
||||
# the proxyserver's max-connection-per-server throttling.
|
||||
self.request.is_http2 and isinstance(self, Http1Client)
|
||||
)
|
||||
http1_sansio.expected_http_body_size(self.request, self.response) == -1
|
||||
or http1.connection_close(self.request.http_version, self.request.headers)
|
||||
or http1.connection_close(self.response.http_version, self.response.headers)
|
||||
# If we proxy HTTP/2 to HTTP/1, we only use upstream connections for one request.
|
||||
# This simplifies our connection management quite a bit as we can rely on
|
||||
# the proxyserver's max-connection-per-server throttling.
|
||||
or (self.request.is_http2 and isinstance(self, Http1Client))
|
||||
)
|
||||
if connection_done:
|
||||
yield commands.CloseConnection(self.conn)
|
||||
@ -172,6 +179,7 @@ class Http1Server(Http1Connection):
|
||||
ReceiveProtocolError = RequestProtocolError
|
||||
ReceiveData = RequestData
|
||||
ReceiveEndOfMessage = RequestEndOfMessage
|
||||
stream_id: int
|
||||
|
||||
def __init__(self, context: Context):
|
||||
super().__init__(context, context.client)
|
||||
@ -185,7 +193,7 @@ class Http1Server(Http1Connection):
|
||||
if response.is_http2:
|
||||
response = response.copy()
|
||||
# Convert to an HTTP/1 response.
|
||||
response.http_version = b"HTTP/1.1"
|
||||
response.http_version = "HTTP/1.1"
|
||||
# not everyone supports empty reason phrases, so we better make up one.
|
||||
response.reason = status_codes.RESPONSES.get(response.status_code, "")
|
||||
# Shall we set a Content-Length header here if there is none?
|
||||
@ -194,6 +202,7 @@ class Http1Server(Http1Connection):
|
||||
raw = http1.assemble_response_head(response)
|
||||
yield commands.SendData(self.conn, raw)
|
||||
elif isinstance(event, ResponseData):
|
||||
assert self.response
|
||||
if "chunked" in self.response.headers.get("transfer-encoding", "").lower():
|
||||
raw = b"%x\r\n%s\r\n" % (len(event.data), event.data)
|
||||
else:
|
||||
@ -201,6 +210,7 @@ class Http1Server(Http1Connection):
|
||||
if raw:
|
||||
yield commands.SendData(self.conn, raw)
|
||||
elif isinstance(event, ResponseEndOfMessage):
|
||||
assert self.response
|
||||
if "chunked" in self.response.headers.get("transfer-encoding", "").lower():
|
||||
yield commands.SendData(self.conn, b"0\r\n\r\n")
|
||||
yield from self.mark_done(response=True)
|
||||
@ -235,7 +245,7 @@ class Http1Server(Http1Connection):
|
||||
elif isinstance(event, events.ConnectionClosed):
|
||||
buf = bytes(self.buf)
|
||||
if buf.strip():
|
||||
yield commands.Log(f"Client closed connection before completing request headers: {buf}")
|
||||
yield commands.Log(f"Client closed connection before completing request headers: {buf!r}")
|
||||
yield commands.CloseConnection(self.conn)
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event}")
|
||||
@ -268,13 +278,14 @@ class Http1Client(Http1Connection):
|
||||
if request.is_http2:
|
||||
# Convert to an HTTP/1 request.
|
||||
request = request.copy() # (we could probably be a bit more efficient here.)
|
||||
request.http_version = b"HTTP/1.1"
|
||||
request.http_version = "HTTP/1.1"
|
||||
if "Host" not in request.headers and request.authority:
|
||||
request.headers.insert(0, "Host", request.authority)
|
||||
request.authority = b""
|
||||
request.authority = ""
|
||||
raw = http1.assemble_request_head(request)
|
||||
yield commands.SendData(self.conn, raw)
|
||||
elif isinstance(event, RequestData):
|
||||
assert self.request
|
||||
if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
|
||||
raw = b"%x\r\n%s\r\n" % (len(event.data), event.data)
|
||||
else:
|
||||
@ -282,6 +293,7 @@ class Http1Client(Http1Connection):
|
||||
if raw:
|
||||
yield commands.SendData(self.conn, raw)
|
||||
elif isinstance(event, RequestEndOfMessage):
|
||||
assert self.request
|
||||
if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
|
||||
yield commands.SendData(self.conn, b"0\r\n\r\n")
|
||||
elif http1_sansio.expected_http_body_size(self.request, self.response) == -1:
|
||||
@ -300,6 +312,7 @@ class Http1Client(Http1Connection):
|
||||
yield commands.Log(f"Unexpected data from server: {bytes(self.buf)!r}")
|
||||
yield commands.CloseConnection(self.conn)
|
||||
return
|
||||
assert self.stream_id
|
||||
|
||||
response_head = self.buf.maybe_extract_lines()
|
||||
if response_head:
|
||||
|
@ -1,7 +1,7 @@
|
||||
import collections
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import ClassVar, DefaultDict, Dict, Iterable, List, Optional, Tuple, Type, Union
|
||||
from typing import ClassVar, DefaultDict, Dict, List, Optional, Tuple, Type, Union, Sequence
|
||||
|
||||
import h2.config
|
||||
import h2.connection
|
||||
@ -76,6 +76,7 @@ class Http2Connection(HttpConnection):
|
||||
|
||||
elif isinstance(event, HttpEvent):
|
||||
if isinstance(event, self.SendData):
|
||||
assert isinstance(event, (RequestData, ResponseData))
|
||||
self.h2_conn.send_data(event.stream_id, event.data)
|
||||
elif isinstance(event, self.SendEndOfMessage):
|
||||
stream = self.h2_conn.streams.get(event.stream_id)
|
||||
@ -85,6 +86,7 @@ class Http2Connection(HttpConnection):
|
||||
if self.is_closed(event.stream_id):
|
||||
self.streams.pop(event.stream_id, None)
|
||||
elif isinstance(event, self.SendProtocolError):
|
||||
assert isinstance(event, (RequestProtocolError, ResponseProtocolError))
|
||||
stream = self.h2_conn.streams.get(event.stream_id)
|
||||
if stream.state_machine.state is not h2.stream.StreamState.CLOSED:
|
||||
code = {
|
||||
@ -192,6 +194,7 @@ class Http2Connection(HttpConnection):
|
||||
yield Log(f"Ignoring unknown HTTP/2 frame type: {event.frame.type}")
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event!r}")
|
||||
return False
|
||||
|
||||
def protocol_error(
|
||||
self,
|
||||
@ -208,7 +211,7 @@ class Http2Connection(HttpConnection):
|
||||
for stream_id in self.streams:
|
||||
yield ReceiveHttp(self.ReceiveProtocolError(stream_id, msg))
|
||||
self.streams.clear()
|
||||
self._handle_event = self.done
|
||||
self._handle_event = self.done # type: ignore
|
||||
|
||||
@expect(DataReceived, HttpEvent, ConnectionClosed)
|
||||
def done(self, _) -> CommandGenerator[None]:
|
||||
@ -285,6 +288,7 @@ class Http2Server(Http2Connection):
|
||||
)
|
||||
self.streams[event.stream_id] = StreamState.HEADERS_RECEIVED
|
||||
yield ReceiveHttp(RequestHeaders(event.stream_id, request, end_stream=bool(event.stream_ended)))
|
||||
return False
|
||||
else:
|
||||
return (yield from super().handle_h2_event(event))
|
||||
|
||||
@ -303,9 +307,9 @@ class Http2Client(Http2Connection):
|
||||
ReceiveData = ResponseData
|
||||
ReceiveEndOfMessage = ResponseEndOfMessage
|
||||
|
||||
our_stream_id = Dict[int, int]
|
||||
their_stream_id = Dict[int, int]
|
||||
stream_queue = DefaultDict[int, List[Event]]
|
||||
our_stream_id: Dict[int, int]
|
||||
their_stream_id: Dict[int, int]
|
||||
stream_queue: DefaultDict[int, List[Event]]
|
||||
"""Queue of streams that we haven't sent yet because we have reached MAX_CONCURRENT_STREAMS"""
|
||||
provisional_max_concurrency: Optional[int] = 10
|
||||
"""A provisional currency limit before we get the server's first settings frame."""
|
||||
@ -360,9 +364,9 @@ class Http2Client(Http2Connection):
|
||||
def _handle_event2(self, event: Event) -> CommandGenerator[None]:
|
||||
if isinstance(event, RequestHeaders):
|
||||
pseudo_headers = [
|
||||
(b':method', event.request.method),
|
||||
(b':scheme', event.request.scheme),
|
||||
(b':path', event.request.path),
|
||||
(b':method', event.request.data.method),
|
||||
(b':scheme', event.request.data.scheme),
|
||||
(b':path', event.request.data.path),
|
||||
]
|
||||
if event.request.authority:
|
||||
pseudo_headers.append((b":authority", event.request.data.authority))
|
||||
@ -410,6 +414,7 @@ class Http2Client(Http2Connection):
|
||||
)
|
||||
self.streams[event.stream_id] = StreamState.HEADERS_RECEIVED
|
||||
yield ReceiveHttp(ResponseHeaders(event.stream_id, response, bool(event.stream_ended)))
|
||||
return False
|
||||
elif isinstance(event, h2.events.RequestReceived):
|
||||
yield from self.protocol_error(f"HTTP/2 protocol error: received request from server")
|
||||
return True
|
||||
@ -422,13 +427,13 @@ class Http2Client(Http2Connection):
|
||||
return (yield from super().handle_h2_event(event))
|
||||
|
||||
|
||||
def split_pseudo_headers(h2_headers: Iterable[Tuple[bytes, bytes]]) -> Tuple[Dict[bytes, bytes], net_http.Headers]:
|
||||
def split_pseudo_headers(h2_headers: Sequence[Tuple[bytes, bytes]]) -> Tuple[Dict[bytes, bytes], net_http.Headers]:
|
||||
pseudo_headers: Dict[bytes, bytes] = {}
|
||||
i = 0
|
||||
for (header, value) in h2_headers:
|
||||
if header.startswith(b":"):
|
||||
if header in pseudo_headers:
|
||||
raise ValueError(f"Duplicate HTTP/2 pseudo header: {header}")
|
||||
raise ValueError(f"Duplicate HTTP/2 pseudo header: {header!r}")
|
||||
pseudo_headers[header] = value
|
||||
i += 1
|
||||
else:
|
||||
@ -441,7 +446,7 @@ def split_pseudo_headers(h2_headers: Iterable[Tuple[bytes, bytes]]) -> Tuple[Dic
|
||||
|
||||
|
||||
def parse_h2_request_headers(
|
||||
h2_headers: Iterable[Tuple[bytes, bytes]]
|
||||
h2_headers: Sequence[Tuple[bytes, bytes]]
|
||||
) -> Tuple[str, int, bytes, bytes, bytes, bytes, net_http.Headers]:
|
||||
"""Split HTTP/2 pseudo-headers from the actual headers and parse them."""
|
||||
pseudo_headers, headers = split_pseudo_headers(h2_headers)
|
||||
@ -468,7 +473,7 @@ def parse_h2_request_headers(
|
||||
return host, port, method, scheme, authority, path, headers
|
||||
|
||||
|
||||
def parse_h2_response_headers(h2_headers: Iterable[Tuple[bytes, bytes]]) -> Tuple[int, net_http.Headers]:
|
||||
def parse_h2_response_headers(h2_headers: Sequence[Tuple[bytes, bytes]]) -> Tuple[int, net_http.Headers]:
|
||||
"""Split HTTP/2 pseudo-headers from the actual headers and parse them."""
|
||||
pseudo_headers, headers = split_pseudo_headers(h2_headers)
|
||||
|
||||
|
@ -1,11 +1,11 @@
|
||||
import collections
|
||||
from typing import DefaultDict, Deque, NamedTuple
|
||||
|
||||
import h2.config
|
||||
import h2.connection
|
||||
import h2.events
|
||||
import h2.settings
|
||||
import h2.exceptions
|
||||
from typing import DefaultDict, Deque, NamedTuple, Optional
|
||||
import h2.settings
|
||||
|
||||
|
||||
class H2ConnectionLogger(h2.config.DummyLogger):
|
||||
|
@ -28,6 +28,7 @@ class HttpUpstreamProxy(tunnel.TunnelLayer):
|
||||
conn=ctx.server
|
||||
)
|
||||
|
||||
assert self.tunnel_connection.address
|
||||
self.conn.via = server_spec.ServerSpec(
|
||||
"https" if self.tunnel_connection.tls else "http",
|
||||
self.tunnel_connection.address
|
||||
@ -43,6 +44,7 @@ class HttpUpstreamProxy(tunnel.TunnelLayer):
|
||||
self.conn.alpn = self.tunnel_connection.alpn
|
||||
if not self.send_connect:
|
||||
return (yield from super().start_handshake())
|
||||
assert self.conn.address
|
||||
req = http.make_connect_request(self.conn.address)
|
||||
raw = http1.assemble_request(req)
|
||||
yield commands.SendData(self.tunnel_connection, raw)
|
||||
@ -66,7 +68,8 @@ class HttpUpstreamProxy(tunnel.TunnelLayer):
|
||||
return True, None
|
||||
else:
|
||||
raw_resp = b"\n".join(response_head)
|
||||
yield commands.Log(f"{human.format_address(self.tunnel_connection.address)}: {raw_resp}", level="debug")
|
||||
yield commands.Log(f"{human.format_address(self.tunnel_connection.address)}: {raw_resp!r}",
|
||||
level="debug")
|
||||
return False, f"{response.status_code} {response.reason}"
|
||||
else:
|
||||
return False, None
|
||||
|
@ -11,6 +11,7 @@ class ReverseProxy(layer.Layer):
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
spec = server_spec.parse_with_mode(self.context.options.mode)[1]
|
||||
self.context.server = Server(spec.address)
|
||||
child_layer: layer.Layer
|
||||
if spec.scheme not in ("http", "tcp"):
|
||||
if not self.context.options.keep_host_header:
|
||||
self.context.server.sni = spec.address[0].encode()
|
||||
@ -32,6 +33,7 @@ class HttpProxy(layer.Layer):
|
||||
class TransparentProxy(layer.Layer):
|
||||
@expect(events.Start)
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
assert platform.original_addr is not None
|
||||
socket = yield commands.GetSocket(self.context.client)
|
||||
try:
|
||||
self.context.server.address = platform.original_addr(socket)
|
||||
|
@ -3,7 +3,7 @@ from typing import Optional
|
||||
from mitmproxy import flow, tcp
|
||||
from mitmproxy.proxy2 import commands, events, layer
|
||||
from mitmproxy.proxy2.commands import Hook
|
||||
from mitmproxy.proxy2.context import ConnectionState, Context
|
||||
from mitmproxy.proxy2.context import ConnectionState, Context, Connection
|
||||
from mitmproxy.proxy2.utils import expect
|
||||
|
||||
|
||||
@ -73,6 +73,7 @@ class TCPLayer(layer.Layer):
|
||||
@expect(events.DataReceived, events.ConnectionClosed)
|
||||
def relay_messages(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
|
||||
from_client = event.connection == self.context.client
|
||||
send_to: Connection
|
||||
if from_client:
|
||||
send_to = self.context.server
|
||||
else:
|
||||
|
@ -42,7 +42,7 @@ def handshake_record_contents(data: bytes) -> Iterator[bytes]:
|
||||
return
|
||||
record_header = data[offset:offset + 5]
|
||||
if not is_tls_handshake_record(record_header):
|
||||
raise ValueError(f"Expected TLS record, got {record_header} instead.")
|
||||
raise ValueError(f"Expected TLS record, got {record_header!r} instead.")
|
||||
record_size = struct.unpack("!H", record_header[3:])[0]
|
||||
if record_size == 0:
|
||||
raise ValueError("Record must not be empty.")
|
||||
@ -132,7 +132,7 @@ class _TLSLayer(tunnel.TunnelLayer):
|
||||
def __repr__(self):
|
||||
return super().__repr__().replace(")", f" {self.conn.sni} {self.conn.alpn})")
|
||||
|
||||
def start_tls(self) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]:
|
||||
def start_tls(self) -> layer.CommandGenerator[None]:
|
||||
assert not self.tls
|
||||
|
||||
tls_start = TlsStartData(self.conn, self.context)
|
||||
@ -169,6 +169,7 @@ class _TLSLayer(tunnel.TunnelLayer):
|
||||
('SSL routines', 'ssl3_read_bytes', 'tlsv1 alert unknown ca'),
|
||||
('SSL routines', 'ssl3_read_bytes', 'sslv3 alert bad certificate')
|
||||
]:
|
||||
assert isinstance(last_err, list)
|
||||
err = last_err[2]
|
||||
elif last_err == ('SSL routines', 'ssl3_get_record', 'wrong version number') and data[:4].isascii():
|
||||
err = f"The remote server does not speak TLS."
|
||||
@ -278,6 +279,7 @@ class ClientTLSLayer(_TLSLayer):
|
||||
"""
|
||||
recv_buffer: bytearray
|
||||
server_tls_available: bool
|
||||
client_hello_parsed: bool = False
|
||||
|
||||
def __init__(self, context: context.Context):
|
||||
super().__init__(context, context.client)
|
||||
@ -288,13 +290,17 @@ class ClientTLSLayer(_TLSLayer):
|
||||
yield from ()
|
||||
|
||||
def receive_handshake_data(self, data: bytes) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]:
|
||||
if self.client_hello_parsed:
|
||||
return (yield from super().receive_handshake_data(data))
|
||||
self.recv_buffer.extend(data)
|
||||
try:
|
||||
client_hello = parse_client_hello(self.recv_buffer)
|
||||
except ValueError:
|
||||
return False, f"Cannot parse ClientHello: {self.recv_buffer.hex()}"
|
||||
|
||||
if not client_hello:
|
||||
if client_hello:
|
||||
self.client_hello_parsed = True
|
||||
else:
|
||||
return False, None
|
||||
|
||||
self.conn.sni = client_hello.sni
|
||||
@ -310,8 +316,7 @@ class ClientTLSLayer(_TLSLayer):
|
||||
|
||||
yield from self.start_tls()
|
||||
|
||||
self.receive_handshake_data = super().receive_handshake_data
|
||||
ret = yield from self.receive_handshake_data(bytes(self.recv_buffer))
|
||||
ret = yield from super().receive_handshake_data(bytes(self.recv_buffer))
|
||||
self.recv_buffer.clear()
|
||||
return ret
|
||||
|
||||
@ -327,6 +332,7 @@ class ClientTLSLayer(_TLSLayer):
|
||||
|
||||
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
|
||||
if self.conn.sni:
|
||||
assert isinstance(self.conn.sni, bytes)
|
||||
dest = self.conn.sni.decode("idna")
|
||||
else:
|
||||
dest = human.format_address(self.context.server.address)
|
||||
|
@ -1,18 +1,16 @@
|
||||
from typing import Optional, Union, List
|
||||
from typing import Union, List, Iterator
|
||||
|
||||
import wsproto
|
||||
import wsproto.utilities
|
||||
import wsproto.frame_protocol
|
||||
import wsproto.extensions
|
||||
from wsproto.frame_protocol import CloseReason, Opcode
|
||||
from wsproto import ConnectionState
|
||||
|
||||
from mitmproxy import flow, tcp, websocket, http
|
||||
import wsproto.frame_protocol
|
||||
import wsproto.utilities
|
||||
from mitmproxy import flow, websocket, http
|
||||
from mitmproxy.proxy2 import commands, events, layer, context
|
||||
from mitmproxy.proxy2.commands import Hook
|
||||
from mitmproxy.proxy2.context import Context
|
||||
from mitmproxy.proxy2.utils import expect
|
||||
from mitmproxy.utils import human
|
||||
from wsproto import ConnectionState
|
||||
from wsproto.frame_protocol import CloseReason, Opcode
|
||||
|
||||
|
||||
class WebsocketStartHook(Hook):
|
||||
@ -65,8 +63,8 @@ class WebsocketConnection(wsproto.Connection):
|
||||
self.conn = conn
|
||||
self.frame_buf = []
|
||||
|
||||
def send(self, event: wsproto.events.Event) -> commands.SendData:
|
||||
data = super().send(event)
|
||||
def send2(self, event: wsproto.events.Event) -> commands.SendData:
|
||||
data = self.send(event)
|
||||
return commands.SendData(self.conn, data)
|
||||
|
||||
def __repr__(self):
|
||||
@ -77,7 +75,7 @@ class WebsocketLayer(layer.Layer):
|
||||
"""
|
||||
WebSocket layer that intercepts and relays messages.
|
||||
"""
|
||||
flow: Optional[websocket.WebSocketFlow]
|
||||
flow: websocket.WebSocketFlow
|
||||
client_ws: WebsocketConnection
|
||||
server_ws: WebsocketConnection
|
||||
|
||||
@ -144,10 +142,10 @@ class WebsocketLayer(layer.Layer):
|
||||
if ws_event.message_finished:
|
||||
if isinstance(ws_event, wsproto.events.TextMessage):
|
||||
frame_type = Opcode.TEXT
|
||||
content = "".join(src_ws.frame_buf)
|
||||
content = "".join(src_ws.frame_buf) # type: ignore
|
||||
else:
|
||||
frame_type = Opcode.BINARY
|
||||
content = b"".join(src_ws.frame_buf)
|
||||
content = b"".join(src_ws.frame_buf) # type: ignore
|
||||
|
||||
fragmentizer = Fragmentizer(src_ws.frame_buf)
|
||||
src_ws.frame_buf.clear()
|
||||
@ -158,15 +156,15 @@ class WebsocketLayer(layer.Layer):
|
||||
|
||||
assert not message.killed # this is deprecated, instead we should have .content set to emptystr.
|
||||
|
||||
for message in fragmentizer(message.content):
|
||||
yield dst_ws.send(message)
|
||||
for msg in fragmentizer(message.content):
|
||||
yield dst_ws.send2(msg)
|
||||
|
||||
elif isinstance(ws_event, (wsproto.events.Ping, wsproto.events.Pong)):
|
||||
yield commands.Log(
|
||||
f"Received WebSocket {ws_event.__class__.__name__.lower()} from {from_str} "
|
||||
f"(payload: {bytes(ws_event.payload)!r})"
|
||||
)
|
||||
yield dst_ws.send(ws_event)
|
||||
yield dst_ws.send2(ws_event)
|
||||
elif isinstance(ws_event, wsproto.events.CloseConnection):
|
||||
self.flow.close_sender = from_str
|
||||
self.flow.close_code = ws_event.code
|
||||
@ -175,7 +173,7 @@ class WebsocketLayer(layer.Layer):
|
||||
for ws in [self.server_ws, self.client_ws]:
|
||||
if ws.state in {ConnectionState.OPEN, ConnectionState.REMOTE_CLOSING}:
|
||||
# response == original event, so no need to differentiate here.
|
||||
yield ws.send(ws_event)
|
||||
yield ws.send2(ws_event)
|
||||
yield commands.CloseConnection(ws.conn)
|
||||
if ws_event.code in {1000, 1001, 1005}:
|
||||
yield WebsocketEndHook(self.flow)
|
||||
@ -219,7 +217,7 @@ class Fragmentizer:
|
||||
assert fragments
|
||||
self.fragment_lengths = [len(x) for x in fragments]
|
||||
|
||||
def __call__(self, content: Union[str, bytes]):
|
||||
def __call__(self, content: Union[str, bytes]) -> Iterator[wsproto.events.Message]:
|
||||
if not content:
|
||||
return
|
||||
if len(content) == sum(self.fragment_lengths):
|
||||
|
@ -75,6 +75,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
timeout_watchdog: TimeoutWatchdog
|
||||
client: Client
|
||||
max_conns: typing.DefaultDict[Address, asyncio.Semaphore]
|
||||
layer: layer.Layer
|
||||
|
||||
def __init__(self, context: Context) -> None:
|
||||
self.client = context.client
|
||||
@ -93,18 +94,24 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
name="timeout watchdog",
|
||||
client=self.client.peername,
|
||||
)
|
||||
if not watch:
|
||||
return # this should not be needed, see asyncio_utils.create_task
|
||||
|
||||
self.log("client connect")
|
||||
await self.handle_hook(server_hooks.ClientConnectedHook(self.client))
|
||||
if self.client.error:
|
||||
self.log("client kill connection")
|
||||
self.transports.pop(self.client).writer.close()
|
||||
writer = self.transports.pop(self.client).writer
|
||||
assert writer
|
||||
writer.close()
|
||||
else:
|
||||
handler = asyncio_utils.create_task(
|
||||
self.handle_connection(self.client),
|
||||
name=f"client connection handler",
|
||||
client=self.client.peername,
|
||||
)
|
||||
if not handler:
|
||||
return # this should not be needed, see asyncio_utils.create_task
|
||||
self.transports[self.client].handler = handler
|
||||
self.server_event(events.Start())
|
||||
await asyncio.wait([handler])
|
||||
@ -118,8 +125,9 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
if self.transports:
|
||||
self.log("closing transports...", "debug")
|
||||
for io in self.transports.values():
|
||||
asyncio_utils.cancel_task(io.handler, "client disconnected")
|
||||
await asyncio.wait([x.handler for x in self.transports.values()])
|
||||
if io.handler:
|
||||
asyncio_utils.cancel_task(io.handler, "client disconnected")
|
||||
await asyncio.wait([x.handler for x in self.transports.values() if x.handler])
|
||||
self.log("transports closed!", "debug")
|
||||
|
||||
async def open_connection(self, command: commands.OpenConnection) -> None:
|
||||
@ -162,6 +170,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
self.transports[command.connection].reader = reader
|
||||
self.transports[command.connection].writer = writer
|
||||
|
||||
assert command.connection.peername
|
||||
if command.connection.address[0] != command.connection.peername[0]:
|
||||
addr = f"{command.connection.address[0]} ({human.format_address(command.connection.peername)})"
|
||||
else:
|
||||
@ -172,6 +181,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
name=f"handle_hook(server_connected) {addr}",
|
||||
client=self.client.peername,
|
||||
)
|
||||
if not connected_hook:
|
||||
return # this should not be needed, see asyncio_utils.create_task
|
||||
|
||||
self.server_event(events.OpenConnectionReply(command, None))
|
||||
|
||||
@ -183,6 +194,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
name=f"server connection handler for {addr}",
|
||||
client=self.client.peername,
|
||||
)
|
||||
if not new_handler:
|
||||
return # this should not be needed, see asyncio_utils.create_task
|
||||
self.transports[command.connection].handler = new_handler
|
||||
await asyncio.wait([new_handler])
|
||||
|
||||
@ -227,7 +240,9 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
await asyncio.Event().wait()
|
||||
|
||||
try:
|
||||
self.transports[connection].writer.close()
|
||||
writer = self.transports[connection].writer
|
||||
assert writer
|
||||
writer.close()
|
||||
except OSError:
|
||||
pass
|
||||
self.transports.pop(connection)
|
||||
@ -237,7 +252,9 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
|
||||
async def on_timeout(self) -> None:
|
||||
self.log(f"Closing connection due to inactivity: {self.client}")
|
||||
asyncio_utils.cancel_task(self.transports[self.client].handler, "timeout")
|
||||
handler = self.transports[self.client].handler
|
||||
assert handler
|
||||
asyncio_utils.cancel_task(handler, "timeout")
|
||||
|
||||
async def hook_task(self, hook: commands.Hook) -> None:
|
||||
await self.handle_hook(hook)
|
||||
@ -268,11 +285,15 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
elif isinstance(command, commands.ConnectionCommand) and command.connection not in self.transports:
|
||||
return # The connection has already been closed.
|
||||
elif isinstance(command, commands.SendData):
|
||||
self.transports[command.connection].writer.write(command.data)
|
||||
writer = self.transports[command.connection].writer
|
||||
assert writer
|
||||
writer.write(command.data)
|
||||
elif isinstance(command, commands.CloseConnection):
|
||||
self.close_connection(command.connection, command.half_close)
|
||||
elif isinstance(command, commands.GetSocket):
|
||||
socket = self.transports[command.connection].writer.get_extra_info("socket")
|
||||
writer = self.transports[command.connection].writer
|
||||
assert writer
|
||||
socket = writer.get_extra_info("socket")
|
||||
self.server_event(events.GetSocketReply(command, socket))
|
||||
elif isinstance(command, commands.Hook):
|
||||
asyncio_utils.create_task(
|
||||
@ -293,7 +314,9 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
return
|
||||
self.log(f"half-closing {connection}", "debug")
|
||||
try:
|
||||
self.transports[connection].writer.write_eof()
|
||||
writer = self.transports[connection].writer
|
||||
assert writer
|
||||
writer.write_eof()
|
||||
except OSError:
|
||||
# if we can't write to the socket anymore we presume it completely dead.
|
||||
connection.state = ConnectionState.CLOSED
|
||||
@ -303,7 +326,9 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
connection.state = ConnectionState.CLOSED
|
||||
|
||||
if connection.state is ConnectionState.CLOSED:
|
||||
asyncio_utils.cancel_task(self.transports[connection].handler, "closed by command")
|
||||
handler = self.transports[connection].handler
|
||||
assert handler
|
||||
asyncio_utils.cancel_task(handler, "closed by command")
|
||||
|
||||
|
||||
class StreamConnectionHandler(ConnectionHandler, metaclass=abc.ABCMeta):
|
||||
@ -359,7 +384,6 @@ if __name__ == "__main__":
|
||||
)
|
||||
opts.mode = "reverse:http://127.0.0.1:3000/"
|
||||
|
||||
|
||||
async def handle(reader, writer):
|
||||
layer_stack = [
|
||||
# lambda ctx: layers.ServerTLSLayer(ctx),
|
||||
@ -371,8 +395,9 @@ if __name__ == "__main__":
|
||||
]
|
||||
|
||||
def next_layer(nl: layer.NextLayer):
|
||||
nl.layer = layer_stack.pop(0)(nl.context)
|
||||
nl.layer.debug = " " * len(nl.context.layers)
|
||||
l = layer_stack.pop(0)(nl.context)
|
||||
l.debug = " " * len(nl.context.layers)
|
||||
nl.layer = l
|
||||
|
||||
def request(flow: http.HTTPFlow):
|
||||
if "cached" in flow.request.path:
|
||||
@ -410,11 +435,11 @@ if __name__ == "__main__":
|
||||
"tls_start": tls_start,
|
||||
}).handle_client()
|
||||
|
||||
|
||||
coro = asyncio.start_server(handle, '127.0.0.1', 8080, loop=loop)
|
||||
server = loop.run_until_complete(coro)
|
||||
|
||||
# Serve requests until Ctrl+C is pressed
|
||||
assert server.sockets
|
||||
print(f"Serving on {human.format_address(server.sockets[0].getsockname())}")
|
||||
try:
|
||||
loop.run_forever()
|
||||
|
@ -1,5 +1,5 @@
|
||||
from enum import Enum, auto
|
||||
from typing import Callable, List, Optional, Tuple, Type
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from mitmproxy.proxy2 import commands, context, events, layer
|
||||
from mitmproxy.proxy2.layer import Layer
|
||||
@ -138,7 +138,7 @@ class LayerStack:
|
||||
|
||||
def __truediv__(self, other: Layer) -> "LayerStack":
|
||||
if self._stack:
|
||||
self._stack[-1].child_layer = other
|
||||
self._stack[-1].child_layer = other # type: ignore
|
||||
self._stack.append(other)
|
||||
return self
|
||||
|
||||
|
@ -1,15 +1,15 @@
|
||||
"""
|
||||
Usage:
|
||||
- pip install pytest-benchmark
|
||||
- pytest bench.py
|
||||
Usage:
|
||||
- pip install pytest-benchmark
|
||||
- pytest bench.py
|
||||
|
||||
See also:
|
||||
- https://github.com/mitmproxy/proxybench
|
||||
- https://github.com/mitmproxy/proxybench
|
||||
"""
|
||||
import copy
|
||||
|
||||
from .layers.http import test_http, test_http2
|
||||
from .layers import test_tcp, test_tls
|
||||
from .layers.http import test_http, test_http2
|
||||
|
||||
|
||||
def test_bench_http_roundtrip(tctx, benchmark):
|
||||
|
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from mitmproxy import log, options
|
||||
from mitmproxy import options
|
||||
from mitmproxy.addons.proxyserver import Proxyserver
|
||||
from mitmproxy.addons.termlog import TermLog
|
||||
from mitmproxy.proxy2 import context
|
||||
|
@ -2,11 +2,10 @@ import struct
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from mitmproxy.proxy2.layers.old import websocket
|
||||
|
||||
from mitmproxy.net.websockets import Frame, OPCODE
|
||||
from mitmproxy.proxy2 import commands, events
|
||||
from mitmproxy.proxy2.layers.old import websocket
|
||||
|
||||
from mitmproxy.proxy2.context import ConnectionState
|
||||
from mitmproxy.test import tflow
|
||||
from .. import tutils
|
||||
@ -39,24 +38,24 @@ def test_simple(tctx, ws_playbook):
|
||||
]
|
||||
|
||||
assert (
|
||||
ws_playbook
|
||||
<< commands.Hook("websocket_start", f)
|
||||
>> events.HookReply(-1)
|
||||
>> events.DataReceived(tctx.client, frames[0])
|
||||
<< commands.Hook("websocket_message", f)
|
||||
>> events.HookReply(-1)
|
||||
<< commands.SendData(tctx.server, frames[0])
|
||||
>> events.DataReceived(tctx.server, frames[1])
|
||||
<< commands.Hook("websocket_message", f)
|
||||
>> events.HookReply(-1)
|
||||
<< commands.SendData(tctx.client, frames[1])
|
||||
>> events.DataReceived(tctx.client, frames[2])
|
||||
<< commands.SendData(tctx.server, frames[2])
|
||||
<< commands.SendData(tctx.client, frames[3])
|
||||
<< commands.Hook("websocket_end", f)
|
||||
>> events.HookReply(-1)
|
||||
>> events.DataReceived(tctx.server, frames[4])
|
||||
<< None
|
||||
ws_playbook
|
||||
<< commands.Hook("websocket_start", f)
|
||||
>> events.HookReply(-1)
|
||||
>> events.DataReceived(tctx.client, frames[0])
|
||||
<< commands.Hook("websocket_message", f)
|
||||
>> events.HookReply(-1)
|
||||
<< commands.SendData(tctx.server, frames[0])
|
||||
>> events.DataReceived(tctx.server, frames[1])
|
||||
<< commands.Hook("websocket_message", f)
|
||||
>> events.HookReply(-1)
|
||||
<< commands.SendData(tctx.client, frames[1])
|
||||
>> events.DataReceived(tctx.client, frames[2])
|
||||
<< commands.SendData(tctx.server, frames[2])
|
||||
<< commands.SendData(tctx.client, frames[3])
|
||||
<< commands.Hook("websocket_end", f)
|
||||
>> events.HookReply(-1)
|
||||
>> events.DataReceived(tctx.server, frames[4])
|
||||
<< None
|
||||
)
|
||||
|
||||
assert len(f().messages) == 2
|
||||
@ -71,31 +70,31 @@ def test_server_close(tctx, ws_playbook):
|
||||
]
|
||||
|
||||
assert (
|
||||
ws_playbook
|
||||
<< commands.Hook("websocket_start", f)
|
||||
>> events.HookReply(-1)
|
||||
>> events.DataReceived(tctx.server, frames[0])
|
||||
<< commands.SendData(tctx.client, frames[0])
|
||||
<< commands.SendData(tctx.server, frames[1])
|
||||
<< commands.Hook("websocket_end", f)
|
||||
>> events.HookReply(-1)
|
||||
<< commands.CloseConnection(tctx.client)
|
||||
ws_playbook
|
||||
<< commands.Hook("websocket_start", f)
|
||||
>> events.HookReply(-1)
|
||||
>> events.DataReceived(tctx.server, frames[0])
|
||||
<< commands.SendData(tctx.client, frames[0])
|
||||
<< commands.SendData(tctx.server, frames[1])
|
||||
<< commands.Hook("websocket_end", f)
|
||||
>> events.HookReply(-1)
|
||||
<< commands.CloseConnection(tctx.client)
|
||||
)
|
||||
|
||||
|
||||
def test_connection_closed(tctx, ws_playbook):
|
||||
f = tutils.Placeholder()
|
||||
assert (
|
||||
ws_playbook
|
||||
<< commands.Hook("websocket_start", f)
|
||||
>> events.HookReply(-1)
|
||||
>> events.ConnectionClosed(tctx.server)
|
||||
<< commands.Log("error", "Connection closed abnormally")
|
||||
<< commands.CloseConnection(tctx.client)
|
||||
<< commands.Hook("websocket_error", f)
|
||||
>> events.HookReply(-1)
|
||||
<< commands.Hook("websocket_end", f)
|
||||
>> events.HookReply(-1)
|
||||
ws_playbook
|
||||
<< commands.Hook("websocket_start", f)
|
||||
>> events.HookReply(-1)
|
||||
>> events.ConnectionClosed(tctx.server)
|
||||
<< commands.Log("error", "Connection closed abnormally")
|
||||
<< commands.CloseConnection(tctx.client)
|
||||
<< commands.Hook("websocket_error", f)
|
||||
>> events.HookReply(-1)
|
||||
<< commands.Hook("websocket_end", f)
|
||||
>> events.HookReply(-1)
|
||||
)
|
||||
|
||||
assert f().error
|
||||
@ -111,16 +110,16 @@ def test_connection_failed(tctx, ws_playbook):
|
||||
]
|
||||
|
||||
assert (
|
||||
ws_playbook
|
||||
<< commands.Hook("websocket_start", f)
|
||||
>> events.HookReply(-1)
|
||||
>> events.DataReceived(tctx.client, frames[0])
|
||||
<< commands.SendData(tctx.server, frames[1])
|
||||
<< commands.SendData(tctx.client, frames[2])
|
||||
<< commands.Hook("websocket_error", f)
|
||||
>> events.HookReply(-1)
|
||||
<< commands.Hook("websocket_end", f)
|
||||
>> events.HookReply(-1)
|
||||
ws_playbook
|
||||
<< commands.Hook("websocket_start", f)
|
||||
>> events.HookReply(-1)
|
||||
>> events.DataReceived(tctx.client, frames[0])
|
||||
<< commands.SendData(tctx.server, frames[1])
|
||||
<< commands.SendData(tctx.client, frames[2])
|
||||
<< commands.Hook("websocket_error", f)
|
||||
>> events.HookReply(-1)
|
||||
<< commands.Hook("websocket_end", f)
|
||||
>> events.HookReply(-1)
|
||||
)
|
||||
|
||||
|
||||
@ -133,15 +132,15 @@ def test_ping_pong(tctx, ws_playbook):
|
||||
]
|
||||
|
||||
assert (
|
||||
ws_playbook
|
||||
<< commands.Hook("websocket_start", f)
|
||||
>> events.HookReply(-1)
|
||||
>> events.DataReceived(tctx.client, frames[0])
|
||||
<< commands.Log("info", "WebSocket PING received from client: <no payload>")
|
||||
<< commands.SendData(tctx.server, frames[0])
|
||||
<< commands.SendData(tctx.client, frames[1])
|
||||
>> events.DataReceived(tctx.server, frames[1])
|
||||
<< commands.Log("info", "WebSocket PONG received from server: <no payload>")
|
||||
ws_playbook
|
||||
<< commands.Hook("websocket_start", f)
|
||||
>> events.HookReply(-1)
|
||||
>> events.DataReceived(tctx.client, frames[0])
|
||||
<< commands.Log("info", "WebSocket PING received from client: <no payload>")
|
||||
<< commands.SendData(tctx.server, frames[0])
|
||||
<< commands.SendData(tctx.client, frames[1])
|
||||
>> events.DataReceived(tctx.server, frames[1])
|
||||
<< commands.Log("info", "WebSocket PONG received from server: <no payload>")
|
||||
)
|
||||
|
||||
|
||||
@ -156,15 +155,15 @@ def test_ping_pong_hidden_payload(tctx, ws_playbook):
|
||||
]
|
||||
|
||||
assert (
|
||||
ws_playbook
|
||||
<< commands.Hook("websocket_start", f)
|
||||
>> events.HookReply(-1)
|
||||
>> events.DataReceived(tctx.server, frames[0])
|
||||
<< commands.Log("info", "WebSocket PING received from server: foobar")
|
||||
<< commands.SendData(tctx.client, frames[1])
|
||||
<< commands.SendData(tctx.server, frames[2])
|
||||
>> events.DataReceived(tctx.client, frames[3])
|
||||
<< commands.Log("info", "WebSocket PONG received from client: <no payload>")
|
||||
ws_playbook
|
||||
<< commands.Hook("websocket_start", f)
|
||||
>> events.HookReply(-1)
|
||||
>> events.DataReceived(tctx.server, frames[0])
|
||||
<< commands.Log("info", "WebSocket PING received from server: foobar")
|
||||
<< commands.SendData(tctx.client, frames[1])
|
||||
<< commands.SendData(tctx.server, frames[2])
|
||||
>> events.DataReceived(tctx.client, frames[3])
|
||||
<< commands.Log("info", "WebSocket PONG received from client: <no payload>")
|
||||
)
|
||||
|
||||
|
||||
@ -181,17 +180,17 @@ def test_extension(tctx, ws_playbook):
|
||||
]
|
||||
|
||||
assert (
|
||||
ws_playbook
|
||||
<< commands.Hook("websocket_start", f)
|
||||
>> events.HookReply(-1)
|
||||
>> events.DataReceived(tctx.client, frames[0])
|
||||
<< commands.Hook("websocket_message", f)
|
||||
>> events.HookReply(-1)
|
||||
<< commands.SendData(tctx.server, frames[0])
|
||||
>> events.DataReceived(tctx.server, frames[1])
|
||||
<< commands.Hook("websocket_message", f)
|
||||
>> events.HookReply(-1)
|
||||
<< commands.SendData(tctx.client, frames[2])
|
||||
ws_playbook
|
||||
<< commands.Hook("websocket_start", f)
|
||||
>> events.HookReply(-1)
|
||||
>> events.DataReceived(tctx.client, frames[0])
|
||||
<< commands.Hook("websocket_message", f)
|
||||
>> events.HookReply(-1)
|
||||
<< commands.SendData(tctx.server, frames[0])
|
||||
>> events.DataReceived(tctx.server, frames[1])
|
||||
<< commands.Hook("websocket_message", f)
|
||||
>> events.HookReply(-1)
|
||||
<< commands.SendData(tctx.client, frames[2])
|
||||
)
|
||||
assert len(f().messages) == 2
|
||||
assert f().messages[0].content == "Hello"
|
||||
|
@ -8,13 +8,12 @@ helpers
|
||||
|
||||
This module contains helpers for the h2 tests.
|
||||
"""
|
||||
from hpack.hpack import Encoder
|
||||
from hyperframe.frame import (
|
||||
HeadersFrame, DataFrame, SettingsFrame, WindowUpdateFrame, PingFrame,
|
||||
GoAwayFrame, RstStreamFrame, PushPromiseFrame, PriorityFrame,
|
||||
ContinuationFrame, AltSvcFrame
|
||||
)
|
||||
from hpack.hpack import Encoder
|
||||
|
||||
|
||||
SAMPLE_SETTINGS = {
|
||||
SettingsFrame.HEADER_TABLE_SIZE: 4096,
|
||||
@ -29,6 +28,7 @@ class FrameFactory(object):
|
||||
allows test cases to easily build correct HTTP/2 frames to feed to
|
||||
hyper-h2.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.encoder = Encoder()
|
||||
|
||||
@ -176,4 +176,4 @@ class FrameFactory(object):
|
||||
Causes the encoder to send a dynamic size update in the next header
|
||||
block it sends.
|
||||
"""
|
||||
self.encoder.header_table_size = new_size
|
||||
self.encoder.header_table_size = new_size
|
||||
|
@ -347,8 +347,7 @@ def test_request_streaming(tctx, response):
|
||||
>> reply()
|
||||
<< SendData(tctx.client, b"HTTP/1.1 413 Request Entity Too Large\r\nContent-Length: 0\r\n\r\n")
|
||||
>> DataReceived(tctx.client, b"def")
|
||||
<< SendData(server, b"DEF")
|
||||
# Important: no request hook here!
|
||||
<< SendData(server, b"DEF") # Important: no request hook here!
|
||||
)
|
||||
elif response == "early close":
|
||||
assert (
|
||||
@ -875,7 +874,10 @@ def test_close_during_connect_hook(tctx):
|
||||
assert (
|
||||
Playbook(http.HttpLayer(tctx, HTTPMode.regular))
|
||||
>> DataReceived(tctx.client,
|
||||
b'CONNECT hi.ls:443 HTTP/1.1\r\nProxy-Connection: keep-alive\r\nConnection: keep-alive\r\nHost: hi.ls:443\r\n\r\n')
|
||||
b'CONNECT hi.ls:443 HTTP/1.1\r\n'
|
||||
b'Proxy-Connection: keep-alive\r\n'
|
||||
b'Connection: keep-alive\r\n'
|
||||
b'Host: hi.ls:443\r\n\r\n')
|
||||
<< http.HttpConnectHook(flow)
|
||||
>> ConnectionClosed(tctx.client)
|
||||
<< CloseConnection(tctx.client)
|
||||
|
@ -1,5 +1,4 @@
|
||||
import random
|
||||
from typing import List, Tuple, Dict, Any
|
||||
from typing import List, Tuple
|
||||
|
||||
import h2.settings
|
||||
import hpack
|
||||
@ -17,7 +16,7 @@ from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
|
||||
from mitmproxy.proxy2.layers import http
|
||||
from mitmproxy.proxy2.layers.http._http2 import split_pseudo_headers, Http2Client
|
||||
from test.mitmproxy.proxy2.layers.http.hyper_h2_test_helpers import FrameFactory
|
||||
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply, _TracebackInPlaybook, _fmt_entry, _eq
|
||||
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply
|
||||
|
||||
example_request_headers = (
|
||||
(b':method', b'GET'),
|
||||
@ -527,5 +526,4 @@ class TestClient:
|
||||
<< SendData(tctx.server, frame_factory.build_rst_stream_frame(1, ErrorCodes.CANCEL).serialize())
|
||||
>> DataReceived(tctx.server, frame_factory.build_data_frame(b"foo").serialize())
|
||||
<< SendData(tctx.server, frame_factory.build_rst_stream_frame(1, ErrorCodes.STREAM_CLOSED).serialize())
|
||||
# important: no ResponseData event here!
|
||||
)
|
||||
) # important: no ResponseData event here!
|
||||
|
@ -239,7 +239,8 @@ def _h2_request(chunks):
|
||||
@example([b'\x00\x00\x12\x01\x04\x00\x00\x00\x01\x84\x86\x82`\x80A\x88/\x91\xd3]\x05\\\x87\xa7\\\x81\x07'])
|
||||
@example([b'\x00\x00\x14\x01\x04\x00\x00\x00\x01A\x88/\x91\xd3]\x05\\\x87\xa7\x84\x86`\x80\x82f\x80'])
|
||||
@example([
|
||||
b'\x00\x00%\x01\x04\x00\x00\x00\x01A\x8b/\x91\xd3]\x05\\\x87\xa6\xe3M3\x84\x86\x82`\x85\x94\xe7\x8c~\xfff\x88/\x91\xd3]\x05\\\x87\xa7\\\x82h_\x00\x00\x07\x01\x05\x00\x00\x00\x01\xc1\x84\x86\x82\xc0\xbf\xbe'])
|
||||
b'\x00\x00%\x01\x04\x00\x00\x00\x01A\x8b/\x91\xd3]\x05\\\x87\xa6\xe3M3\x84\x86\x82`\x85\x94\xe7\x8c~\xfff\x88/\x91'
|
||||
b'\xd3]\x05\\\x87\xa7\\\x82h_\x00\x00\x07\x01\x05\x00\x00\x00\x01\xc1\x84\x86\x82\xc0\xbf\xbe'])
|
||||
def test_fuzz_h2_request_chunks(chunks):
|
||||
_h2_request(chunks)
|
||||
|
||||
|
@ -2,7 +2,6 @@ import copy
|
||||
|
||||
import pytest
|
||||
|
||||
from mitmproxy.http import HTTPFlow
|
||||
from mitmproxy.proxy.protocol.http import HTTPMode
|
||||
from mitmproxy.proxy2.commands import CloseConnection, OpenConnection, SendData
|
||||
from mitmproxy.proxy2.context import Client, Context, Server
|
||||
|
@ -2,8 +2,8 @@ import ssl
|
||||
import typing
|
||||
|
||||
import pytest
|
||||
from OpenSSL import SSL
|
||||
|
||||
from OpenSSL import SSL
|
||||
from mitmproxy.proxy2 import commands, context, events, layer
|
||||
from mitmproxy.proxy2.context import ConnectionState
|
||||
from mitmproxy.proxy2.layers import tls
|
||||
@ -98,7 +98,7 @@ class SSLTest:
|
||||
def bio_write(self, buf: bytes) -> int:
|
||||
return self.inc.write(buf)
|
||||
|
||||
def bio_read(self, bufsize: int = 2**16) -> bytes:
|
||||
def bio_read(self, bufsize: int = 2 ** 16) -> bytes:
|
||||
return self.out.read(bufsize)
|
||||
|
||||
def do_handshake(self) -> None:
|
||||
@ -138,6 +138,7 @@ def interact(playbook: tutils.Playbook, conn: context.Connection, tssl: SSLTest)
|
||||
)
|
||||
tssl.bio_write(data())
|
||||
|
||||
|
||||
def reply_tls_start(alpn: typing.Optional[bytes] = None, *args, **kwargs) -> tutils.reply:
|
||||
"""
|
||||
Helper function to simplify the syntax for tls_start hooks.
|
||||
|
@ -9,7 +9,7 @@ from mitmproxy.http import HTTPFlow
|
||||
from mitmproxy.net.http import Request, Response
|
||||
from mitmproxy.proxy.protocol.http import HTTPMode
|
||||
from mitmproxy.proxy2.commands import SendData, CloseConnection, Log
|
||||
from mitmproxy.proxy2.context import Server, ConnectionState
|
||||
from mitmproxy.proxy2.context import ConnectionState
|
||||
from mitmproxy.proxy2.events import DataReceived, ConnectionClosed
|
||||
from mitmproxy.proxy2.layers import http, websocket
|
||||
from mitmproxy.websocket import WebSocketFlow
|
||||
|
@ -10,22 +10,22 @@ class TestNextLayer:
|
||||
playbook = tutils.Playbook(nl, hooks=True)
|
||||
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, b"foo")
|
||||
<< layer.NextLayerHook(nl)
|
||||
>> tutils.reply()
|
||||
>> events.DataReceived(tctx.client, b"bar")
|
||||
<< layer.NextLayerHook(nl)
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, b"foo")
|
||||
<< layer.NextLayerHook(nl)
|
||||
>> tutils.reply()
|
||||
>> events.DataReceived(tctx.client, b"bar")
|
||||
<< layer.NextLayerHook(nl)
|
||||
)
|
||||
assert nl.data_client() == b"foobar"
|
||||
assert nl.data_server() == b""
|
||||
|
||||
nl.layer = tutils.EchoLayer(tctx)
|
||||
assert (
|
||||
playbook
|
||||
>> tutils.reply()
|
||||
<< commands.SendData(tctx.client, b"foo")
|
||||
<< commands.SendData(tctx.client, b"bar")
|
||||
playbook
|
||||
>> tutils.reply()
|
||||
<< commands.SendData(tctx.client, b"foo")
|
||||
<< commands.SendData(tctx.client, b"bar")
|
||||
)
|
||||
|
||||
def test_late_hook_reply(self, tctx):
|
||||
@ -37,19 +37,19 @@ class TestNextLayer:
|
||||
playbook = tutils.Playbook(nl)
|
||||
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, b"foo")
|
||||
<< layer.NextLayerHook(nl)
|
||||
>> events.DataReceived(tctx.client, b"bar")
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, b"foo")
|
||||
<< layer.NextLayerHook(nl)
|
||||
>> events.DataReceived(tctx.client, b"bar")
|
||||
)
|
||||
assert nl.data_client() == b"foo" # "bar" is paused.
|
||||
nl.layer = tutils.EchoLayer(tctx)
|
||||
|
||||
assert (
|
||||
playbook
|
||||
>> tutils.reply(to=-2)
|
||||
<< commands.SendData(tctx.client, b"foo")
|
||||
<< commands.SendData(tctx.client, b"bar")
|
||||
playbook
|
||||
>> tutils.reply(to=-2)
|
||||
<< commands.SendData(tctx.client, b"foo")
|
||||
<< commands.SendData(tctx.client, b"bar")
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("layer_found", [True, False])
|
||||
@ -58,23 +58,23 @@ class TestNextLayer:
|
||||
nl = layer.NextLayer(tctx)
|
||||
playbook = tutils.Playbook(nl)
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, b"foo")
|
||||
<< layer.NextLayerHook(nl)
|
||||
>> events.ConnectionClosed(tctx.client)
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, b"foo")
|
||||
<< layer.NextLayerHook(nl)
|
||||
>> events.ConnectionClosed(tctx.client)
|
||||
)
|
||||
if layer_found:
|
||||
nl.layer = tutils.RecordLayer(tctx)
|
||||
assert (
|
||||
playbook
|
||||
>> tutils.reply(to=-2)
|
||||
playbook
|
||||
>> tutils.reply(to=-2)
|
||||
)
|
||||
assert isinstance(nl.layer.event_log[-1], events.ConnectionClosed)
|
||||
else:
|
||||
assert (
|
||||
playbook
|
||||
>> tutils.reply(to=-2)
|
||||
<< commands.CloseConnection(tctx.client)
|
||||
playbook
|
||||
>> tutils.reply(to=-2)
|
||||
<< commands.CloseConnection(tctx.client)
|
||||
)
|
||||
|
||||
def test_func_references(self, tctx):
|
||||
@ -82,16 +82,16 @@ class TestNextLayer:
|
||||
playbook = tutils.Playbook(nl)
|
||||
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, b"foo")
|
||||
<< layer.NextLayerHook(nl)
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, b"foo")
|
||||
<< layer.NextLayerHook(nl)
|
||||
)
|
||||
nl.layer = tutils.EchoLayer(tctx)
|
||||
handle = nl.handle_event
|
||||
assert (
|
||||
playbook
|
||||
>> tutils.reply()
|
||||
<< commands.SendData(tctx.client, b"foo")
|
||||
playbook
|
||||
>> tutils.reply()
|
||||
<< commands.SendData(tctx.client, b"foo")
|
||||
)
|
||||
sd, = handle(events.DataReceived(tctx.client, b"bar"))
|
||||
assert isinstance(sd, commands.SendData)
|
||||
|
@ -79,9 +79,9 @@ def _merge_sends(lst: typing.List[commands.Command], ignore_hooks: bool, ignore_
|
||||
current_send.data += x.data
|
||||
else:
|
||||
ignore = (
|
||||
(ignore_hooks and isinstance(x, commands.Hook))
|
||||
or
|
||||
(ignore_logs and isinstance(x, commands.Log))
|
||||
(ignore_hooks and isinstance(x, commands.Hook))
|
||||
or
|
||||
(ignore_logs and isinstance(x, commands.Log))
|
||||
)
|
||||
if not ignore:
|
||||
current_send = None
|
||||
@ -244,7 +244,7 @@ class Playbook:
|
||||
# the current event may still have yielded more events, so we need to insert
|
||||
# the reply *after* those additional events.
|
||||
hook_replies.append(events.HookReply(cmd))
|
||||
self.expected = self.expected[:pos+1] + hook_replies + self.expected[pos+1:]
|
||||
self.expected = self.expected[:pos + 1] + hook_replies + self.expected[pos + 1:]
|
||||
|
||||
eq(self.expected[i:], self.actual[i:]) # compare now already to set placeholders
|
||||
i += 1
|
||||
|
Loading…
Reference in New Issue
Block a user