mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 00:01:36 +00:00
[sans-io] fixes, fixes, fixes
This commit is contained in:
parent
b2060356b6
commit
a30a6758f3
@ -86,12 +86,11 @@ class NextLayer:
|
||||
|
||||
def next_layer(self, nextlayer: layer.NextLayer):
|
||||
nextlayer.layer = self._next_layer(nextlayer.context, nextlayer.data_client())
|
||||
# nextlayer.layer.debug = " " * len(nextlayer.context.layers)
|
||||
|
||||
def _next_layer(self, context: context.Context, data_client: bytes) -> typing.Optional[layer.Layer]:
|
||||
if len(context.layers) == 0:
|
||||
return self.make_top_layer(context)
|
||||
if len(context.layers) == 1:
|
||||
return layers.ServerTLSLayer(context)
|
||||
|
||||
if len(data_client) < 3:
|
||||
return
|
||||
@ -113,21 +112,14 @@ class NextLayer:
|
||||
if isinstance(top_layer, layers.ServerTLSLayer):
|
||||
return layers.ClientTLSLayer(context)
|
||||
else:
|
||||
if s(modes.HttpProxy):
|
||||
# A "Secure Web Proxy" (https://www.chromium.org/developers/design-documents/secure-web-proxy)
|
||||
# This does not imply TLS on the server side.
|
||||
pass
|
||||
else:
|
||||
# In all other cases, client TLS implies TLS for both ends.
|
||||
context.server.tls = True
|
||||
return layers.ServerTLSLayer(context)
|
||||
|
||||
# 3. Setup the HTTP layer for a regular HTTP proxy or an upstream proxy.
|
||||
if any([
|
||||
s(modes.HttpProxy, layers.ServerTLSLayer),
|
||||
s(modes.HttpProxy, layers.ServerTLSLayer, layers.ClientTLSLayer),
|
||||
s(modes.HttpProxy),
|
||||
s(modes.HttpProxy, layers.ClientTLSLayer),
|
||||
]):
|
||||
return layers.HTTPLayer(context, HTTPMode.regular)
|
||||
return layers.HttpLayer(context, HTTPMode.regular)
|
||||
if ctx.options.mode.startswith("upstream:") and len(context.layers) <= 3 and isinstance(top_layer,
|
||||
layers.ServerTLSLayer):
|
||||
raise NotImplementedError()
|
||||
@ -153,7 +145,7 @@ class NextLayer:
|
||||
return layers.TCPLayer(context)
|
||||
|
||||
# 6. Assume HTTP by default.
|
||||
return layers.HTTPLayer(context, HTTPMode.transparent)
|
||||
return layers.HttpLayer(context, HTTPMode.transparent)
|
||||
|
||||
def make_top_layer(self, context: context.Context) -> layer.Layer:
|
||||
if ctx.options.mode == "regular":
|
||||
|
@ -13,9 +13,11 @@ from mitmproxy.proxy2.layers import tls
|
||||
|
||||
def alpn_select_callback(conn: SSL.Connection, options):
|
||||
server_alpn = conn.get_app_data()["server_alpn"]
|
||||
http2 = conn.get_app_data()["http2"]
|
||||
if server_alpn and server_alpn in options:
|
||||
return server_alpn
|
||||
for alpn in tls.HTTP_ALPNS:
|
||||
http_alpns = tls.HTTP_ALPNS if http2 else tls.HTTP1_ALPNS
|
||||
for alpn in http_alpns:
|
||||
if alpn in options:
|
||||
return alpn
|
||||
else:
|
||||
@ -76,11 +78,11 @@ class TlsConfig:
|
||||
)
|
||||
tls_start.ssl_conn = SSL.Connection(ssl_ctx)
|
||||
tls_start.ssl_conn.set_app_data({
|
||||
"server_alpn": tls_start.context.server.alpn
|
||||
"server_alpn": tls_start.context.server.alpn,
|
||||
"http2": ctx.options.http2,
|
||||
})
|
||||
tls_start.ssl_conn.set_accept_state()
|
||||
|
||||
|
||||
def create_proxy_server_ssl_conn(self, tls_start: tls.TlsStartData) -> None:
|
||||
client = tls_start.context.client
|
||||
server: context.Server = tls_start.conn
|
||||
@ -120,7 +122,7 @@ class TlsConfig:
|
||||
if os.path.exists(path):
|
||||
client_cert = path
|
||||
|
||||
args["cipher_list"] = ':'.join(server.cipher_list) if server.cipher_list else None
|
||||
args["cipher_list"] = b':'.join(server.cipher_list) if server.cipher_list else None
|
||||
ssl_ctx = net_tls.create_client_context(
|
||||
cert=client_cert,
|
||||
sni=server.sni.decode("idna"), # FIXME: Should pass-through here.
|
||||
|
@ -1,5 +1,5 @@
|
||||
from enum import Flag, auto
|
||||
from typing import List, Optional, Sequence, Union
|
||||
from typing import List, Literal, Optional, Sequence, Union
|
||||
|
||||
from mitmproxy import certs
|
||||
from mitmproxy.options import Options
|
||||
@ -16,7 +16,7 @@ class Connection:
|
||||
"""
|
||||
Connections exposed to the layers only contain metadata, no socket objects.
|
||||
"""
|
||||
address: tuple
|
||||
address: Optional[tuple]
|
||||
state: ConnectionState
|
||||
tls: bool = False
|
||||
tls_established: bool = False
|
||||
@ -25,7 +25,7 @@ class Connection:
|
||||
alpn_offers: Sequence[bytes] = ()
|
||||
cipher_list: Sequence[bytes] = ()
|
||||
tls_version: Optional[str] = None
|
||||
sni: Union[bytes, bool, None]
|
||||
sni: Union[bytes, Literal[True], None]
|
||||
|
||||
timestamp_tls_setup: Optional[float] = None
|
||||
|
||||
@ -33,21 +33,17 @@ class Connection:
|
||||
def connected(self):
|
||||
return self.state is ConnectionState.OPEN
|
||||
|
||||
@connected.setter
|
||||
def connected(self, val: bool) -> None:
|
||||
# We should really set .state, but verdict is still due if we even want to keep .state around.
|
||||
# We allow setting .connected while we figure that out.
|
||||
if val:
|
||||
self.state = ConnectionState.OPEN
|
||||
else:
|
||||
self.state = ConnectionState.CLOSED
|
||||
|
||||
def __repr__(self):
|
||||
return f"{type(self).__name__}({repr(self.__dict__)})"
|
||||
attrs = repr({
|
||||
k: {"cipher_list": lambda: f"<{len(v)} ciphers>"}.get(k,lambda: v)()
|
||||
for k, v in self.__dict__.items()
|
||||
})
|
||||
return f"{type(self).__name__}({attrs})"
|
||||
|
||||
|
||||
class Client(Connection):
|
||||
sni: Optional[bytes] = None
|
||||
sni: Union[bytes, None] = None
|
||||
address: tuple
|
||||
|
||||
def __init__(self, address):
|
||||
self.address = address
|
||||
@ -55,9 +51,9 @@ class Client(Connection):
|
||||
|
||||
|
||||
class Server(Connection):
|
||||
sni: Union[bytes, bool] = True
|
||||
sni = True
|
||||
"""True: client SNI, False: no SNI, bytes: custom value"""
|
||||
address: Optional[tuple]
|
||||
via: Optional["Server"] = None
|
||||
|
||||
def __init__(self, address: Optional[tuple]):
|
||||
self.address = address
|
||||
|
@ -5,7 +5,7 @@ The counterpart to events are commands.
|
||||
"""
|
||||
import socket
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, is_dataclass
|
||||
|
||||
from mitmproxy.proxy2 import commands
|
||||
from mitmproxy.proxy2.context import Connection
|
||||
@ -55,7 +55,6 @@ class DataReceived(ConnectionEvent):
|
||||
return f"DataReceived({target}, {self.data})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandReply(Event):
|
||||
"""
|
||||
Emitted when a command has been finished, e.g.
|
||||
@ -65,6 +64,7 @@ class CommandReply(Event):
|
||||
reply: typing.Any
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
assert is_dataclass(cls)
|
||||
if cls is CommandReply:
|
||||
raise TypeError("CommandReply may not be instantiated directly.")
|
||||
return super().__new__(cls)
|
||||
|
@ -5,6 +5,7 @@ import collections
|
||||
import textwrap
|
||||
import typing
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
from mitmproxy import log
|
||||
from mitmproxy.proxy2 import commands, events
|
||||
@ -149,11 +150,14 @@ class NextLayer(Layer):
|
||||
events: typing.List[mevents.Event]
|
||||
"""All events that happened before a decision was made."""
|
||||
|
||||
def __init__(self, context: Context) -> None:
|
||||
_ask_on_start: bool
|
||||
|
||||
def __init__(self, context: Context, ask_on_start: bool = False) -> None:
|
||||
super().__init__(context)
|
||||
self.context.layers.remove(self)
|
||||
self.events = []
|
||||
self.layer = None
|
||||
self.events = []
|
||||
self._ask_on_start = ask_on_start
|
||||
self._handle = None
|
||||
|
||||
def __repr__(self):
|
||||
@ -169,11 +173,13 @@ class NextLayer(Layer):
|
||||
self.events.append(event)
|
||||
|
||||
# We receive new data. Let's find out if we can determine the next layer now?
|
||||
if isinstance(event, mevents.DataReceived):
|
||||
if self._ask_on_start and isinstance(event, events.Start):
|
||||
yield from self._ask()
|
||||
elif isinstance(event, mevents.DataReceived):
|
||||
# For now, we only ask if we have received new data to reduce hook noise.
|
||||
yield from self.ask_now()
|
||||
yield from self._ask()
|
||||
|
||||
def ask_now(self):
|
||||
def _ask(self):
|
||||
"""
|
||||
Manually trigger a next_layer hook.
|
||||
The only use at the moment is to make sure that the top layer is initialized.
|
||||
|
@ -1,11 +1,11 @@
|
||||
from . import modes
|
||||
from .http import HTTPLayer
|
||||
from .http import HttpLayer
|
||||
from .tcp import TCPLayer
|
||||
from .tls import ClientTLSLayer, ServerTLSLayer
|
||||
|
||||
__all__ = [
|
||||
"modes",
|
||||
"HTTPLayer",
|
||||
"HttpLayer",
|
||||
"TCPLayer",
|
||||
"ClientTLSLayer", "ServerTLSLayer",
|
||||
]
|
||||
|
@ -1,5 +1,6 @@
|
||||
import collections
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
|
||||
from mitmproxy import flow, http
|
||||
from mitmproxy.proxy.protocol.http import HTTPMode
|
||||
@ -8,7 +9,7 @@ from mitmproxy.proxy2.context import Connection, Context, Server
|
||||
from mitmproxy.proxy2.layers import tls
|
||||
from mitmproxy.proxy2.utils import expect
|
||||
from mitmproxy.utils import human
|
||||
from ._base import HttpConnection, StreamId, HttpCommand, ReceiveHttp
|
||||
from ._base import HttpCommand, HttpConnection, ReceiveHttp, StreamId
|
||||
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
|
||||
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
|
||||
from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \
|
||||
@ -37,6 +38,7 @@ class GetHttpConnection(HttpCommand):
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetHttpConnectionReply(events.CommandReply):
|
||||
command: GetHttpConnection
|
||||
reply: typing.Tuple[typing.Optional[Connection], typing.Optional[str]]
|
||||
@ -67,8 +69,6 @@ class SendHttp(HttpCommand):
|
||||
return f"Send({self.event})"
|
||||
|
||||
|
||||
|
||||
|
||||
class HttpStream(layer.Layer):
|
||||
request_body_buf: bytes
|
||||
response_body_buf: bytes
|
||||
@ -78,7 +78,7 @@ class HttpStream(layer.Layer):
|
||||
|
||||
@property
|
||||
def mode(self):
|
||||
parent: HTTPLayer = self.context.layers[-2]
|
||||
parent: HttpLayer = self.context.layers[-2]
|
||||
return parent.mode
|
||||
|
||||
def __init__(self, context: Context):
|
||||
@ -309,7 +309,7 @@ class HttpStream(layer.Layer):
|
||||
yield from ()
|
||||
|
||||
|
||||
class HTTPLayer(layer.Layer):
|
||||
class HttpLayer(layer.Layer):
|
||||
"""
|
||||
ConnectionEvent: We have received b"GET /\r\n\r\n" from the client.
|
||||
HttpEvent: We have received request headers
|
||||
@ -339,22 +339,22 @@ class HTTPLayer(layer.Layer):
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
return f"HTTPLayer(conns: {len(self.connections)}, queue: {[type(e).__name__ for e in self.command_queue]})"
|
||||
return f"HttpLayer(conns: {len(self.connections)}, queue: {[type(e).__name__ for e in self.command_queue]})"
|
||||
|
||||
def _handle_event(self, event: events.Event):
|
||||
if isinstance(event, events.Start):
|
||||
return
|
||||
elif isinstance(event, events.CommandReply):
|
||||
try:
|
||||
stream = self.stream_by_command.pop(event.command)
|
||||
except KeyError:
|
||||
raise
|
||||
self.event_to_child(stream, event)
|
||||
elif isinstance(event, events.ConnectionEvent):
|
||||
if event.connection == self.context.server and self.context.server not in self.connections:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
handler = self.connections[event.connection]
|
||||
except KeyError:
|
||||
raise
|
||||
self.event_to_child(handler, event)
|
||||
else:
|
||||
raise ValueError(f"Unexpected event: {event}")
|
||||
@ -388,7 +388,7 @@ class HTTPLayer(layer.Layer):
|
||||
|
||||
def get_connection(self, event: GetHttpConnection, *, reuse: bool = True):
|
||||
# Do we already have a connection we can re-use?
|
||||
for connection, layer in self.connections.items():
|
||||
for connection in self.connections:
|
||||
connection_suitable = (
|
||||
reuse and
|
||||
event.connection_spec_matches(connection) and
|
||||
@ -402,7 +402,7 @@ class HTTPLayer(layer.Layer):
|
||||
self.waiting_for_establishment[connection].append(event)
|
||||
else:
|
||||
stream = self.stream_by_command.pop(event)
|
||||
self.event_to_child(stream, GetHttpConnectionReply(event, (layer, None)))
|
||||
self.event_to_child(stream, GetHttpConnectionReply(event, (connection, None)))
|
||||
return
|
||||
|
||||
can_reuse_context_connection = (
|
||||
@ -414,8 +414,11 @@ class HTTPLayer(layer.Layer):
|
||||
layer = HttpClient(context)
|
||||
if not can_reuse_context_connection:
|
||||
context.server = Server(event.address)
|
||||
if context.options.http2:
|
||||
context.server.alpn_offers = tls.HTTP_ALPNS
|
||||
else:
|
||||
context.server.alpn_offers = tls.HTTP1_ALPNS
|
||||
if event.tls:
|
||||
context.server.tls = True
|
||||
# TODO: This is a bit ugly, let's make up nicer syntax, e.g. using __truediv__
|
||||
orig = layer
|
||||
layer = tls.ServerTLSLayer(context)
|
||||
@ -432,7 +435,6 @@ class HTTPLayer(layer.Layer):
|
||||
|
||||
if command.err:
|
||||
reply = (None, command.err)
|
||||
self.connections.pop(command.connection)
|
||||
else:
|
||||
reply = (command.connection, None)
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from mitmproxy import http
|
||||
from mitmproxy.proxy2 import commands
|
||||
|
||||
|
@ -12,6 +12,7 @@ from mitmproxy.proxy2 import commands, events, layer
|
||||
from mitmproxy.proxy2.context import Client, Connection, Context, Server
|
||||
from mitmproxy.proxy2.layers.http._base import ReceiveHttp, StreamId
|
||||
from mitmproxy.proxy2.utils import expect
|
||||
from mitmproxy.utils import human
|
||||
from ._base import HttpConnection
|
||||
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
|
||||
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
|
||||
@ -131,7 +132,7 @@ class Http1Server(Http1Connection):
|
||||
elif isinstance(event, ResponseEndOfMessage):
|
||||
if "chunked" in self.response.headers.get("transfer-encoding", "").lower():
|
||||
yield commands.SendData(self.conn, b"0\r\n\r\n")
|
||||
elif http1.expected_http_body_size(self.request, self.response) == -1:
|
||||
elif http1.expected_http_body_size(self.request, self.response) == -1 or self.request.first_line_format == "authority":
|
||||
yield commands.CloseConnection(self.conn)
|
||||
yield from self.mark_done(response=True)
|
||||
elif isinstance(event, ResponseProtocolError):
|
||||
@ -162,7 +163,13 @@ class Http1Server(Http1Connection):
|
||||
request_head = self.buf.maybe_extract_lines()
|
||||
if request_head:
|
||||
request_head = [bytes(x) for x in request_head] # TODO: Make url.parse compatible with bytearrays
|
||||
try:
|
||||
self.request = http.HTTPRequest.wrap(http1_sansio.read_request_head(request_head))
|
||||
except ValueError as e:
|
||||
yield commands.Log(f"{human.format_address(self.conn.address)}: {e}")
|
||||
yield commands.CloseConnection(self.conn)
|
||||
self.state = self.wait
|
||||
return
|
||||
yield ReceiveHttp(RequestHeaders(self.stream_id, self.request))
|
||||
|
||||
if self.request.first_line_format == "authority":
|
||||
@ -181,6 +188,7 @@ class Http1Server(Http1Connection):
|
||||
elif isinstance(event, events.ConnectionClosed):
|
||||
if bytes(self.buf).strip():
|
||||
yield commands.Log(f"Client closed connection before sending request headers: {bytes(self.buf)}")
|
||||
yield commands.Log(f"Receive Buffer: {bytes(self.buf)}", level="debug")
|
||||
yield commands.CloseConnection(self.conn)
|
||||
else:
|
||||
raise ValueError(f"Unexpected event: {event}")
|
||||
|
@ -2,6 +2,7 @@ from mitmproxy import platform
|
||||
from mitmproxy.net import server_spec
|
||||
from mitmproxy.proxy2 import commands, events, layer
|
||||
from mitmproxy.proxy2.context import Server
|
||||
from mitmproxy.proxy2.layers import tls
|
||||
from mitmproxy.proxy2.utils import expect
|
||||
|
||||
|
||||
@ -11,9 +12,10 @@ class ReverseProxy(layer.Layer):
|
||||
spec = server_spec.parse_with_mode(self.context.options.mode)[1]
|
||||
self.context.server = Server(spec.address)
|
||||
if spec.scheme not in ("http", "tcp"):
|
||||
self.context.server.tls = True
|
||||
if not self.context.options.keep_host_header:
|
||||
self.context.server.sni = spec.address[0]
|
||||
child_layer = tls.ServerTLSLayer(self.context)
|
||||
else:
|
||||
child_layer = layer.NextLayer(self.context)
|
||||
self._handle_event = child_layer.handle_event
|
||||
yield from child_layer.handle_event(event)
|
||||
|
@ -1,3 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from mitmproxy import flow, tcp
|
||||
|
@ -90,7 +90,8 @@ def parse_client_hello(data: bytes) -> Optional[net_tls.ClientHello]:
|
||||
return None
|
||||
|
||||
|
||||
HTTP_ALPNS = (b"h2", b"http/1.1", b"http/1.0", b"http/0.9")
|
||||
HTTP1_ALPNS = (b"http/1.1", b"http/1.0", b"http/0.9")
|
||||
HTTP_ALPNS = (b"h2",) + HTTP1_ALPNS
|
||||
|
||||
|
||||
# We need these classes as hooks can only have one argument at the moment.
|
||||
@ -122,10 +123,16 @@ class _TLSLayer(layer.Layer):
|
||||
tls: SSL.Connection = None
|
||||
"""The OpenSSL connection object"""
|
||||
child_layer: layer.Layer
|
||||
errored: bool = False
|
||||
"""Have we errored yet?"""
|
||||
|
||||
def __init__(self, context: context.Context, conn: context.Connection):
|
||||
super().__init__(context)
|
||||
|
||||
assert not conn.tls
|
||||
conn.tls = True
|
||||
self.conn = conn
|
||||
self.child_layer = layer.NextLayer(self.context)
|
||||
|
||||
def __repr__(self):
|
||||
if not self.tls:
|
||||
@ -138,8 +145,6 @@ class _TLSLayer(layer.Layer):
|
||||
|
||||
def start_tls(self, initial_data: bytes = b""):
|
||||
assert not self.tls
|
||||
assert self.conn.connected
|
||||
self.conn.tls = True
|
||||
|
||||
tls_start = TlsStartData(self.conn, self.context)
|
||||
yield TlsStartHook(tls_start)
|
||||
@ -227,7 +232,7 @@ class _TLSLayer(layer.Layer):
|
||||
)
|
||||
if close:
|
||||
self.conn.state &= ~context.ConnectionState.CAN_READ
|
||||
yield commands.Log(f"TLS close_notify {self.conn}")
|
||||
yield commands.Log(f"TLS close_notify {self.conn}", level="debug")
|
||||
yield from self.event_to_child(
|
||||
events.ConnectionClosed(self.conn)
|
||||
)
|
||||
@ -252,12 +257,13 @@ class _TLSLayer(layer.Layer):
|
||||
pass # We have already dispatched a ConnectionClosed to the child layer.
|
||||
else:
|
||||
yield from self.event_to_child(event)
|
||||
else:
|
||||
elif not self.errored:
|
||||
yield from self.on_handshake_error("connection closed without notice")
|
||||
else:
|
||||
yield from self.event_to_child(event)
|
||||
|
||||
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
|
||||
self.errored = True
|
||||
yield commands.CloseConnection(self.conn)
|
||||
|
||||
|
||||
@ -269,13 +275,12 @@ class ServerTLSLayer(_TLSLayer):
|
||||
|
||||
def __init__(self, context: context.Context):
|
||||
super().__init__(context, context.server)
|
||||
self.child_layer = layer.NextLayer(self.context)
|
||||
|
||||
@expect(events.Start)
|
||||
def state_start(self, _) -> layer.CommandGenerator[None]:
|
||||
self.context.server.tls = True
|
||||
def state_start(self, event) -> layer.CommandGenerator[None]:
|
||||
if self.context.server.connected:
|
||||
yield from self.start_tls()
|
||||
yield from self.event_to_child(event)
|
||||
self._handle_event = super()._handle_event
|
||||
|
||||
_handle_event = state_start
|
||||
@ -327,20 +332,12 @@ class ClientTLSLayer(_TLSLayer):
|
||||
|
||||
"""
|
||||
recv_buffer: bytearray
|
||||
server_tls_available: bool
|
||||
|
||||
def __init__(self, context: context.Context):
|
||||
assert isinstance(context.layers[-1], ServerTLSLayer)
|
||||
super().__init__(context, self.context.client)
|
||||
super().__init__(context, context.client)
|
||||
self.server_tls_available = isinstance(self.context.layers[-2], ServerTLSLayer)
|
||||
self.recv_buffer = bytearray()
|
||||
self.child_layer = layer.NextLayer(self.context)
|
||||
|
||||
@expect(events.Start)
|
||||
def state_start(self, _) -> layer.CommandGenerator[None]:
|
||||
self.context.client.tls = True
|
||||
self._handle_event = self.state_wait_for_clienthello
|
||||
yield from ()
|
||||
|
||||
_handle_event = state_start
|
||||
|
||||
def state_wait_for_clienthello(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, events.DataReceived) and event.connection == self.conn:
|
||||
@ -361,8 +358,8 @@ class ClientTLSLayer(_TLSLayer):
|
||||
if tls_clienthello.establish_server_tls_first and not self.context.server.tls_established:
|
||||
err = yield from self.start_server_tls()
|
||||
if err:
|
||||
yield commands.Log("Unable to establish TLS connection with server. "
|
||||
"Trying to establish TLS with client anyway.")
|
||||
yield commands.Log(f"Unable to establish TLS connection with server ({err}). "
|
||||
f"Trying to establish TLS with client anyway.")
|
||||
|
||||
yield from self.start_tls(bytes(self.recv_buffer))
|
||||
self.recv_buffer.clear()
|
||||
@ -373,14 +370,17 @@ class ClientTLSLayer(_TLSLayer):
|
||||
else:
|
||||
yield from self.event_to_child(event)
|
||||
|
||||
_handle_event = state_wait_for_clienthello
|
||||
|
||||
def start_server_tls(self) -> layer.CommandGenerator[Optional[str]]:
|
||||
"""
|
||||
We often need information from the upstream connection to establish TLS with the client.
|
||||
For example, we need to check if the client does ALPN or not.
|
||||
"""
|
||||
if not self.server_tls_available:
|
||||
return "No server TLS available."
|
||||
err = yield commands.OpenConnection(self.context.server)
|
||||
if err:
|
||||
yield commands.Log(f"Cannot establish server connection: {err}")
|
||||
return err
|
||||
else:
|
||||
return None
|
||||
@ -400,3 +400,14 @@ class ClientTLSLayer(_TLSLayer):
|
||||
level="warn"
|
||||
)
|
||||
yield from super().on_handshake_error(err)
|
||||
|
||||
|
||||
class MockTLSLayer(_TLSLayer):
|
||||
"""Mock layer to disable actual TLS and use cleartext in tests.
|
||||
|
||||
Use like so:
|
||||
monkeypatch.setattr(tls, "ServerTLSLayer", tls.MockTLSLayer)
|
||||
"""
|
||||
|
||||
def __init__(self, ctx: context.Context):
|
||||
super().__init__(ctx, context.Server(None))
|
||||
|
@ -15,6 +15,7 @@ import time
|
||||
import traceback
|
||||
import typing
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
|
||||
from OpenSSL import SSL
|
||||
|
||||
@ -27,11 +28,6 @@ from mitmproxy.utils import human
|
||||
from mitmproxy.utils.data import pkg_data
|
||||
|
||||
|
||||
class StreamIO(typing.NamedTuple):
|
||||
r: asyncio.StreamReader
|
||||
w: asyncio.StreamWriter
|
||||
|
||||
|
||||
class TimeoutWatchdog:
|
||||
last_activity: float
|
||||
CONNECTION_TIMEOUT = 10 * 60
|
||||
@ -72,8 +68,15 @@ class TimeoutWatchdog:
|
||||
self.can_timeout.set()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionIO:
|
||||
handler: typing.Optional[asyncio.Task] = None
|
||||
reader: typing.Optional[asyncio.StreamReader] = None
|
||||
writer: typing.Optional[asyncio.StreamWriter] = None
|
||||
|
||||
|
||||
class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
transports: typing.MutableMapping[Connection, StreamIO]
|
||||
transports: typing.MutableMapping[Connection, ConnectionIO]
|
||||
timeout_watchdog: TimeoutWatchdog
|
||||
|
||||
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, options: moptions.Options) -> None:
|
||||
@ -81,93 +84,86 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
|
||||
self.client = Client(addr)
|
||||
self.context = Context(self.client, options)
|
||||
self.layer = layer.NextLayer(self.context)
|
||||
self.timeout_watchdog = TimeoutWatchdog(self.on_timeout)
|
||||
self.transports = {
|
||||
self.client: ConnectionIO(handler=None, reader=reader, writer=writer)
|
||||
}
|
||||
|
||||
# Ask for the first layer right away.
|
||||
# In a reverse proxy scenario, this is necessary as we would otherwise hang
|
||||
# on protocols that start with a server greeting.
|
||||
self.layer.ask_now()
|
||||
self.layer = layer.NextLayer(self.context, ask_on_start=True)
|
||||
|
||||
self.transports = {
|
||||
self.client: StreamIO(reader, writer)
|
||||
}
|
||||
self.timeout_watchdog = TimeoutWatchdog(self.on_timeout)
|
||||
|
||||
async def handle_client(self) -> None:
|
||||
# FIXME: Work around log suppression in core.
|
||||
# Hack: Work around log suppression in core.
|
||||
logging.getLogger('asyncio').setLevel(logging.DEBUG)
|
||||
|
||||
asyncio.get_event_loop().set_debug(True)
|
||||
watch = asyncio.ensure_future(self.timeout_watchdog.watch())
|
||||
|
||||
self.log("[sans-io] clientconnect")
|
||||
|
||||
handler = asyncio.create_task(
|
||||
self.handle_connection(self.client)
|
||||
)
|
||||
self.transports[self.client].handler = handler
|
||||
self.server_event(events.Start())
|
||||
await self.handle_connection(self.client)
|
||||
await handler
|
||||
|
||||
self.log("[sans-io] clientdisconnected")
|
||||
watch.cancel()
|
||||
|
||||
if self.transports:
|
||||
self.log("[sans-io] closing transports...")
|
||||
await asyncio.wait([
|
||||
self.close_connection(x)
|
||||
for x in self.transports
|
||||
])
|
||||
for x in self.transports.values():
|
||||
x.handler.cancel()
|
||||
await asyncio.wait([x.handler for x in self.transports.values()])
|
||||
self.log("[sans-io] transports closed!")
|
||||
|
||||
async def on_timeout(self) -> None:
|
||||
self.log(f"Closing connection due to inactivity: {self.client}")
|
||||
await self.close_connection(self.client)
|
||||
|
||||
async def close_connection(self, connection: Connection) -> None:
|
||||
self.log(f"closing {connection}", "debug")
|
||||
connection.state = ConnectionState.CLOSED
|
||||
io = self.transports.pop(connection)
|
||||
io.w.close()
|
||||
await io.w.wait_closed()
|
||||
|
||||
async def shutdown_connection(self, connection: Connection) -> None:
|
||||
assert connection.state & ConnectionState.CAN_WRITE
|
||||
io = self.transports[connection]
|
||||
self.log(f"shutting down {connection}", "debug")
|
||||
|
||||
io.w.write_eof()
|
||||
connection.state &= ~ConnectionState.CAN_WRITE
|
||||
|
||||
async def handle_connection(self, connection: Connection) -> None:
|
||||
reader, writer = self.transports[connection]
|
||||
while True:
|
||||
try:
|
||||
data = await reader.read(65535)
|
||||
except socket.error:
|
||||
data = b""
|
||||
if data:
|
||||
self.server_event(events.DataReceived(connection, data))
|
||||
else:
|
||||
if connection.state is ConnectionState.CAN_READ:
|
||||
await self.close_connection(connection)
|
||||
else:
|
||||
connection.state &= ~ConnectionState.CAN_READ
|
||||
self.server_event(events.ConnectionClosed(connection))
|
||||
break
|
||||
|
||||
async def open_connection(self, command: commands.OpenConnection) -> None:
|
||||
if not command.connection.address:
|
||||
raise ValueError("Cannot open connection, no hostname given.")
|
||||
assert command.connection not in self.transports
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection(
|
||||
*command.connection.address
|
||||
)
|
||||
except IOError as e:
|
||||
reader, writer = await asyncio.open_connection(*command.connection.address)
|
||||
except (IOError, asyncio.CancelledError) as e:
|
||||
self.server_event(events.OpenConnectionReply(command, str(e)))
|
||||
else:
|
||||
self.log("serverconnect")
|
||||
self.transports[command.connection] = StreamIO(reader, writer)
|
||||
self.log(f"serverconnect {command.connection.address}")
|
||||
self.transports[command.connection].reader = reader
|
||||
self.transports[command.connection].writer = writer
|
||||
command.connection.state = ConnectionState.OPEN
|
||||
self.server_event(events.OpenConnectionReply(command, None))
|
||||
try:
|
||||
await self.handle_connection(command.connection)
|
||||
finally:
|
||||
self.log("serverdisconnected")
|
||||
|
||||
async def handle_connection(self, connection: Connection) -> None:
|
||||
reader = self.transports[connection].reader
|
||||
assert reader
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
data = await reader.read(65535)
|
||||
except (socket.error, asyncio.CancelledError):
|
||||
data = b""
|
||||
if data:
|
||||
self.server_event(events.DataReceived(connection, data))
|
||||
else:
|
||||
connection.state &= ~ConnectionState.CAN_READ
|
||||
self.server_event(events.ConnectionClosed(connection))
|
||||
if connection.state is ConnectionState.CLOSED:
|
||||
self.transports[connection].handler.cancel()
|
||||
await asyncio.Event().wait() # wait for cancellation
|
||||
except asyncio.CancelledError:
|
||||
connection.state = ConnectionState.CLOSED
|
||||
io = self.transports.pop(connection)
|
||||
io.writer.close()
|
||||
|
||||
async def on_timeout(self) -> None:
|
||||
self.log(f"Closing connection due to inactivity: {self.client}")
|
||||
self.transports[self.client].handler.cancel()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def handle_hook(self, hook: commands.Hook) -> None:
|
||||
pass
|
||||
@ -180,31 +176,24 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
try:
|
||||
layer_commands = self.layer.handle_event(event)
|
||||
for command in layer_commands:
|
||||
|
||||
if isinstance(command, commands.OpenConnection):
|
||||
asyncio.ensure_future(
|
||||
assert command.connection not in self.transports
|
||||
handler = asyncio.create_task(
|
||||
self.open_connection(command)
|
||||
)
|
||||
self.transports[command.connection] = ConnectionIO(handler=handler)
|
||||
elif isinstance(command, commands.ConnectionCommand) and command.connection not in self.transports:
|
||||
return # The connection has already been closed.
|
||||
elif isinstance(command, commands.SendData):
|
||||
try:
|
||||
io = self.transports[command.connection]
|
||||
except KeyError:
|
||||
raise RuntimeError(f"Cannot write to closed connection: {command.connection}")
|
||||
else:
|
||||
io.w.write(command.data)
|
||||
self.transports[command.connection].writer.write(command.data)
|
||||
elif isinstance(command, commands.CloseConnection):
|
||||
if command.connection == self.client:
|
||||
asyncio.ensure_future(
|
||||
self.close_connection(command.connection)
|
||||
)
|
||||
else:
|
||||
asyncio.ensure_future(
|
||||
self.shutdown_connection(command.connection)
|
||||
)
|
||||
self.close_our_end(command.connection)
|
||||
elif isinstance(command, commands.GetSocket):
|
||||
socket = self.transports[command.connection].w.get_extra_info("socket")
|
||||
socket = self.transports[command.connection].writer.get_extra_info("socket")
|
||||
self.server_event(events.GetSocketReply(command, socket))
|
||||
elif isinstance(command, commands.Hook):
|
||||
asyncio.ensure_future(
|
||||
asyncio.create_task(
|
||||
self.handle_hook(command)
|
||||
)
|
||||
elif isinstance(command, commands.Log):
|
||||
@ -214,6 +203,22 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
except Exception:
|
||||
self.log(f"mitmproxy has crashed!\n{traceback.format_exc()}", level="error")
|
||||
|
||||
def close_our_end(self, connection):
|
||||
assert connection.state & ConnectionState.CAN_WRITE
|
||||
self.log(f"shutting down {connection}", "debug")
|
||||
try:
|
||||
self.transports[connection].writer.write_eof()
|
||||
except socket.error:
|
||||
connection.state = ConnectionState.CLOSED
|
||||
connection.state &= ~ConnectionState.CAN_WRITE
|
||||
|
||||
# if we are closing the client connection, we should destroy everything.
|
||||
if connection == self.client:
|
||||
self.transports[connection].handler.cancel()
|
||||
# If we have already received a close, let's finish everything.
|
||||
elif connection.state is ConnectionState.CLOSED:
|
||||
self.transports[connection].handler.cancel()
|
||||
|
||||
|
||||
class SimpleConnectionHandler(ConnectionHandler):
|
||||
"""Simple handler that does not really process any hooks."""
|
||||
@ -253,10 +258,10 @@ if __name__ == "__main__":
|
||||
async def handle(reader, writer):
|
||||
layer_stack = [
|
||||
lambda ctx: layers.ServerTLSLayer(ctx),
|
||||
lambda ctx: layers.HTTPLayer(ctx, HTTPMode.regular),
|
||||
lambda ctx: layers.HttpLayer(ctx, HTTPMode.regular),
|
||||
lambda ctx: setattr(ctx.server, "tls", True) or layers.ServerTLSLayer(ctx),
|
||||
lambda ctx: layers.ClientTLSLayer(ctx),
|
||||
lambda ctx: layers.HTTPLayer(ctx, HTTPMode.transparent)
|
||||
lambda ctx: layers.HttpLayer(ctx, HTTPMode.transparent)
|
||||
]
|
||||
|
||||
def next_layer(nl: layer.NextLayer):
|
||||
|
@ -6,13 +6,15 @@ import pytest
|
||||
from mitmproxy.net.websockets import Frame, OPCODE
|
||||
from mitmproxy.proxy2 import commands, events
|
||||
from mitmproxy.proxy2.layers.old import websocket
|
||||
|
||||
from mitmproxy.proxy2.context import ConnectionState
|
||||
from mitmproxy.test import tflow
|
||||
from .. import tutils
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ws_playbook(tctx):
|
||||
tctx.server.connected = True
|
||||
tctx.server.state = ConnectionState.OPEN
|
||||
playbook = tutils.Playbook(
|
||||
websocket.WebsocketLayer(
|
||||
tctx,
|
||||
|
@ -6,7 +6,7 @@ from mitmproxy.proxy2 import layer
|
||||
from mitmproxy.proxy2.commands import CloseConnection, OpenConnection, SendData
|
||||
from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
|
||||
from mitmproxy.proxy2.layers import http, tls
|
||||
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply, reply_establish_server_tls, reply_next_layer
|
||||
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply, reply_next_layer
|
||||
|
||||
|
||||
def test_http_proxy(tctx):
|
||||
@ -14,7 +14,7 @@ def test_http_proxy(tctx):
|
||||
server = Placeholder()
|
||||
flow = Placeholder()
|
||||
assert (
|
||||
Playbook(http.HTTPLayer(tctx, HTTPMode.regular))
|
||||
Playbook(http.HttpLayer(tctx, HTTPMode.regular))
|
||||
>> DataReceived(tctx.client, b"GET http://example.com/foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
<< http.HttpRequestHeadersHook(flow)
|
||||
>> reply()
|
||||
@ -39,7 +39,7 @@ def test_https_proxy(strategy, tctx):
|
||||
"""Test a CONNECT request, followed by a HTTP GET /"""
|
||||
server = Placeholder()
|
||||
flow = Placeholder()
|
||||
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular))
|
||||
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular))
|
||||
tctx.options.connection_strategy = strategy
|
||||
|
||||
(playbook
|
||||
@ -54,7 +54,7 @@ def test_https_proxy(strategy, tctx):
|
||||
<< SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n')
|
||||
>> DataReceived(tctx.client, b"GET /foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
<< layer.NextLayerHook(Placeholder())
|
||||
>> reply_next_layer(lambda ctx: http.HTTPLayer(ctx, HTTPMode.transparent))
|
||||
>> reply_next_layer(lambda ctx: http.HttpLayer(ctx, HTTPMode.transparent))
|
||||
<< http.HttpRequestHeadersHook(flow)
|
||||
>> reply()
|
||||
<< http.HttpRequestHook(flow)
|
||||
@ -77,12 +77,15 @@ def test_https_proxy(strategy, tctx):
|
||||
@pytest.mark.parametrize("https_client", [False, True])
|
||||
@pytest.mark.parametrize("https_server", [False, True])
|
||||
@pytest.mark.parametrize("strategy", ["lazy", "eager"])
|
||||
def test_redirect(strategy, https_server, https_client, tctx):
|
||||
def test_redirect(strategy, https_server, https_client, tctx, monkeypatch):
|
||||
"""Test redirects between http:// and https:// in regular proxy mode."""
|
||||
server = Placeholder()
|
||||
flow = Placeholder()
|
||||
tctx.options.connection_strategy = strategy
|
||||
p = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
p = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
|
||||
if https_server:
|
||||
monkeypatch.setattr(tls, "ServerTLSLayer", tls.MockTLSLayer)
|
||||
|
||||
def redirect(flow: HTTPFlow):
|
||||
if https_server:
|
||||
@ -98,16 +101,13 @@ def test_redirect(strategy, https_server, https_client, tctx):
|
||||
p << SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n')
|
||||
p >> DataReceived(tctx.client, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
p << layer.NextLayerHook(Placeholder())
|
||||
p >> reply_next_layer(lambda ctx: http.HTTPLayer(ctx, HTTPMode.transparent))
|
||||
p >> reply_next_layer(lambda ctx: http.HttpLayer(ctx, HTTPMode.transparent))
|
||||
else:
|
||||
p >> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
p << http.HttpRequestHook(flow)
|
||||
p >> reply(side_effect=redirect)
|
||||
p << OpenConnection(server)
|
||||
p >> reply(None)
|
||||
if https_server:
|
||||
pass # p << tls.EstablishServerTLS(server)
|
||||
# p >> reply_establish_server_tls()
|
||||
p << SendData(server, b"GET / HTTP/1.1\r\nHost: redirected.site\r\n\r\n")
|
||||
p >> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
|
||||
p << SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
|
||||
@ -123,7 +123,7 @@ def test_multiple_server_connections(tctx):
|
||||
"""Test multiple requests being rewritten to different targets."""
|
||||
server1 = Placeholder()
|
||||
server2 = Placeholder()
|
||||
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
|
||||
def redirect(to: str):
|
||||
def side_effect(flow: HTTPFlow):
|
||||
@ -164,7 +164,7 @@ def test_http_reply_from_proxy(tctx):
|
||||
flow.response = HTTPResponse.make(418)
|
||||
|
||||
assert (
|
||||
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
<< http.HttpRequestHook(Placeholder())
|
||||
>> reply(side_effect=reply_from_proxy)
|
||||
@ -176,7 +176,7 @@ def test_response_until_eof(tctx):
|
||||
"""Test scenario where the server response body is terminated by EOF."""
|
||||
server = Placeholder()
|
||||
assert (
|
||||
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
<< OpenConnection(server)
|
||||
>> reply(None)
|
||||
@ -197,7 +197,7 @@ def test_disconnect_while_intercept(tctx):
|
||||
flow = Placeholder()
|
||||
|
||||
assert (
|
||||
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
>> DataReceived(tctx.client, b"CONNECT example.com:80 HTTP/1.1\r\n\r\n")
|
||||
<< http.HttpConnectHook(Placeholder())
|
||||
>> reply()
|
||||
@ -206,7 +206,7 @@ def test_disconnect_while_intercept(tctx):
|
||||
<< SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n')
|
||||
>> DataReceived(tctx.client, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
<< layer.NextLayerHook(Placeholder())
|
||||
>> reply_next_layer(lambda ctx: http.HTTPLayer(ctx, HTTPMode.transparent))
|
||||
>> reply_next_layer(lambda ctx: http.HttpLayer(ctx, HTTPMode.transparent))
|
||||
<< http.HttpRequestHook(flow)
|
||||
>> ConnectionClosed(server1)
|
||||
>> reply(to=-2)
|
||||
@ -229,7 +229,7 @@ def test_response_streaming(tctx):
|
||||
flow.response.stream = lambda x: x.upper()
|
||||
|
||||
assert (
|
||||
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
>> DataReceived(tctx.client, b"GET http://example.com/largefile HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
<< OpenConnection(server)
|
||||
>> reply(None)
|
||||
@ -252,7 +252,7 @@ def test_request_streaming(tctx, response):
|
||||
"""
|
||||
server = Placeholder()
|
||||
flow = Placeholder()
|
||||
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
|
||||
def enable_streaming(flow: HTTPFlow):
|
||||
flow.request.stream = lambda x: x.upper()
|
||||
@ -324,7 +324,7 @@ def test_server_aborts(tctx, data):
|
||||
server = Placeholder()
|
||||
flow = Placeholder()
|
||||
err = Placeholder()
|
||||
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
assert (
|
||||
playbook
|
||||
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
|
@ -1,4 +1,5 @@
|
||||
from mitmproxy.proxy2.commands import CloseConnection, OpenConnection, SendData
|
||||
from mitmproxy.proxy2.context import ConnectionState
|
||||
from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
|
||||
from mitmproxy.proxy2.layers import tcp
|
||||
from ..tutils import Placeholder, Playbook, reply
|
||||
@ -14,7 +15,7 @@ def test_open_connection(tctx):
|
||||
<< OpenConnection(tctx.server)
|
||||
)
|
||||
|
||||
tctx.server.connected = True
|
||||
tctx.server.state = ConnectionState.OPEN
|
||||
assert (
|
||||
Playbook(tcp.TCPLayer(tctx, True))
|
||||
<< None
|
||||
|
@ -5,6 +5,7 @@ import pytest
|
||||
from OpenSSL import SSL
|
||||
|
||||
from mitmproxy.proxy2 import commands, context, events, layer
|
||||
from mitmproxy.proxy2.context import ConnectionState
|
||||
from mitmproxy.proxy2.layers import tls
|
||||
from mitmproxy.utils import data
|
||||
from test.mitmproxy.proxy2 import tutils
|
||||
@ -109,7 +110,6 @@ class TlsEchoLayer(tutils.EchoLayer):
|
||||
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, events.DataReceived) and event.data == b"open-connection":
|
||||
# noinspection PyTypeChecker
|
||||
err = yield commands.OpenConnection(self.context.server)
|
||||
if err:
|
||||
yield commands.SendData(event.connection, f"open-connection failed: {err}".encode())
|
||||
@ -180,23 +180,20 @@ def reply_tls_start(alpn: typing.Optional[bytes] = None, *args, **kwargs) -> tut
|
||||
|
||||
|
||||
class TestServerTLS:
|
||||
def test_no_tls(self, tctx: context.Context):
|
||||
"""Test TLS layer without TLS"""
|
||||
def test_not_connected(self, tctx: context.Context):
|
||||
"""Test that we don't do anything if no server connection exists."""
|
||||
layer = tls.ServerTLSLayer(tctx)
|
||||
layer.child_layer = TlsEchoLayer(tctx)
|
||||
|
||||
# Handshake
|
||||
assert (
|
||||
tutils.Playbook(layer)
|
||||
>> events.DataReceived(tctx.client, b"Hello World")
|
||||
<< commands.SendData(tctx.client, b"hello world")
|
||||
>> events.DataReceived(tctx.server, b"Foo")
|
||||
<< commands.SendData(tctx.server, b"foo")
|
||||
)
|
||||
|
||||
def test_simple(self, tctx):
|
||||
playbook = tutils.Playbook(tls.ServerTLSLayer(tctx))
|
||||
tctx.server.connected = True
|
||||
tctx.server.state = ConnectionState.OPEN
|
||||
tctx.server.address = ("example.mitmproxy.org", 443)
|
||||
tctx.server.sni = b"example.mitmproxy.org"
|
||||
|
||||
@ -250,7 +247,6 @@ class TestServerTLS:
|
||||
def test_untrusted_cert(self, tctx):
|
||||
"""If the certificate is not trusted, we should fail."""
|
||||
playbook = tutils.Playbook(tls.ServerTLSLayer(tctx))
|
||||
tctx.server.connected = True
|
||||
tctx.server.address = ("wrong.host.mitmproxy.org", 443)
|
||||
tctx.server.sni = b"wrong.host.mitmproxy.org"
|
||||
|
||||
@ -260,9 +256,11 @@ class TestServerTLS:
|
||||
data = tutils.Placeholder()
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, b"establish-server-tls")
|
||||
>> events.DataReceived(tctx.client, b"open-connection")
|
||||
<< layer.NextLayerHook(tutils.Placeholder())
|
||||
>> tutils.reply_next_layer(TlsEchoLayer)
|
||||
<< commands.OpenConnection(tctx.server)
|
||||
>> tutils.reply(None)
|
||||
<< tls.TlsStartHook(tutils.Placeholder())
|
||||
>> reply_tls_start()
|
||||
<< commands.SendData(tctx.server, data)
|
||||
@ -278,7 +276,8 @@ class TestServerTLS:
|
||||
>> events.DataReceived(tctx.server, tssl.out.read())
|
||||
<< commands.Log("Server TLS handshake failed. Certificate verify failed: Hostname mismatch", "warn")
|
||||
<< commands.CloseConnection(tctx.server)
|
||||
<< commands.SendData(tctx.client, b"server-tls-failed: Certificate verify failed: Hostname mismatch")
|
||||
<< commands.SendData(tctx.client,
|
||||
b"open-connection failed: Certificate verify failed: Hostname mismatch")
|
||||
)
|
||||
assert not tctx.server.tls_established
|
||||
|
||||
@ -334,10 +333,11 @@ class TestClientTLS:
|
||||
|
||||
# Echo
|
||||
_test_echo(playbook, tssl_client, tctx.client)
|
||||
other_server = context.Server(None)
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.server, b"Plaintext")
|
||||
<< commands.SendData(tctx.server, b"plaintext")
|
||||
>> events.DataReceived(other_server, b"Plaintext")
|
||||
<< commands.SendData(other_server, b"plaintext")
|
||||
)
|
||||
|
||||
def test_server_required(self, tctx):
|
||||
@ -419,6 +419,7 @@ class TestClientTLS:
|
||||
def test_mitmproxy_ca_is_untrusted(self, tctx: context.Context):
|
||||
"""Test the scenario where the client doesn't trust the mitmproxy CA."""
|
||||
playbook, client_layer, tssl_client = make_client_tls_layer(tctx, sni=b"wrong.host.mitmproxy.org")
|
||||
playbook.logs = True
|
||||
|
||||
data = tutils.Placeholder()
|
||||
assert (
|
||||
@ -440,6 +441,7 @@ class TestClientTLS:
|
||||
<< commands.Log("Client TLS handshake failed. The client does not trust the proxy's certificate "
|
||||
"for wrong.host.mitmproxy.org (sslv3 alert bad certificate)", "warn")
|
||||
<< commands.CloseConnection(tctx.client)
|
||||
>> events.ConnectionClosed(tctx.client)
|
||||
)
|
||||
assert not tctx.client.tls_established
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
@ -20,6 +21,7 @@ class TCommand(commands.Command):
|
||||
self.x = x
|
||||
|
||||
|
||||
@dataclass
|
||||
class TCommandReply(events.CommandReply):
|
||||
command: TCommand
|
||||
|
||||
@ -157,7 +159,7 @@ def test_command_reply(tplaybook):
|
||||
tplaybook
|
||||
>> TEvent()
|
||||
<< TCommand()
|
||||
>> tutils.reply(42)
|
||||
>> tutils.reply()
|
||||
)
|
||||
assert tplaybook.actual[1] == tplaybook.actual[2].command
|
||||
|
||||
|
@ -10,8 +10,7 @@ from mitmproxy.proxy2 import commands, context, layer
|
||||
from mitmproxy.proxy2 import events
|
||||
from mitmproxy.proxy2.context import ConnectionState
|
||||
from mitmproxy.proxy2.events import command_reply_subclasses
|
||||
from mitmproxy.proxy2.layer import Layer, NextLayer
|
||||
from mitmproxy.proxy2.layers import tls
|
||||
from mitmproxy.proxy2.layer import Layer
|
||||
|
||||
PlaybookEntry = typing.Union[commands.Command, events.Event]
|
||||
PlaybookEntryList = typing.List[PlaybookEntry]
|
||||
@ -101,7 +100,7 @@ class Playbook:
|
||||
|
||||
assert playbook(tcp.TCPLayer(tctx)) \
|
||||
<< commands.OpenConnection(tctx.server)
|
||||
>> events.OpenConnectionReply(-1, "ok") # -1 = reply to command in previous line.
|
||||
>> reply(None)
|
||||
<< None # this line is optional.
|
||||
|
||||
This is syntactic sugar for the following:
|
||||
@ -351,15 +350,3 @@ def reply_next_layer(
|
||||
next_layer.layer = child_layer(next_layer.context)
|
||||
|
||||
return reply(*args, side_effect=set_layer, **kwargs)
|
||||
|
||||
|
||||
def reply_establish_server_tls(**kwargs) -> reply:
|
||||
"""Helper function to simplify the syntax for EstablishServerTls events to this:
|
||||
<< tls.EstablishServerTLS(server)
|
||||
>> tutils.reply_establish_server_tls()
|
||||
"""
|
||||
|
||||
def fake_tls(cmd: tls.EstablishServerTLS) -> None:
|
||||
cmd.connection.tls_established = True
|
||||
|
||||
return reply(None, side_effect=fake_tls, **kwargs)
|
||||
|
@ -84,7 +84,7 @@ class TestConnectionHandler:
|
||||
def ask(_, x):
|
||||
raise RuntimeError
|
||||
|
||||
channel.ask = ask
|
||||
channel._ask = ask
|
||||
c = ConnectionHandler(
|
||||
mock.MagicMock(),
|
||||
("127.0.0.1", 8080),
|
||||
|
Loading…
Reference in New Issue
Block a user