[sans-io] lint!

This commit is contained in:
Maximilian Hils 2020-12-11 10:45:47 +01:00
parent 6b2e49eb13
commit 8201a90e22
31 changed files with 384 additions and 301 deletions

View File

@ -33,16 +33,16 @@ class MockServer(layers.http.HttpConnection):
def _handle_event(self, event: events.Event) -> CommandGenerator[None]:
if isinstance(event, events.Start):
has_content = bool(self.flow.request.raw_content)
content = self.flow.request.raw_content
self.flow.request.timestamp_start = self.flow.request.timestamp_end = time.time()
yield layers.http.ReceiveHttp(layers.http.RequestHeaders(
1,
self.flow.request,
end_stream=not has_content,
end_stream=not content,
replay_flow=self.flow,
))
if has_content:
yield layers.http.ReceiveHttp(layers.http.RequestData(1, self.flow.request.raw_content))
if content:
yield layers.http.ReceiveHttp(layers.http.RequestData(1, content))
yield layers.http.ReceiveHttp(layers.http.RequestEndOfMessage(1))
elif isinstance(event, (
layers.http.ResponseHeaders,
@ -56,6 +56,8 @@ class MockServer(layers.http.HttpConnection):
class ReplayHandler(server.ConnectionHandler):
layer: layers.HttpLayer
def __init__(self, flow: http.HTTPFlow, options: Options) -> None:
client = flow.client_conn.copy()
client.state = ConnectionState.OPEN
@ -91,8 +93,9 @@ class ReplayHandler(server.ConnectionHandler):
if self.transports:
# close server connections
for x in self.transports.values():
x.handler.cancel()
await asyncio.wait([x.handler for x in self.transports.values()])
if x.handler:
x.handler.cancel()
await asyncio.wait([x.handler for x in self.transports.values() if x.handler])
# signal completion
self.done.set()
@ -140,6 +143,7 @@ class ClientPlayback:
return "Can't replay flow with missing content."
else:
return "Can only replay HTTP flows."
return None
def load(self, loader):
loader.add_option(

View File

@ -14,7 +14,7 @@ LayerCls = typing.Type[layer.Layer]
def stack_match(
context: context.Context,
layers: typing.List[typing.Union[LayerCls, typing.Tuple[LayerCls, ...]]]
layers: typing.Sequence[typing.Union[LayerCls, typing.Tuple[LayerCls, ...]]]
) -> bool:
if len(context.layers) != len(layers):
return False
@ -74,11 +74,15 @@ class NextLayer:
hostnames.append(context.server.address[0])
if is_tls_record_magic(data_client):
try:
sni = parse_client_hello(data_client).sni
ch = parse_client_hello(data_client)
if ch is None:
return None
sni = ch.sni
except ValueError:
return None # defer decision, wait for more input data
else:
hostnames.append(sni.decode("idna"))
if sni:
hostnames.append(sni.decode("idna"))
if not hostnames:
return False
@ -95,6 +99,8 @@ class NextLayer:
for host in hostnames
for rex in ctx.options.allow_hosts
)
else: # pragma: no cover
raise AssertionError()
def next_layer(self, nextlayer: layer.NextLayer):
if isinstance(nextlayer, base.Layer):
@ -106,10 +112,13 @@ class NextLayer:
return self.make_top_layer(context)
if len(data_client) < 3:
return
return None
client_tls = is_tls_record_magic(data_client)
s = lambda *layers: stack_match(context, layers)
def s(*layers):
return stack_match(context, layers)
top_layer = context.layers[-1]
# 1. check for --ignore/--allow
@ -117,7 +126,7 @@ class NextLayer:
if ignore is True:
return layers.TCPLayer(context, ignore=True)
if ignore is None:
return
return None
# 2. Check for TLS
if client_tls:

View File

@ -51,7 +51,7 @@ class ProxyConnectionHandler(server.StreamConnectionHandler):
def log(self, message: str, level: str = "info") -> None:
x = log.LogEntry(self.log_prefix + message, level)
x.reply = controller.DummyReply()
x.reply = controller.DummyReply() # type: ignore
asyncio_utils.create_task(
self.master.addons.handle_lifecycle("log", x),
name="ProxyConnectionHandler.log"

View File

@ -33,7 +33,7 @@ class TlsConfig:
"""
This addon supplies the proxy core with the desired OpenSSL connection objects to negotiate TLS.
"""
certstore: certs.CertStore = None
certstore: certs.CertStore
# TODO: We should support configuring TLS 1.3 cipher suites (https://github.com/mitmproxy/mitmproxy/issues/4260)
# TODO: We should re-use SSL.Context options here, if only for TLS session resumption.
@ -57,7 +57,7 @@ class TlsConfig:
our certificate should have and then fetches a matching cert from the certstore.
"""
altnames: List[bytes] = []
organization: Optional[str] = None
organization: Optional[bytes] = None
# Use upstream certificate if available.
if conn_context.server.certificate_list:
@ -130,6 +130,7 @@ class TlsConfig:
def create_proxy_server_ssl_conn(self, tls_start: tls.TlsStartData) -> None:
client = tls_start.context.client
server = cast(context.Server, tls_start.conn)
assert server.address
if server.sni is True:
server.sni = client.sni or server.address[0].encode()
@ -179,7 +180,7 @@ class TlsConfig:
args["cipher_list"] = ':'.join(server.cipher_list) if server.cipher_list else None
ssl_ctx = net_tls.create_client_context(
cert=client_cert,
sni=server.sni.decode("idna"), # TODO: Should pass-through here.
sni=server.sni.decode("idna") if server.sni else None, # TODO: Should pass-through here.
alpn_protos=server.alpn_offers,
**args
)

View File

@ -16,26 +16,27 @@ def _parse_authority_form(hostport: bytes) -> Tuple[bytes, int]:
ValueError, if the input is malformed
"""
try:
host, port = hostport.rsplit(b":", 1)
host, port_str = hostport.rsplit(b":", 1)
if host.startswith(b"[") and host.endswith(b"]"):
host = host[1:-1]
port = int(port)
port = int(port_str)
if not check.is_valid_host(host) or not check.is_valid_port(port):
raise ValueError
except ValueError:
raise ValueError(f"Invalid host specification: {hostport}")
raise ValueError(f"Invalid host specification: {hostport!r}")
return host, port
def raise_if_http_version_unknown(http_version):
def raise_if_http_version_unknown(http_version: bytes) -> None:
if not re.match(br"^HTTP/\d\.\d$", http_version):
raise ValueError(f"Unknown HTTP version: {http_version}")
raise ValueError(f"Unknown HTTP version: {http_version!r}")
def _read_request_line(line: bytes) -> Tuple[str, int, bytes, bytes, bytes, bytes, bytes]:
try:
method, target, http_version = line.split()
port: Optional[int]
if target == b"*" or target.startswith(b"/"):
scheme, authority, path = b"", b"", target
@ -58,7 +59,7 @@ def _read_request_line(line: bytes) -> Tuple[str, int, bytes, bytes, bytes, byte
raise_if_http_version_unknown(http_version)
except ValueError as e:
raise ValueError(f"Bad HTTP request line: {line}") from e
raise ValueError(f"Bad HTTP request line: {line!r}") from e
return host, port, method, scheme, authority, path, http_version
@ -69,16 +70,16 @@ def _read_response_line(line: bytes) -> Tuple[bytes, int, bytes]:
if len(parts) == 2: # handle missing message gracefully
parts.append(b"")
http_version, status_code, reason = parts
status_code = int(status_code)
http_version, status_code_str, reason = parts
status_code = int(status_code_str)
raise_if_http_version_unknown(http_version)
except ValueError as e:
raise ValueError(f"Bad HTTP response line: {line}") from e
raise ValueError(f"Bad HTTP response line: {line!r}") from e
return http_version, status_code, reason
def _read_headers(lines: Iterable[bytes]):
def _read_headers(lines: Iterable[bytes]) -> headers.Headers:
"""
Read a set of headers.
Stop once a blank line is reached.
@ -89,7 +90,7 @@ def _read_headers(lines: Iterable[bytes]):
Raises:
exceptions.HttpSyntaxException
"""
ret = []
ret: List[Tuple[bytes, bytes]] = []
for line in lines:
if line[0] in b" \t":
if not ret:
@ -104,7 +105,7 @@ def _read_headers(lines: Iterable[bytes]):
raise ValueError()
ret.append((name, value))
except ValueError:
raise ValueError(f"Invalid header line: {line}")
raise ValueError(f"Invalid header line: {line!r}")
return headers.Headers(ret)

View File

@ -8,17 +8,20 @@ The counterpart to commands are events.
"""
import dataclasses
import re
from typing import Any, ClassVar, Dict, List, Literal, Type
from typing import Any, ClassVar, Dict, List, Literal, Type, Union, TYPE_CHECKING
from mitmproxy.proxy2.context import Connection, Server
if TYPE_CHECKING:
import mitmproxy.proxy2.layer
class Command:
"""
Base class for all commands
"""
blocking: ClassVar[bool] = False
blocking: Union[bool, "mitmproxy.proxy2.layer.Layer"] = False
"""
Determines if the command blocks until it has been completed.

View File

@ -1,13 +1,16 @@
import uuid
import warnings
from enum import Flag
from typing import List, Literal, Optional, Sequence, Tuple, Union
from typing import List, Literal, Optional, Sequence, Tuple, Union, TYPE_CHECKING
from mitmproxy import certs
from mitmproxy.coretypes import serializable
from mitmproxy.net import server_spec
from mitmproxy.options import Options
if TYPE_CHECKING:
import mitmproxy.proxy2.layer
class ConnectionState(Flag):
CLOSED = 0
@ -38,9 +41,9 @@ class Connection(serializable.Serializable):
tls: bool = False
certificate_list: Optional[Sequence[certs.Cert]] = None
"""
The TLS certificate list as sent by the peer.
The TLS certificate list as sent by the peer.
The first certificate is the end-entity certificate.
[RFC 8446] Prior to TLS 1.3, "certificate_list" ordering required each
certificate to certify the one immediately preceding it; however,
some implementations allowed some flexibility. Servers sometimes
@ -84,7 +87,7 @@ class Connection(serializable.Serializable):
return f"{type(self).__name__}({attrs})"
@property
def alpn_proto_negotiated(self) -> bytes:
def alpn_proto_negotiated(self) -> Optional[bytes]:
warnings.warn("Server.alpn_proto_negotiated is deprecated, use Server.alpn instead.", PendingDeprecationWarning)
return self.alpn
@ -185,11 +188,11 @@ class Server(Connection):
timestamp_tcp_setup: Optional[float] = None
"""TCP ACK received"""
sni = True
sni: Union[bytes, Literal[True], None] = True
"""True: client SNI, False: no SNI, bytes: custom value"""
via: Optional[server_spec.ServerSpec] = None
def __init__(self, address: Optional[tuple]):
def __init__(self, address: Optional[Address]):
self.id = str(uuid.uuid4())
self.address = address
@ -250,7 +253,7 @@ class Server(Connection):
self.via = state["via2"]
@property
def ip_address(self) -> Address:
def ip_address(self) -> Optional[Address]:
warnings.warn("Server.ip_address is deprecated, use Server.peername instead.", PendingDeprecationWarning)
return self.peername

View File

@ -3,22 +3,22 @@ Base class for protocol layers.
"""
import collections
import textwrap
import typing
from abc import abstractmethod
from typing import Optional, List, ClassVar, Deque, NamedTuple, Generator, Any, TypeVar
from mitmproxy import log
from mitmproxy.proxy2 import commands, events
from mitmproxy.proxy2.commands import Command, Hook
from mitmproxy.proxy2.context import Connection, Context
T = typing.TypeVar('T')
CommandGenerator = typing.Generator[Command, typing.Optional[events.CommandReply], T]
T = TypeVar('T')
CommandGenerator = Generator[Command, Any, T]
"""
A function annotated with CommandGenerator[bool] may yield commands and ultimately return a boolean value.
"""
class Paused(typing.NamedTuple):
class Paused(NamedTuple):
"""
State of a layer that's paused because it is waiting for a command reply.
"""
@ -27,11 +27,11 @@ class Paused(typing.NamedTuple):
class Layer:
__last_debug_message: typing.ClassVar[str] = ""
__last_debug_message: ClassVar[str] = ""
context: Context
_paused: typing.Optional[Paused]
_paused_event_queue: typing.Deque[events.Event]
debug: typing.Optional[str] = None
_paused: Optional[Paused]
_paused_event_queue: Deque[events.Event]
debug: Optional[str] = None
"""
Enable debug logging by assigning a prefix string for log messages.
Different amounts of whitespace for different layers work well.
@ -94,6 +94,7 @@ class Layer:
if self.debug is not None:
yield self.__debug(f"{'>>' if pause_finished else '>!'} {event}")
if pause_finished:
assert isinstance(event, events.CommandReply)
yield from self.__continue(event)
else:
self._paused_event_queue.append(event)
@ -114,7 +115,7 @@ class Layer:
except StopIteration:
return
while command:
while True:
if self.debug is not None:
if not isinstance(command, commands.Log):
yield self.__debug(f"<< {command}")
@ -128,19 +129,23 @@ class Layer:
return
else:
yield command
command = next(command_generator, None)
try:
command = next(command_generator)
except StopIteration:
return
def __continue(self, event: events.CommandReply):
"""continue processing events after being paused"""
assert self._paused is not None
command_generator = self._paused.generator
self._paused = None
yield from self.__process(command_generator, event.reply)
while not self._paused and self._paused_event_queue:
event = self._paused_event_queue.popleft()
ev = self._paused_event_queue.popleft()
if self.debug is not None:
yield self.__debug(f"!> {event}")
command_generator = self._handle_event(event)
yield self.__debug(f"!> {ev}")
command_generator = self._handle_event(ev)
yield from self.__process(command_generator)
@ -152,10 +157,10 @@ class NextLayerHook(Hook):
class NextLayer(Layer):
layer: typing.Optional[Layer]
layer: Optional[Layer]
"""The next layer. To be set by an addon."""
events: typing.List[mevents.Event]
events: List[mevents.Event]
"""All events that happened before a decision was made."""
_ask_on_start: bool

View File

@ -1,7 +1,7 @@
import collections
import time
import typing
from dataclasses import dataclass
from typing import Optional, Tuple, Union, Dict, DefaultDict, List
from mitmproxy import flow, http
from mitmproxy.net import server_spec
@ -13,7 +13,7 @@ from mitmproxy.proxy2.layers import tls, websocket, tcp
from mitmproxy.proxy2.layers.http import _upstream_proxy
from mitmproxy.proxy2.utils import expect
from mitmproxy.utils import human
from ._base import HttpCommand, HttpConnection, ReceiveHttp, StreamId
from ._base import HttpCommand, ReceiveHttp, StreamId, HttpConnection
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \
@ -22,7 +22,7 @@ from ._http1 import Http1Client, Http1Server
from ._http2 import Http2Client, Http2Server
def validate_request(mode, request) -> typing.Optional[str]:
def validate_request(mode, request) -> Optional[str]:
if request.scheme not in ("http", "https", ""):
return f"Invalid request scheme: {request.scheme}"
if mode is HTTPMode.transparent and request.method == "CONNECT":
@ -39,9 +39,9 @@ class GetHttpConnection(HttpCommand):
Open an HTTP Connection. This may not actually open a connection, but return an existing HTTP connection instead.
"""
blocking = True
address: typing.Tuple[str, int]
address: Tuple[str, int]
tls: bool
via: typing.Optional[server_spec.ServerSpec]
via: Optional[server_spec.ServerSpec]
def __hash__(self):
return id(self)
@ -61,29 +61,23 @@ class GetHttpConnection(HttpCommand):
@dataclass
class GetHttpConnectionReply(events.CommandReply):
command: GetHttpConnection
reply: typing.Tuple[typing.Optional[Connection], typing.Optional[str]]
reply: Union[Tuple[None, str], Tuple[Connection, None]]
"""connection object, error message"""
@dataclass
class RegisterHttpConnection(HttpCommand):
"""
Register that a HTTP connection has been successfully established.
Register that a HTTP connection attempt has been completed.
"""
connection: Connection
err: str
def __init__(self, connection: Connection, err: str):
self.connection = connection
self.err = err
err: Optional[str]
@dataclass
class SendHttp(HttpCommand):
connection: Connection
event: HttpEvent
def __init__(self, event: HttpEvent, connection: Connection):
self.connection = connection
self.event = event
connection: Connection
def __repr__(self) -> str:
return f"Send({self.event})"
@ -93,8 +87,8 @@ class HttpStream(layer.Layer):
request_body_buf: bytes
response_body_buf: bytes
flow: http.HTTPFlow
stream_id: StreamId = None
child_layer: typing.Optional[layer.Layer] = None
stream_id: StreamId
child_layer: Optional[layer.Layer] = None
@property
def mode(self):
@ -102,12 +96,13 @@ class HttpStream(layer.Layer):
parent: HttpLayer = self.context.layers[i - 1]
return parent.mode
def __init__(self, context: Context):
def __init__(self, context: Context, stream_id: int):
super().__init__(context)
self.request_body_buf = b""
self.response_body_buf = b""
self.client_state = self.state_uninitialized
self.server_state = self.state_uninitialized
self.stream_id = stream_id
def __repr__(self):
return (
@ -131,7 +126,6 @@ class HttpStream(layer.Layer):
@expect(RequestHeaders)
def state_wait_for_request_headers(self, event: RequestHeaders) -> layer.CommandGenerator[None]:
self.stream_id = event.stream_id
if not event.replay_flow:
self.flow = http.HTTPFlow(
self.context.client,
@ -152,6 +146,7 @@ class HttpStream(layer.Layer):
if self.mode is HTTPMode.transparent:
# Determine .scheme, .host and .port attributes for transparent requests
assert self.context.server.address
self.flow.request.data.host = self.context.server.address[0]
self.flow.request.data.port = self.context.server.address[1]
self.flow.request.scheme = "https" if self.context.server.tls else "http"
@ -176,10 +171,11 @@ class HttpStream(layer.Layer):
if self.mode is HTTPMode.regular and not self.flow.request.is_http2:
# Set the request target to origin-form for HTTP/1, some servers don't support absolute-form requests.
# see https://github.com/mitmproxy/mitmproxy/issues/1759
self.flow.request.authority = b""
self.flow.request.authority = ""
# update host header in reverse proxy mode
if self.context.options.mode.startswith("reverse:") and not self.context.options.keep_host_header:
assert self.context.server.address
self.flow.request.host_header = url.hostport(
"https" if self.context.server.tls else "http",
self.context.server.address[0],
@ -210,7 +206,7 @@ class HttpStream(layer.Layer):
self.server_state = self.state_wait_for_response_headers
@expect(RequestData, RequestEndOfMessage)
def state_stream_request_body(self, event: events.Event) -> layer.CommandGenerator[None]:
def state_stream_request_body(self, event: Union[RequestData, RequestEndOfMessage]) -> layer.CommandGenerator[None]:
if isinstance(event, RequestData):
if callable(self.flow.request.stream):
event.data = self.flow.request.stream(event.data)
@ -255,10 +251,10 @@ class HttpStream(layer.Layer):
if not ok:
return
has_content = bool(self.flow.request.raw_content)
yield SendHttp(RequestHeaders(self.stream_id, self.flow.request, not has_content), self.context.server)
if has_content:
yield SendHttp(RequestData(self.stream_id, self.flow.request.raw_content), self.context.server)
content = self.flow.request.raw_content
yield SendHttp(RequestHeaders(self.stream_id, self.flow.request, not content), self.context.server)
if content:
yield SendHttp(RequestData(self.stream_id, content), self.context.server)
yield SendHttp(RequestEndOfMessage(self.stream_id), self.context.server)
@expect(ResponseHeaders)
@ -275,6 +271,7 @@ class HttpStream(layer.Layer):
@expect(ResponseData, ResponseEndOfMessage)
def state_stream_response_body(self, event: events.Event) -> layer.CommandGenerator[None]:
assert self.flow.response
if isinstance(event, ResponseData):
if callable(self.flow.response.stream):
data = self.flow.response.stream(event.data)
@ -289,12 +286,14 @@ class HttpStream(layer.Layer):
if isinstance(event, ResponseData):
self.response_body_buf += event.data
elif isinstance(event, ResponseEndOfMessage):
assert self.flow.response
self.flow.response.data.content = self.response_body_buf
self.response_body_buf = b""
yield from self.send_response()
def send_response(self, already_streamed: bool = False):
"""We have either consumed the entire response from the server or the response was set by an addon."""
assert self.flow.response
self.flow.response.timestamp_end = time.time()
yield HttpResponseHook(self.flow)
self.server_state = self.state_done
@ -302,10 +301,10 @@ class HttpStream(layer.Layer):
return
if not already_streamed:
has_content = bool(self.flow.response.raw_content)
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response, not has_content), self.context.client)
if has_content:
yield SendHttp(ResponseData(self.stream_id, self.flow.response.raw_content), self.context.client)
content = self.flow.response.raw_content
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response, not content), self.context.client)
if content:
yield SendHttp(ResponseData(self.stream_id, content), self.context.client)
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
@ -331,7 +330,7 @@ class HttpStream(layer.Layer):
)
# The client may have closed the connection while we were waiting for the hook to complete.
# We peek into the event queue to see if that is the case.
killed_by_remote = False
killed_by_remote = None
for evt in self._paused_event_queue:
if isinstance(evt, RequestProtocolError):
killed_by_remote = evt.message
@ -356,7 +355,7 @@ class HttpStream(layer.Layer):
def handle_protocol_error(
self,
event: typing.Union[RequestProtocolError, ResponseProtocolError]
event: Union[RequestProtocolError, ResponseProtocolError]
) -> layer.CommandGenerator[None]:
is_client_error_but_we_already_talk_upstream = (
isinstance(event, RequestProtocolError)
@ -450,6 +449,8 @@ class HttpStream(layer.Layer):
@expect(RequestData, RequestEndOfMessage, events.Event)
def passthrough(self, event: events.Event) -> layer.CommandGenerator[None]:
assert self.flow.response
assert self.child_layer
# HTTP events -> normal connection events
if isinstance(event, RequestData):
event = events.DataReceived(self.context.client, event.data)
@ -509,10 +510,10 @@ class HttpLayer(layer.Layer):
ConnectionEvent -> HttpEvent -> HttpCommand -> ConnectionCommand
"""
mode: HTTPMode
command_sources: typing.Dict[commands.Command, layer.Layer]
streams: typing.Dict[int, HttpStream]
connections: typing.Dict[Connection, layer.Layer]
waiting_for_establishment: typing.DefaultDict[Connection, typing.List[GetHttpConnection]]
command_sources: Dict[commands.Command, layer.Layer]
streams: Dict[int, HttpStream]
connections: Dict[Connection, layer.Layer]
waiting_for_establishment: DefaultDict[Connection, List[GetHttpConnection]]
def __init__(self, context: Context, mode: HTTPMode):
super().__init__(context)
@ -522,6 +523,7 @@ class HttpLayer(layer.Layer):
self.streams = {}
self.command_sources = {}
http_conn: HttpConnection
if self.context.client.alpn == b"h2":
http_conn = Http2Server(context.fork())
else:
@ -554,7 +556,7 @@ class HttpLayer(layer.Layer):
def event_to_child(
self,
child: typing.Union[layer.Layer, HttpStream],
child: Union[layer.Layer, HttpStream],
event: events.Event,
) -> layer.CommandGenerator[None]:
for command in child.handle_event(event):
@ -567,7 +569,7 @@ class HttpLayer(layer.Layer):
if isinstance(command, ReceiveHttp):
if isinstance(command.event, RequestHeaders):
self.streams[command.event.stream_id] = yield from self.make_stream()
yield from self.make_stream(command.event.stream_id)
stream = self.streams[command.event.stream_id]
yield from self.event_to_child(stream, command.event)
elif isinstance(command, SendHttp):
@ -585,11 +587,10 @@ class HttpLayer(layer.Layer):
else:
raise AssertionError(f"Not a command: {event}")
def make_stream(self) -> layer.CommandGenerator[HttpStream]:
def make_stream(self, stream_id: int) -> layer.CommandGenerator[None]:
ctx = self.context.fork()
stream = HttpStream(ctx)
yield from self.event_to_child(stream, events.Start())
return stream
self.streams[stream_id] = HttpStream(ctx, stream_id)
yield from self.event_to_child(self.streams[stream_id], events.Start())
def get_connection(self, event: GetHttpConnection, *, reuse: bool = True) -> layer.CommandGenerator[None]:
# Do we already have a connection we can re-use?
@ -649,6 +650,7 @@ class HttpLayer(layer.Layer):
def register_connection(self, command: RegisterHttpConnection) -> layer.CommandGenerator[None]:
waiting = self.waiting_for_establishment.pop(command.connection)
reply: Union[Tuple[None, str], Tuple[Connection, None]]
if command.err:
reply = (None, command.err)
else:
@ -675,11 +677,13 @@ class HttpLayer(layer.Layer):
class HttpClient(layer.Layer):
@expect(events.Start)
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
err: Optional[str]
if self.context.server.connected:
err = None
else:
err = yield commands.OpenConnection(self.context.server)
if not err:
child_layer: layer.Layer
if self.context.server.alpn == b"h2":
child_layer = Http2Client(self.context)
else:

View File

@ -27,7 +27,8 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
response: Optional[http.HTTPResponse] = None
request_done: bool = False
response_done: bool = False
state: Callable[[events.ConnectionEvent], layer.CommandGenerator[None]]
# this is a bit of a hack to make both mypy and PyCharm happy.
state: Union[Callable[[events.ConnectionEvent], layer.CommandGenerator[None]], Callable]
body_reader: TBodyReader
buf: ReceiveBuffer
@ -50,10 +51,12 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, HttpEvent):
yield from self.send(event)
else:
elif isinstance(event, events.ConnectionEvent):
if isinstance(event, events.DataReceived) and self.state != self.passthrough:
self.buf += event.data
yield from self.state(event)
else:
raise AssertionError(f"Unexpected event: {event}")
@expect(events.Start)
def start(self, _) -> layer.CommandGenerator[None]:
@ -63,6 +66,7 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
state = start
def read_body(self, event: events.Event) -> layer.CommandGenerator[None]:
assert self.stream_id
while True:
try:
if isinstance(event, events.DataReceived):
@ -83,6 +87,7 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
if data:
yield ReceiveHttp(self.ReceiveData(self.stream_id, data))
elif isinstance(h11_event, h11.EndOfMessage):
assert self.request
if h11_event.headers:
raise NotImplementedError(f"HTTP trailers are not implemented yet.")
if self.request.data.method.upper() != b"CONNECT":
@ -99,6 +104,7 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
We wait for the current flow to be finished before parsing the next message,
as we may want to upgrade to WebSocket or plain TCP before that.
"""
assert self.stream_id
if isinstance(event, events.DataReceived):
return
elif isinstance(event, events.ConnectionClosed):
@ -123,6 +129,7 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
self.buf.compress()
def passthrough(self, event: events.Event) -> layer.CommandGenerator[None]:
assert self.stream_id
if isinstance(event, events.DataReceived):
yield ReceiveHttp(self.ReceiveData(self.stream_id, event.data))
elif isinstance(event, events.ConnectionClosed):
@ -137,19 +144,19 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
if response:
self.response_done = True
if self.request_done and self.response_done:
assert self.request
assert self.response
if should_make_pipe(self.request, self.response):
yield from self.make_pipe()
return
connection_done = (
http1_sansio.expected_http_body_size(self.request, self.response) == -1 or
http1.connection_close(self.request.http_version, self.request.headers) or
http1.connection_close(self.response.http_version, self.response.headers) or
(
# If we proxy HTTP/2 to HTTP/1, we only use upstream connections for one request.
# This simplifies our connection management quite a bit as we can rely on
# the proxyserver's max-connection-per-server throttling.
self.request.is_http2 and isinstance(self, Http1Client)
)
http1_sansio.expected_http_body_size(self.request, self.response) == -1
or http1.connection_close(self.request.http_version, self.request.headers)
or http1.connection_close(self.response.http_version, self.response.headers)
# If we proxy HTTP/2 to HTTP/1, we only use upstream connections for one request.
# This simplifies our connection management quite a bit as we can rely on
# the proxyserver's max-connection-per-server throttling.
or (self.request.is_http2 and isinstance(self, Http1Client))
)
if connection_done:
yield commands.CloseConnection(self.conn)
@ -172,6 +179,7 @@ class Http1Server(Http1Connection):
ReceiveProtocolError = RequestProtocolError
ReceiveData = RequestData
ReceiveEndOfMessage = RequestEndOfMessage
stream_id: int
def __init__(self, context: Context):
super().__init__(context, context.client)
@ -185,7 +193,7 @@ class Http1Server(Http1Connection):
if response.is_http2:
response = response.copy()
# Convert to an HTTP/1 response.
response.http_version = b"HTTP/1.1"
response.http_version = "HTTP/1.1"
# not everyone supports empty reason phrases, so we better make up one.
response.reason = status_codes.RESPONSES.get(response.status_code, "")
# Shall we set a Content-Length header here if there is none?
@ -194,6 +202,7 @@ class Http1Server(Http1Connection):
raw = http1.assemble_response_head(response)
yield commands.SendData(self.conn, raw)
elif isinstance(event, ResponseData):
assert self.response
if "chunked" in self.response.headers.get("transfer-encoding", "").lower():
raw = b"%x\r\n%s\r\n" % (len(event.data), event.data)
else:
@ -201,6 +210,7 @@ class Http1Server(Http1Connection):
if raw:
yield commands.SendData(self.conn, raw)
elif isinstance(event, ResponseEndOfMessage):
assert self.response
if "chunked" in self.response.headers.get("transfer-encoding", "").lower():
yield commands.SendData(self.conn, b"0\r\n\r\n")
yield from self.mark_done(response=True)
@ -235,7 +245,7 @@ class Http1Server(Http1Connection):
elif isinstance(event, events.ConnectionClosed):
buf = bytes(self.buf)
if buf.strip():
yield commands.Log(f"Client closed connection before completing request headers: {buf}")
yield commands.Log(f"Client closed connection before completing request headers: {buf!r}")
yield commands.CloseConnection(self.conn)
else:
raise AssertionError(f"Unexpected event: {event}")
@ -268,13 +278,14 @@ class Http1Client(Http1Connection):
if request.is_http2:
# Convert to an HTTP/1 request.
request = request.copy() # (we could probably be a bit more efficient here.)
request.http_version = b"HTTP/1.1"
request.http_version = "HTTP/1.1"
if "Host" not in request.headers and request.authority:
request.headers.insert(0, "Host", request.authority)
request.authority = b""
request.authority = ""
raw = http1.assemble_request_head(request)
yield commands.SendData(self.conn, raw)
elif isinstance(event, RequestData):
assert self.request
if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
raw = b"%x\r\n%s\r\n" % (len(event.data), event.data)
else:
@ -282,6 +293,7 @@ class Http1Client(Http1Connection):
if raw:
yield commands.SendData(self.conn, raw)
elif isinstance(event, RequestEndOfMessage):
assert self.request
if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
yield commands.SendData(self.conn, b"0\r\n\r\n")
elif http1_sansio.expected_http_body_size(self.request, self.response) == -1:
@ -300,6 +312,7 @@ class Http1Client(Http1Connection):
yield commands.Log(f"Unexpected data from server: {bytes(self.buf)!r}")
yield commands.CloseConnection(self.conn)
return
assert self.stream_id
response_head = self.buf.maybe_extract_lines()
if response_head:

View File

@ -1,7 +1,7 @@
import collections
import time
from enum import Enum
from typing import ClassVar, DefaultDict, Dict, Iterable, List, Optional, Tuple, Type, Union
from typing import ClassVar, DefaultDict, Dict, List, Optional, Tuple, Type, Union, Sequence
import h2.config
import h2.connection
@ -76,6 +76,7 @@ class Http2Connection(HttpConnection):
elif isinstance(event, HttpEvent):
if isinstance(event, self.SendData):
assert isinstance(event, (RequestData, ResponseData))
self.h2_conn.send_data(event.stream_id, event.data)
elif isinstance(event, self.SendEndOfMessage):
stream = self.h2_conn.streams.get(event.stream_id)
@ -85,6 +86,7 @@ class Http2Connection(HttpConnection):
if self.is_closed(event.stream_id):
self.streams.pop(event.stream_id, None)
elif isinstance(event, self.SendProtocolError):
assert isinstance(event, (RequestProtocolError, ResponseProtocolError))
stream = self.h2_conn.streams.get(event.stream_id)
if stream.state_machine.state is not h2.stream.StreamState.CLOSED:
code = {
@ -192,6 +194,7 @@ class Http2Connection(HttpConnection):
yield Log(f"Ignoring unknown HTTP/2 frame type: {event.frame.type}")
else:
raise AssertionError(f"Unexpected event: {event!r}")
return False
def protocol_error(
self,
@ -208,7 +211,7 @@ class Http2Connection(HttpConnection):
for stream_id in self.streams:
yield ReceiveHttp(self.ReceiveProtocolError(stream_id, msg))
self.streams.clear()
self._handle_event = self.done
self._handle_event = self.done # type: ignore
@expect(DataReceived, HttpEvent, ConnectionClosed)
def done(self, _) -> CommandGenerator[None]:
@ -285,6 +288,7 @@ class Http2Server(Http2Connection):
)
self.streams[event.stream_id] = StreamState.HEADERS_RECEIVED
yield ReceiveHttp(RequestHeaders(event.stream_id, request, end_stream=bool(event.stream_ended)))
return False
else:
return (yield from super().handle_h2_event(event))
@ -303,9 +307,9 @@ class Http2Client(Http2Connection):
ReceiveData = ResponseData
ReceiveEndOfMessage = ResponseEndOfMessage
our_stream_id = Dict[int, int]
their_stream_id = Dict[int, int]
stream_queue = DefaultDict[int, List[Event]]
our_stream_id: Dict[int, int]
their_stream_id: Dict[int, int]
stream_queue: DefaultDict[int, List[Event]]
"""Queue of streams that we haven't sent yet because we have reached MAX_CONCURRENT_STREAMS"""
provisional_max_concurrency: Optional[int] = 10
"""A provisional currency limit before we get the server's first settings frame."""
@ -360,9 +364,9 @@ class Http2Client(Http2Connection):
def _handle_event2(self, event: Event) -> CommandGenerator[None]:
if isinstance(event, RequestHeaders):
pseudo_headers = [
(b':method', event.request.method),
(b':scheme', event.request.scheme),
(b':path', event.request.path),
(b':method', event.request.data.method),
(b':scheme', event.request.data.scheme),
(b':path', event.request.data.path),
]
if event.request.authority:
pseudo_headers.append((b":authority", event.request.data.authority))
@ -410,6 +414,7 @@ class Http2Client(Http2Connection):
)
self.streams[event.stream_id] = StreamState.HEADERS_RECEIVED
yield ReceiveHttp(ResponseHeaders(event.stream_id, response, bool(event.stream_ended)))
return False
elif isinstance(event, h2.events.RequestReceived):
yield from self.protocol_error(f"HTTP/2 protocol error: received request from server")
return True
@ -422,13 +427,13 @@ class Http2Client(Http2Connection):
return (yield from super().handle_h2_event(event))
def split_pseudo_headers(h2_headers: Iterable[Tuple[bytes, bytes]]) -> Tuple[Dict[bytes, bytes], net_http.Headers]:
def split_pseudo_headers(h2_headers: Sequence[Tuple[bytes, bytes]]) -> Tuple[Dict[bytes, bytes], net_http.Headers]:
pseudo_headers: Dict[bytes, bytes] = {}
i = 0
for (header, value) in h2_headers:
if header.startswith(b":"):
if header in pseudo_headers:
raise ValueError(f"Duplicate HTTP/2 pseudo header: {header}")
raise ValueError(f"Duplicate HTTP/2 pseudo header: {header!r}")
pseudo_headers[header] = value
i += 1
else:
@ -441,7 +446,7 @@ def split_pseudo_headers(h2_headers: Iterable[Tuple[bytes, bytes]]) -> Tuple[Dic
def parse_h2_request_headers(
h2_headers: Iterable[Tuple[bytes, bytes]]
h2_headers: Sequence[Tuple[bytes, bytes]]
) -> Tuple[str, int, bytes, bytes, bytes, bytes, net_http.Headers]:
"""Split HTTP/2 pseudo-headers from the actual headers and parse them."""
pseudo_headers, headers = split_pseudo_headers(h2_headers)
@ -468,7 +473,7 @@ def parse_h2_request_headers(
return host, port, method, scheme, authority, path, headers
def parse_h2_response_headers(h2_headers: Iterable[Tuple[bytes, bytes]]) -> Tuple[int, net_http.Headers]:
def parse_h2_response_headers(h2_headers: Sequence[Tuple[bytes, bytes]]) -> Tuple[int, net_http.Headers]:
"""Split HTTP/2 pseudo-headers from the actual headers and parse them."""
pseudo_headers, headers = split_pseudo_headers(h2_headers)

View File

@ -1,11 +1,11 @@
import collections
from typing import DefaultDict, Deque, NamedTuple
import h2.config
import h2.connection
import h2.events
import h2.settings
import h2.exceptions
from typing import DefaultDict, Deque, NamedTuple, Optional
import h2.settings
class H2ConnectionLogger(h2.config.DummyLogger):

View File

@ -28,6 +28,7 @@ class HttpUpstreamProxy(tunnel.TunnelLayer):
conn=ctx.server
)
assert self.tunnel_connection.address
self.conn.via = server_spec.ServerSpec(
"https" if self.tunnel_connection.tls else "http",
self.tunnel_connection.address
@ -43,6 +44,7 @@ class HttpUpstreamProxy(tunnel.TunnelLayer):
self.conn.alpn = self.tunnel_connection.alpn
if not self.send_connect:
return (yield from super().start_handshake())
assert self.conn.address
req = http.make_connect_request(self.conn.address)
raw = http1.assemble_request(req)
yield commands.SendData(self.tunnel_connection, raw)
@ -66,7 +68,8 @@ class HttpUpstreamProxy(tunnel.TunnelLayer):
return True, None
else:
raw_resp = b"\n".join(response_head)
yield commands.Log(f"{human.format_address(self.tunnel_connection.address)}: {raw_resp}", level="debug")
yield commands.Log(f"{human.format_address(self.tunnel_connection.address)}: {raw_resp!r}",
level="debug")
return False, f"{response.status_code} {response.reason}"
else:
return False, None

View File

@ -11,6 +11,7 @@ class ReverseProxy(layer.Layer):
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
spec = server_spec.parse_with_mode(self.context.options.mode)[1]
self.context.server = Server(spec.address)
child_layer: layer.Layer
if spec.scheme not in ("http", "tcp"):
if not self.context.options.keep_host_header:
self.context.server.sni = spec.address[0].encode()
@ -32,6 +33,7 @@ class HttpProxy(layer.Layer):
class TransparentProxy(layer.Layer):
@expect(events.Start)
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
assert platform.original_addr is not None
socket = yield commands.GetSocket(self.context.client)
try:
self.context.server.address = platform.original_addr(socket)

View File

@ -3,7 +3,7 @@ from typing import Optional
from mitmproxy import flow, tcp
from mitmproxy.proxy2 import commands, events, layer
from mitmproxy.proxy2.commands import Hook
from mitmproxy.proxy2.context import ConnectionState, Context
from mitmproxy.proxy2.context import ConnectionState, Context, Connection
from mitmproxy.proxy2.utils import expect
@ -73,6 +73,7 @@ class TCPLayer(layer.Layer):
@expect(events.DataReceived, events.ConnectionClosed)
def relay_messages(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
from_client = event.connection == self.context.client
send_to: Connection
if from_client:
send_to = self.context.server
else:

View File

@ -42,7 +42,7 @@ def handshake_record_contents(data: bytes) -> Iterator[bytes]:
return
record_header = data[offset:offset + 5]
if not is_tls_handshake_record(record_header):
raise ValueError(f"Expected TLS record, got {record_header} instead.")
raise ValueError(f"Expected TLS record, got {record_header!r} instead.")
record_size = struct.unpack("!H", record_header[3:])[0]
if record_size == 0:
raise ValueError("Record must not be empty.")
@ -132,7 +132,7 @@ class _TLSLayer(tunnel.TunnelLayer):
def __repr__(self):
return super().__repr__().replace(")", f" {self.conn.sni} {self.conn.alpn})")
def start_tls(self) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]:
def start_tls(self) -> layer.CommandGenerator[None]:
assert not self.tls
tls_start = TlsStartData(self.conn, self.context)
@ -169,6 +169,7 @@ class _TLSLayer(tunnel.TunnelLayer):
('SSL routines', 'ssl3_read_bytes', 'tlsv1 alert unknown ca'),
('SSL routines', 'ssl3_read_bytes', 'sslv3 alert bad certificate')
]:
assert isinstance(last_err, list)
err = last_err[2]
elif last_err == ('SSL routines', 'ssl3_get_record', 'wrong version number') and data[:4].isascii():
err = f"The remote server does not speak TLS."
@ -278,6 +279,7 @@ class ClientTLSLayer(_TLSLayer):
"""
recv_buffer: bytearray
server_tls_available: bool
client_hello_parsed: bool = False
def __init__(self, context: context.Context):
super().__init__(context, context.client)
@ -288,13 +290,17 @@ class ClientTLSLayer(_TLSLayer):
yield from ()
def receive_handshake_data(self, data: bytes) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]:
if self.client_hello_parsed:
return (yield from super().receive_handshake_data(data))
self.recv_buffer.extend(data)
try:
client_hello = parse_client_hello(self.recv_buffer)
except ValueError:
return False, f"Cannot parse ClientHello: {self.recv_buffer.hex()}"
if not client_hello:
if client_hello:
self.client_hello_parsed = True
else:
return False, None
self.conn.sni = client_hello.sni
@ -310,8 +316,7 @@ class ClientTLSLayer(_TLSLayer):
yield from self.start_tls()
self.receive_handshake_data = super().receive_handshake_data
ret = yield from self.receive_handshake_data(bytes(self.recv_buffer))
ret = yield from super().receive_handshake_data(bytes(self.recv_buffer))
self.recv_buffer.clear()
return ret
@ -327,6 +332,7 @@ class ClientTLSLayer(_TLSLayer):
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
if self.conn.sni:
assert isinstance(self.conn.sni, bytes)
dest = self.conn.sni.decode("idna")
else:
dest = human.format_address(self.context.server.address)

View File

@ -1,18 +1,16 @@
from typing import Optional, Union, List
from typing import Union, List, Iterator
import wsproto
import wsproto.utilities
import wsproto.frame_protocol
import wsproto.extensions
from wsproto.frame_protocol import CloseReason, Opcode
from wsproto import ConnectionState
from mitmproxy import flow, tcp, websocket, http
import wsproto.frame_protocol
import wsproto.utilities
from mitmproxy import flow, websocket, http
from mitmproxy.proxy2 import commands, events, layer, context
from mitmproxy.proxy2.commands import Hook
from mitmproxy.proxy2.context import Context
from mitmproxy.proxy2.utils import expect
from mitmproxy.utils import human
from wsproto import ConnectionState
from wsproto.frame_protocol import CloseReason, Opcode
class WebsocketStartHook(Hook):
@ -65,8 +63,8 @@ class WebsocketConnection(wsproto.Connection):
self.conn = conn
self.frame_buf = []
def send(self, event: wsproto.events.Event) -> commands.SendData:
data = super().send(event)
def send2(self, event: wsproto.events.Event) -> commands.SendData:
data = self.send(event)
return commands.SendData(self.conn, data)
def __repr__(self):
@ -77,7 +75,7 @@ class WebsocketLayer(layer.Layer):
"""
WebSocket layer that intercepts and relays messages.
"""
flow: Optional[websocket.WebSocketFlow]
flow: websocket.WebSocketFlow
client_ws: WebsocketConnection
server_ws: WebsocketConnection
@ -144,10 +142,10 @@ class WebsocketLayer(layer.Layer):
if ws_event.message_finished:
if isinstance(ws_event, wsproto.events.TextMessage):
frame_type = Opcode.TEXT
content = "".join(src_ws.frame_buf)
content = "".join(src_ws.frame_buf) # type: ignore
else:
frame_type = Opcode.BINARY
content = b"".join(src_ws.frame_buf)
content = b"".join(src_ws.frame_buf) # type: ignore
fragmentizer = Fragmentizer(src_ws.frame_buf)
src_ws.frame_buf.clear()
@ -158,15 +156,15 @@ class WebsocketLayer(layer.Layer):
assert not message.killed # this is deprecated, instead we should have .content set to emptystr.
for message in fragmentizer(message.content):
yield dst_ws.send(message)
for msg in fragmentizer(message.content):
yield dst_ws.send2(msg)
elif isinstance(ws_event, (wsproto.events.Ping, wsproto.events.Pong)):
yield commands.Log(
f"Received WebSocket {ws_event.__class__.__name__.lower()} from {from_str} "
f"(payload: {bytes(ws_event.payload)!r})"
)
yield dst_ws.send(ws_event)
yield dst_ws.send2(ws_event)
elif isinstance(ws_event, wsproto.events.CloseConnection):
self.flow.close_sender = from_str
self.flow.close_code = ws_event.code
@ -175,7 +173,7 @@ class WebsocketLayer(layer.Layer):
for ws in [self.server_ws, self.client_ws]:
if ws.state in {ConnectionState.OPEN, ConnectionState.REMOTE_CLOSING}:
# response == original event, so no need to differentiate here.
yield ws.send(ws_event)
yield ws.send2(ws_event)
yield commands.CloseConnection(ws.conn)
if ws_event.code in {1000, 1001, 1005}:
yield WebsocketEndHook(self.flow)
@ -219,7 +217,7 @@ class Fragmentizer:
assert fragments
self.fragment_lengths = [len(x) for x in fragments]
def __call__(self, content: Union[str, bytes]):
def __call__(self, content: Union[str, bytes]) -> Iterator[wsproto.events.Message]:
if not content:
return
if len(content) == sum(self.fragment_lengths):

View File

@ -75,6 +75,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
timeout_watchdog: TimeoutWatchdog
client: Client
max_conns: typing.DefaultDict[Address, asyncio.Semaphore]
layer: layer.Layer
def __init__(self, context: Context) -> None:
self.client = context.client
@ -93,18 +94,24 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
name="timeout watchdog",
client=self.client.peername,
)
if not watch:
return # this should not be needed, see asyncio_utils.create_task
self.log("client connect")
await self.handle_hook(server_hooks.ClientConnectedHook(self.client))
if self.client.error:
self.log("client kill connection")
self.transports.pop(self.client).writer.close()
writer = self.transports.pop(self.client).writer
assert writer
writer.close()
else:
handler = asyncio_utils.create_task(
self.handle_connection(self.client),
name=f"client connection handler",
client=self.client.peername,
)
if not handler:
return # this should not be needed, see asyncio_utils.create_task
self.transports[self.client].handler = handler
self.server_event(events.Start())
await asyncio.wait([handler])
@ -118,8 +125,9 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
if self.transports:
self.log("closing transports...", "debug")
for io in self.transports.values():
asyncio_utils.cancel_task(io.handler, "client disconnected")
await asyncio.wait([x.handler for x in self.transports.values()])
if io.handler:
asyncio_utils.cancel_task(io.handler, "client disconnected")
await asyncio.wait([x.handler for x in self.transports.values() if x.handler])
self.log("transports closed!", "debug")
async def open_connection(self, command: commands.OpenConnection) -> None:
@ -162,6 +170,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
self.transports[command.connection].reader = reader
self.transports[command.connection].writer = writer
assert command.connection.peername
if command.connection.address[0] != command.connection.peername[0]:
addr = f"{command.connection.address[0]} ({human.format_address(command.connection.peername)})"
else:
@ -172,6 +181,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
name=f"handle_hook(server_connected) {addr}",
client=self.client.peername,
)
if not connected_hook:
return # this should not be needed, see asyncio_utils.create_task
self.server_event(events.OpenConnectionReply(command, None))
@ -183,6 +194,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
name=f"server connection handler for {addr}",
client=self.client.peername,
)
if not new_handler:
return # this should not be needed, see asyncio_utils.create_task
self.transports[command.connection].handler = new_handler
await asyncio.wait([new_handler])
@ -227,7 +240,9 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
await asyncio.Event().wait()
try:
self.transports[connection].writer.close()
writer = self.transports[connection].writer
assert writer
writer.close()
except OSError:
pass
self.transports.pop(connection)
@ -237,7 +252,9 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
async def on_timeout(self) -> None:
self.log(f"Closing connection due to inactivity: {self.client}")
asyncio_utils.cancel_task(self.transports[self.client].handler, "timeout")
handler = self.transports[self.client].handler
assert handler
asyncio_utils.cancel_task(handler, "timeout")
async def hook_task(self, hook: commands.Hook) -> None:
await self.handle_hook(hook)
@ -268,11 +285,15 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
elif isinstance(command, commands.ConnectionCommand) and command.connection not in self.transports:
return # The connection has already been closed.
elif isinstance(command, commands.SendData):
self.transports[command.connection].writer.write(command.data)
writer = self.transports[command.connection].writer
assert writer
writer.write(command.data)
elif isinstance(command, commands.CloseConnection):
self.close_connection(command.connection, command.half_close)
elif isinstance(command, commands.GetSocket):
socket = self.transports[command.connection].writer.get_extra_info("socket")
writer = self.transports[command.connection].writer
assert writer
socket = writer.get_extra_info("socket")
self.server_event(events.GetSocketReply(command, socket))
elif isinstance(command, commands.Hook):
asyncio_utils.create_task(
@ -293,7 +314,9 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
return
self.log(f"half-closing {connection}", "debug")
try:
self.transports[connection].writer.write_eof()
writer = self.transports[connection].writer
assert writer
writer.write_eof()
except OSError:
# if we can't write to the socket anymore we presume it completely dead.
connection.state = ConnectionState.CLOSED
@ -303,7 +326,9 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
connection.state = ConnectionState.CLOSED
if connection.state is ConnectionState.CLOSED:
asyncio_utils.cancel_task(self.transports[connection].handler, "closed by command")
handler = self.transports[connection].handler
assert handler
asyncio_utils.cancel_task(handler, "closed by command")
class StreamConnectionHandler(ConnectionHandler, metaclass=abc.ABCMeta):
@ -359,7 +384,6 @@ if __name__ == "__main__":
)
opts.mode = "reverse:http://127.0.0.1:3000/"
async def handle(reader, writer):
layer_stack = [
# lambda ctx: layers.ServerTLSLayer(ctx),
@ -371,8 +395,9 @@ if __name__ == "__main__":
]
def next_layer(nl: layer.NextLayer):
nl.layer = layer_stack.pop(0)(nl.context)
nl.layer.debug = " " * len(nl.context.layers)
l = layer_stack.pop(0)(nl.context)
l.debug = " " * len(nl.context.layers)
nl.layer = l
def request(flow: http.HTTPFlow):
if "cached" in flow.request.path:
@ -410,11 +435,11 @@ if __name__ == "__main__":
"tls_start": tls_start,
}).handle_client()
coro = asyncio.start_server(handle, '127.0.0.1', 8080, loop=loop)
server = loop.run_until_complete(coro)
# Serve requests until Ctrl+C is pressed
assert server.sockets
print(f"Serving on {human.format_address(server.sockets[0].getsockname())}")
try:
loop.run_forever()

View File

@ -1,5 +1,5 @@
from enum import Enum, auto
from typing import Callable, List, Optional, Tuple, Type
from typing import List, Optional, Tuple
from mitmproxy.proxy2 import commands, context, events, layer
from mitmproxy.proxy2.layer import Layer
@ -138,7 +138,7 @@ class LayerStack:
def __truediv__(self, other: Layer) -> "LayerStack":
if self._stack:
self._stack[-1].child_layer = other
self._stack[-1].child_layer = other # type: ignore
self._stack.append(other)
return self

View File

@ -1,15 +1,15 @@
"""
Usage:
- pip install pytest-benchmark
- pytest bench.py
Usage:
- pip install pytest-benchmark
- pytest bench.py
See also:
- https://github.com/mitmproxy/proxybench
- https://github.com/mitmproxy/proxybench
"""
import copy
from .layers.http import test_http, test_http2
from .layers import test_tcp, test_tls
from .layers.http import test_http, test_http2
def test_bench_http_roundtrip(tctx, benchmark):

View File

@ -1,6 +1,6 @@
import pytest
from mitmproxy import log, options
from mitmproxy import options
from mitmproxy.addons.proxyserver import Proxyserver
from mitmproxy.addons.termlog import TermLog
from mitmproxy.proxy2 import context

View File

@ -2,11 +2,10 @@ import struct
from unittest import mock
import pytest
from mitmproxy.proxy2.layers.old import websocket
from mitmproxy.net.websockets import Frame, OPCODE
from mitmproxy.proxy2 import commands, events
from mitmproxy.proxy2.layers.old import websocket
from mitmproxy.proxy2.context import ConnectionState
from mitmproxy.test import tflow
from .. import tutils
@ -39,24 +38,24 @@ def test_simple(tctx, ws_playbook):
]
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.client, frames[0])
<< commands.Hook("websocket_message", f)
>> events.HookReply(-1)
<< commands.SendData(tctx.server, frames[0])
>> events.DataReceived(tctx.server, frames[1])
<< commands.Hook("websocket_message", f)
>> events.HookReply(-1)
<< commands.SendData(tctx.client, frames[1])
>> events.DataReceived(tctx.client, frames[2])
<< commands.SendData(tctx.server, frames[2])
<< commands.SendData(tctx.client, frames[3])
<< commands.Hook("websocket_end", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.server, frames[4])
<< None
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.client, frames[0])
<< commands.Hook("websocket_message", f)
>> events.HookReply(-1)
<< commands.SendData(tctx.server, frames[0])
>> events.DataReceived(tctx.server, frames[1])
<< commands.Hook("websocket_message", f)
>> events.HookReply(-1)
<< commands.SendData(tctx.client, frames[1])
>> events.DataReceived(tctx.client, frames[2])
<< commands.SendData(tctx.server, frames[2])
<< commands.SendData(tctx.client, frames[3])
<< commands.Hook("websocket_end", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.server, frames[4])
<< None
)
assert len(f().messages) == 2
@ -71,31 +70,31 @@ def test_server_close(tctx, ws_playbook):
]
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.server, frames[0])
<< commands.SendData(tctx.client, frames[0])
<< commands.SendData(tctx.server, frames[1])
<< commands.Hook("websocket_end", f)
>> events.HookReply(-1)
<< commands.CloseConnection(tctx.client)
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.server, frames[0])
<< commands.SendData(tctx.client, frames[0])
<< commands.SendData(tctx.server, frames[1])
<< commands.Hook("websocket_end", f)
>> events.HookReply(-1)
<< commands.CloseConnection(tctx.client)
)
def test_connection_closed(tctx, ws_playbook):
f = tutils.Placeholder()
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.ConnectionClosed(tctx.server)
<< commands.Log("error", "Connection closed abnormally")
<< commands.CloseConnection(tctx.client)
<< commands.Hook("websocket_error", f)
>> events.HookReply(-1)
<< commands.Hook("websocket_end", f)
>> events.HookReply(-1)
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.ConnectionClosed(tctx.server)
<< commands.Log("error", "Connection closed abnormally")
<< commands.CloseConnection(tctx.client)
<< commands.Hook("websocket_error", f)
>> events.HookReply(-1)
<< commands.Hook("websocket_end", f)
>> events.HookReply(-1)
)
assert f().error
@ -111,16 +110,16 @@ def test_connection_failed(tctx, ws_playbook):
]
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.client, frames[0])
<< commands.SendData(tctx.server, frames[1])
<< commands.SendData(tctx.client, frames[2])
<< commands.Hook("websocket_error", f)
>> events.HookReply(-1)
<< commands.Hook("websocket_end", f)
>> events.HookReply(-1)
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.client, frames[0])
<< commands.SendData(tctx.server, frames[1])
<< commands.SendData(tctx.client, frames[2])
<< commands.Hook("websocket_error", f)
>> events.HookReply(-1)
<< commands.Hook("websocket_end", f)
>> events.HookReply(-1)
)
@ -133,15 +132,15 @@ def test_ping_pong(tctx, ws_playbook):
]
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.client, frames[0])
<< commands.Log("info", "WebSocket PING received from client: <no payload>")
<< commands.SendData(tctx.server, frames[0])
<< commands.SendData(tctx.client, frames[1])
>> events.DataReceived(tctx.server, frames[1])
<< commands.Log("info", "WebSocket PONG received from server: <no payload>")
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.client, frames[0])
<< commands.Log("info", "WebSocket PING received from client: <no payload>")
<< commands.SendData(tctx.server, frames[0])
<< commands.SendData(tctx.client, frames[1])
>> events.DataReceived(tctx.server, frames[1])
<< commands.Log("info", "WebSocket PONG received from server: <no payload>")
)
@ -156,15 +155,15 @@ def test_ping_pong_hidden_payload(tctx, ws_playbook):
]
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.server, frames[0])
<< commands.Log("info", "WebSocket PING received from server: foobar")
<< commands.SendData(tctx.client, frames[1])
<< commands.SendData(tctx.server, frames[2])
>> events.DataReceived(tctx.client, frames[3])
<< commands.Log("info", "WebSocket PONG received from client: <no payload>")
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.server, frames[0])
<< commands.Log("info", "WebSocket PING received from server: foobar")
<< commands.SendData(tctx.client, frames[1])
<< commands.SendData(tctx.server, frames[2])
>> events.DataReceived(tctx.client, frames[3])
<< commands.Log("info", "WebSocket PONG received from client: <no payload>")
)
@ -181,17 +180,17 @@ def test_extension(tctx, ws_playbook):
]
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.client, frames[0])
<< commands.Hook("websocket_message", f)
>> events.HookReply(-1)
<< commands.SendData(tctx.server, frames[0])
>> events.DataReceived(tctx.server, frames[1])
<< commands.Hook("websocket_message", f)
>> events.HookReply(-1)
<< commands.SendData(tctx.client, frames[2])
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.client, frames[0])
<< commands.Hook("websocket_message", f)
>> events.HookReply(-1)
<< commands.SendData(tctx.server, frames[0])
>> events.DataReceived(tctx.server, frames[1])
<< commands.Hook("websocket_message", f)
>> events.HookReply(-1)
<< commands.SendData(tctx.client, frames[2])
)
assert len(f().messages) == 2
assert f().messages[0].content == "Hello"

View File

@ -8,13 +8,12 @@ helpers
This module contains helpers for the h2 tests.
"""
from hpack.hpack import Encoder
from hyperframe.frame import (
HeadersFrame, DataFrame, SettingsFrame, WindowUpdateFrame, PingFrame,
GoAwayFrame, RstStreamFrame, PushPromiseFrame, PriorityFrame,
ContinuationFrame, AltSvcFrame
)
from hpack.hpack import Encoder
SAMPLE_SETTINGS = {
SettingsFrame.HEADER_TABLE_SIZE: 4096,
@ -29,6 +28,7 @@ class FrameFactory(object):
allows test cases to easily build correct HTTP/2 frames to feed to
hyper-h2.
"""
def __init__(self):
self.encoder = Encoder()
@ -176,4 +176,4 @@ class FrameFactory(object):
Causes the encoder to send a dynamic size update in the next header
block it sends.
"""
self.encoder.header_table_size = new_size
self.encoder.header_table_size = new_size

View File

@ -347,8 +347,7 @@ def test_request_streaming(tctx, response):
>> reply()
<< SendData(tctx.client, b"HTTP/1.1 413 Request Entity Too Large\r\nContent-Length: 0\r\n\r\n")
>> DataReceived(tctx.client, b"def")
<< SendData(server, b"DEF")
# Important: no request hook here!
<< SendData(server, b"DEF") # Important: no request hook here!
)
elif response == "early close":
assert (
@ -875,7 +874,10 @@ def test_close_during_connect_hook(tctx):
assert (
Playbook(http.HttpLayer(tctx, HTTPMode.regular))
>> DataReceived(tctx.client,
b'CONNECT hi.ls:443 HTTP/1.1\r\nProxy-Connection: keep-alive\r\nConnection: keep-alive\r\nHost: hi.ls:443\r\n\r\n')
b'CONNECT hi.ls:443 HTTP/1.1\r\n'
b'Proxy-Connection: keep-alive\r\n'
b'Connection: keep-alive\r\n'
b'Host: hi.ls:443\r\n\r\n')
<< http.HttpConnectHook(flow)
>> ConnectionClosed(tctx.client)
<< CloseConnection(tctx.client)

View File

@ -1,5 +1,4 @@
import random
from typing import List, Tuple, Dict, Any
from typing import List, Tuple
import h2.settings
import hpack
@ -17,7 +16,7 @@ from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
from mitmproxy.proxy2.layers import http
from mitmproxy.proxy2.layers.http._http2 import split_pseudo_headers, Http2Client
from test.mitmproxy.proxy2.layers.http.hyper_h2_test_helpers import FrameFactory
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply, _TracebackInPlaybook, _fmt_entry, _eq
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply
example_request_headers = (
(b':method', b'GET'),
@ -527,5 +526,4 @@ class TestClient:
<< SendData(tctx.server, frame_factory.build_rst_stream_frame(1, ErrorCodes.CANCEL).serialize())
>> DataReceived(tctx.server, frame_factory.build_data_frame(b"foo").serialize())
<< SendData(tctx.server, frame_factory.build_rst_stream_frame(1, ErrorCodes.STREAM_CLOSED).serialize())
# important: no ResponseData event here!
)
) # important: no ResponseData event here!

View File

@ -239,7 +239,8 @@ def _h2_request(chunks):
@example([b'\x00\x00\x12\x01\x04\x00\x00\x00\x01\x84\x86\x82`\x80A\x88/\x91\xd3]\x05\\\x87\xa7\\\x81\x07'])
@example([b'\x00\x00\x14\x01\x04\x00\x00\x00\x01A\x88/\x91\xd3]\x05\\\x87\xa7\x84\x86`\x80\x82f\x80'])
@example([
b'\x00\x00%\x01\x04\x00\x00\x00\x01A\x8b/\x91\xd3]\x05\\\x87\xa6\xe3M3\x84\x86\x82`\x85\x94\xe7\x8c~\xfff\x88/\x91\xd3]\x05\\\x87\xa7\\\x82h_\x00\x00\x07\x01\x05\x00\x00\x00\x01\xc1\x84\x86\x82\xc0\xbf\xbe'])
b'\x00\x00%\x01\x04\x00\x00\x00\x01A\x8b/\x91\xd3]\x05\\\x87\xa6\xe3M3\x84\x86\x82`\x85\x94\xe7\x8c~\xfff\x88/\x91'
b'\xd3]\x05\\\x87\xa7\\\x82h_\x00\x00\x07\x01\x05\x00\x00\x00\x01\xc1\x84\x86\x82\xc0\xbf\xbe'])
def test_fuzz_h2_request_chunks(chunks):
_h2_request(chunks)

View File

@ -2,7 +2,6 @@ import copy
import pytest
from mitmproxy.http import HTTPFlow
from mitmproxy.proxy.protocol.http import HTTPMode
from mitmproxy.proxy2.commands import CloseConnection, OpenConnection, SendData
from mitmproxy.proxy2.context import Client, Context, Server

View File

@ -2,8 +2,8 @@ import ssl
import typing
import pytest
from OpenSSL import SSL
from OpenSSL import SSL
from mitmproxy.proxy2 import commands, context, events, layer
from mitmproxy.proxy2.context import ConnectionState
from mitmproxy.proxy2.layers import tls
@ -98,7 +98,7 @@ class SSLTest:
def bio_write(self, buf: bytes) -> int:
return self.inc.write(buf)
def bio_read(self, bufsize: int = 2**16) -> bytes:
def bio_read(self, bufsize: int = 2 ** 16) -> bytes:
return self.out.read(bufsize)
def do_handshake(self) -> None:
@ -138,6 +138,7 @@ def interact(playbook: tutils.Playbook, conn: context.Connection, tssl: SSLTest)
)
tssl.bio_write(data())
def reply_tls_start(alpn: typing.Optional[bytes] = None, *args, **kwargs) -> tutils.reply:
"""
Helper function to simplify the syntax for tls_start hooks.

View File

@ -9,7 +9,7 @@ from mitmproxy.http import HTTPFlow
from mitmproxy.net.http import Request, Response
from mitmproxy.proxy.protocol.http import HTTPMode
from mitmproxy.proxy2.commands import SendData, CloseConnection, Log
from mitmproxy.proxy2.context import Server, ConnectionState
from mitmproxy.proxy2.context import ConnectionState
from mitmproxy.proxy2.events import DataReceived, ConnectionClosed
from mitmproxy.proxy2.layers import http, websocket
from mitmproxy.websocket import WebSocketFlow

View File

@ -10,22 +10,22 @@ class TestNextLayer:
playbook = tutils.Playbook(nl, hooks=True)
assert (
playbook
>> events.DataReceived(tctx.client, b"foo")
<< layer.NextLayerHook(nl)
>> tutils.reply()
>> events.DataReceived(tctx.client, b"bar")
<< layer.NextLayerHook(nl)
playbook
>> events.DataReceived(tctx.client, b"foo")
<< layer.NextLayerHook(nl)
>> tutils.reply()
>> events.DataReceived(tctx.client, b"bar")
<< layer.NextLayerHook(nl)
)
assert nl.data_client() == b"foobar"
assert nl.data_server() == b""
nl.layer = tutils.EchoLayer(tctx)
assert (
playbook
>> tutils.reply()
<< commands.SendData(tctx.client, b"foo")
<< commands.SendData(tctx.client, b"bar")
playbook
>> tutils.reply()
<< commands.SendData(tctx.client, b"foo")
<< commands.SendData(tctx.client, b"bar")
)
def test_late_hook_reply(self, tctx):
@ -37,19 +37,19 @@ class TestNextLayer:
playbook = tutils.Playbook(nl)
assert (
playbook
>> events.DataReceived(tctx.client, b"foo")
<< layer.NextLayerHook(nl)
>> events.DataReceived(tctx.client, b"bar")
playbook
>> events.DataReceived(tctx.client, b"foo")
<< layer.NextLayerHook(nl)
>> events.DataReceived(tctx.client, b"bar")
)
assert nl.data_client() == b"foo" # "bar" is paused.
nl.layer = tutils.EchoLayer(tctx)
assert (
playbook
>> tutils.reply(to=-2)
<< commands.SendData(tctx.client, b"foo")
<< commands.SendData(tctx.client, b"bar")
playbook
>> tutils.reply(to=-2)
<< commands.SendData(tctx.client, b"foo")
<< commands.SendData(tctx.client, b"bar")
)
@pytest.mark.parametrize("layer_found", [True, False])
@ -58,23 +58,23 @@ class TestNextLayer:
nl = layer.NextLayer(tctx)
playbook = tutils.Playbook(nl)
assert (
playbook
>> events.DataReceived(tctx.client, b"foo")
<< layer.NextLayerHook(nl)
>> events.ConnectionClosed(tctx.client)
playbook
>> events.DataReceived(tctx.client, b"foo")
<< layer.NextLayerHook(nl)
>> events.ConnectionClosed(tctx.client)
)
if layer_found:
nl.layer = tutils.RecordLayer(tctx)
assert (
playbook
>> tutils.reply(to=-2)
playbook
>> tutils.reply(to=-2)
)
assert isinstance(nl.layer.event_log[-1], events.ConnectionClosed)
else:
assert (
playbook
>> tutils.reply(to=-2)
<< commands.CloseConnection(tctx.client)
playbook
>> tutils.reply(to=-2)
<< commands.CloseConnection(tctx.client)
)
def test_func_references(self, tctx):
@ -82,16 +82,16 @@ class TestNextLayer:
playbook = tutils.Playbook(nl)
assert (
playbook
>> events.DataReceived(tctx.client, b"foo")
<< layer.NextLayerHook(nl)
playbook
>> events.DataReceived(tctx.client, b"foo")
<< layer.NextLayerHook(nl)
)
nl.layer = tutils.EchoLayer(tctx)
handle = nl.handle_event
assert (
playbook
>> tutils.reply()
<< commands.SendData(tctx.client, b"foo")
playbook
>> tutils.reply()
<< commands.SendData(tctx.client, b"foo")
)
sd, = handle(events.DataReceived(tctx.client, b"bar"))
assert isinstance(sd, commands.SendData)

View File

@ -79,9 +79,9 @@ def _merge_sends(lst: typing.List[commands.Command], ignore_hooks: bool, ignore_
current_send.data += x.data
else:
ignore = (
(ignore_hooks and isinstance(x, commands.Hook))
or
(ignore_logs and isinstance(x, commands.Log))
(ignore_hooks and isinstance(x, commands.Hook))
or
(ignore_logs and isinstance(x, commands.Log))
)
if not ignore:
current_send = None
@ -244,7 +244,7 @@ class Playbook:
# the current event may still have yielded more events, so we need to insert
# the reply *after* those additional events.
hook_replies.append(events.HookReply(cmd))
self.expected = self.expected[:pos+1] + hook_replies + self.expected[pos+1:]
self.expected = self.expected[:pos + 1] + hook_replies + self.expected[pos + 1:]
eq(self.expected[i:], self.actual[i:]) # compare now already to set placeholders
i += 1