[sans-io] fixes, fixes, fixes

This commit is contained in:
Maximilian Hils 2020-01-02 03:09:48 +01:00
parent b2060356b6
commit a30a6758f3
20 changed files with 236 additions and 215 deletions

View File

@ -86,12 +86,11 @@ class NextLayer:
def next_layer(self, nextlayer: layer.NextLayer): def next_layer(self, nextlayer: layer.NextLayer):
nextlayer.layer = self._next_layer(nextlayer.context, nextlayer.data_client()) nextlayer.layer = self._next_layer(nextlayer.context, nextlayer.data_client())
# nextlayer.layer.debug = " " * len(nextlayer.context.layers)
def _next_layer(self, context: context.Context, data_client: bytes) -> typing.Optional[layer.Layer]: def _next_layer(self, context: context.Context, data_client: bytes) -> typing.Optional[layer.Layer]:
if len(context.layers) == 0: if len(context.layers) == 0:
return self.make_top_layer(context) return self.make_top_layer(context)
if len(context.layers) == 1:
return layers.ServerTLSLayer(context)
if len(data_client) < 3: if len(data_client) < 3:
return return
@ -113,21 +112,14 @@ class NextLayer:
if isinstance(top_layer, layers.ServerTLSLayer): if isinstance(top_layer, layers.ServerTLSLayer):
return layers.ClientTLSLayer(context) return layers.ClientTLSLayer(context)
else: else:
if s(modes.HttpProxy):
# A "Secure Web Proxy" (https://www.chromium.org/developers/design-documents/secure-web-proxy)
# This does not imply TLS on the server side.
pass
else:
# In all other cases, client TLS implies TLS for both ends.
context.server.tls = True
return layers.ServerTLSLayer(context) return layers.ServerTLSLayer(context)
# 3. Setup the HTTP layer for a regular HTTP proxy or an upstream proxy. # 3. Setup the HTTP layer for a regular HTTP proxy or an upstream proxy.
if any([ if any([
s(modes.HttpProxy, layers.ServerTLSLayer), s(modes.HttpProxy),
s(modes.HttpProxy, layers.ServerTLSLayer, layers.ClientTLSLayer), s(modes.HttpProxy, layers.ClientTLSLayer),
]): ]):
return layers.HTTPLayer(context, HTTPMode.regular) return layers.HttpLayer(context, HTTPMode.regular)
if ctx.options.mode.startswith("upstream:") and len(context.layers) <= 3 and isinstance(top_layer, if ctx.options.mode.startswith("upstream:") and len(context.layers) <= 3 and isinstance(top_layer,
layers.ServerTLSLayer): layers.ServerTLSLayer):
raise NotImplementedError() raise NotImplementedError()
@ -153,7 +145,7 @@ class NextLayer:
return layers.TCPLayer(context) return layers.TCPLayer(context)
# 6. Assume HTTP by default. # 6. Assume HTTP by default.
return layers.HTTPLayer(context, HTTPMode.transparent) return layers.HttpLayer(context, HTTPMode.transparent)
def make_top_layer(self, context: context.Context) -> layer.Layer: def make_top_layer(self, context: context.Context) -> layer.Layer:
if ctx.options.mode == "regular": if ctx.options.mode == "regular":

View File

@ -13,9 +13,11 @@ from mitmproxy.proxy2.layers import tls
def alpn_select_callback(conn: SSL.Connection, options): def alpn_select_callback(conn: SSL.Connection, options):
server_alpn = conn.get_app_data()["server_alpn"] server_alpn = conn.get_app_data()["server_alpn"]
http2 = conn.get_app_data()["http2"]
if server_alpn and server_alpn in options: if server_alpn and server_alpn in options:
return server_alpn return server_alpn
for alpn in tls.HTTP_ALPNS: http_alpns = tls.HTTP_ALPNS if http2 else tls.HTTP1_ALPNS
for alpn in http_alpns:
if alpn in options: if alpn in options:
return alpn return alpn
else: else:
@ -76,11 +78,11 @@ class TlsConfig:
) )
tls_start.ssl_conn = SSL.Connection(ssl_ctx) tls_start.ssl_conn = SSL.Connection(ssl_ctx)
tls_start.ssl_conn.set_app_data({ tls_start.ssl_conn.set_app_data({
"server_alpn": tls_start.context.server.alpn "server_alpn": tls_start.context.server.alpn,
"http2": ctx.options.http2,
}) })
tls_start.ssl_conn.set_accept_state() tls_start.ssl_conn.set_accept_state()
def create_proxy_server_ssl_conn(self, tls_start: tls.TlsStartData) -> None: def create_proxy_server_ssl_conn(self, tls_start: tls.TlsStartData) -> None:
client = tls_start.context.client client = tls_start.context.client
server: context.Server = tls_start.conn server: context.Server = tls_start.conn
@ -120,7 +122,7 @@ class TlsConfig:
if os.path.exists(path): if os.path.exists(path):
client_cert = path client_cert = path
args["cipher_list"] = ':'.join(server.cipher_list) if server.cipher_list else None args["cipher_list"] = b':'.join(server.cipher_list) if server.cipher_list else None
ssl_ctx = net_tls.create_client_context( ssl_ctx = net_tls.create_client_context(
cert=client_cert, cert=client_cert,
sni=server.sni.decode("idna"), # FIXME: Should pass-through here. sni=server.sni.decode("idna"), # FIXME: Should pass-through here.

View File

@ -1,5 +1,5 @@
from enum import Flag, auto from enum import Flag, auto
from typing import List, Optional, Sequence, Union from typing import List, Literal, Optional, Sequence, Union
from mitmproxy import certs from mitmproxy import certs
from mitmproxy.options import Options from mitmproxy.options import Options
@ -16,7 +16,7 @@ class Connection:
""" """
Connections exposed to the layers only contain metadata, no socket objects. Connections exposed to the layers only contain metadata, no socket objects.
""" """
address: tuple address: Optional[tuple]
state: ConnectionState state: ConnectionState
tls: bool = False tls: bool = False
tls_established: bool = False tls_established: bool = False
@ -25,7 +25,7 @@ class Connection:
alpn_offers: Sequence[bytes] = () alpn_offers: Sequence[bytes] = ()
cipher_list: Sequence[bytes] = () cipher_list: Sequence[bytes] = ()
tls_version: Optional[str] = None tls_version: Optional[str] = None
sni: Union[bytes, bool, None] sni: Union[bytes, Literal[True], None]
timestamp_tls_setup: Optional[float] = None timestamp_tls_setup: Optional[float] = None
@ -33,21 +33,17 @@ class Connection:
def connected(self): def connected(self):
return self.state is ConnectionState.OPEN return self.state is ConnectionState.OPEN
@connected.setter
def connected(self, val: bool) -> None:
# We should really set .state, but verdict is still due if we even want to keep .state around.
# We allow setting .connected while we figure that out.
if val:
self.state = ConnectionState.OPEN
else:
self.state = ConnectionState.CLOSED
def __repr__(self): def __repr__(self):
return f"{type(self).__name__}({repr(self.__dict__)})" attrs = repr({
k: {"cipher_list": lambda: f"<{len(v)} ciphers>"}.get(k,lambda: v)()
for k, v in self.__dict__.items()
})
return f"{type(self).__name__}({attrs})"
class Client(Connection): class Client(Connection):
sni: Optional[bytes] = None sni: Union[bytes, None] = None
address: tuple
def __init__(self, address): def __init__(self, address):
self.address = address self.address = address
@ -55,9 +51,9 @@ class Client(Connection):
class Server(Connection): class Server(Connection):
sni: Union[bytes, bool] = True sni = True
"""True: client SNI, False: no SNI, bytes: custom value""" """True: client SNI, False: no SNI, bytes: custom value"""
address: Optional[tuple] via: Optional["Server"] = None
def __init__(self, address: Optional[tuple]): def __init__(self, address: Optional[tuple]):
self.address = address self.address = address

View File

@ -5,7 +5,7 @@ The counterpart to events are commands.
""" """
import socket import socket
import typing import typing
from dataclasses import dataclass from dataclasses import dataclass, is_dataclass
from mitmproxy.proxy2 import commands from mitmproxy.proxy2 import commands
from mitmproxy.proxy2.context import Connection from mitmproxy.proxy2.context import Connection
@ -55,7 +55,6 @@ class DataReceived(ConnectionEvent):
return f"DataReceived({target}, {self.data})" return f"DataReceived({target}, {self.data})"
@dataclass
class CommandReply(Event): class CommandReply(Event):
""" """
Emitted when a command has been finished, e.g. Emitted when a command has been finished, e.g.
@ -65,6 +64,7 @@ class CommandReply(Event):
reply: typing.Any reply: typing.Any
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
assert is_dataclass(cls)
if cls is CommandReply: if cls is CommandReply:
raise TypeError("CommandReply may not be instantiated directly.") raise TypeError("CommandReply may not be instantiated directly.")
return super().__new__(cls) return super().__new__(cls)

View File

@ -5,6 +5,7 @@ import collections
import textwrap import textwrap
import typing import typing
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass
from mitmproxy import log from mitmproxy import log
from mitmproxy.proxy2 import commands, events from mitmproxy.proxy2 import commands, events
@ -149,11 +150,14 @@ class NextLayer(Layer):
events: typing.List[mevents.Event] events: typing.List[mevents.Event]
"""All events that happened before a decision was made.""" """All events that happened before a decision was made."""
def __init__(self, context: Context) -> None: _ask_on_start: bool
def __init__(self, context: Context, ask_on_start: bool = False) -> None:
super().__init__(context) super().__init__(context)
self.context.layers.remove(self) self.context.layers.remove(self)
self.events = []
self.layer = None self.layer = None
self.events = []
self._ask_on_start = ask_on_start
self._handle = None self._handle = None
def __repr__(self): def __repr__(self):
@ -169,11 +173,13 @@ class NextLayer(Layer):
self.events.append(event) self.events.append(event)
# We receive new data. Let's find out if we can determine the next layer now? # We receive new data. Let's find out if we can determine the next layer now?
if isinstance(event, mevents.DataReceived): if self._ask_on_start and isinstance(event, events.Start):
yield from self._ask()
elif isinstance(event, mevents.DataReceived):
# For now, we only ask if we have received new data to reduce hook noise. # For now, we only ask if we have received new data to reduce hook noise.
yield from self.ask_now() yield from self._ask()
def ask_now(self): def _ask(self):
""" """
Manually trigger a next_layer hook. Manually trigger a next_layer hook.
The only use at the moment is to make sure that the top layer is initialized. The only use at the moment is to make sure that the top layer is initialized.

View File

@ -1,11 +1,11 @@
from . import modes from . import modes
from .http import HTTPLayer from .http import HttpLayer
from .tcp import TCPLayer from .tcp import TCPLayer
from .tls import ClientTLSLayer, ServerTLSLayer from .tls import ClientTLSLayer, ServerTLSLayer
__all__ = [ __all__ = [
"modes", "modes",
"HTTPLayer", "HttpLayer",
"TCPLayer", "TCPLayer",
"ClientTLSLayer", "ServerTLSLayer", "ClientTLSLayer", "ServerTLSLayer",
] ]

View File

@ -1,5 +1,6 @@
import collections import collections
import typing import typing
from dataclasses import dataclass
from mitmproxy import flow, http from mitmproxy import flow, http
from mitmproxy.proxy.protocol.http import HTTPMode from mitmproxy.proxy.protocol.http import HTTPMode
@ -8,7 +9,7 @@ from mitmproxy.proxy2.context import Connection, Context, Server
from mitmproxy.proxy2.layers import tls from mitmproxy.proxy2.layers import tls
from mitmproxy.proxy2.utils import expect from mitmproxy.proxy2.utils import expect
from mitmproxy.utils import human from mitmproxy.utils import human
from ._base import HttpConnection, StreamId, HttpCommand, ReceiveHttp from ._base import HttpCommand, HttpConnection, ReceiveHttp, StreamId
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \ from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \ from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \
@ -37,6 +38,7 @@ class GetHttpConnection(HttpCommand):
) )
@dataclass
class GetHttpConnectionReply(events.CommandReply): class GetHttpConnectionReply(events.CommandReply):
command: GetHttpConnection command: GetHttpConnection
reply: typing.Tuple[typing.Optional[Connection], typing.Optional[str]] reply: typing.Tuple[typing.Optional[Connection], typing.Optional[str]]
@ -67,8 +69,6 @@ class SendHttp(HttpCommand):
return f"Send({self.event})" return f"Send({self.event})"
class HttpStream(layer.Layer): class HttpStream(layer.Layer):
request_body_buf: bytes request_body_buf: bytes
response_body_buf: bytes response_body_buf: bytes
@ -78,7 +78,7 @@ class HttpStream(layer.Layer):
@property @property
def mode(self): def mode(self):
parent: HTTPLayer = self.context.layers[-2] parent: HttpLayer = self.context.layers[-2]
return parent.mode return parent.mode
def __init__(self, context: Context): def __init__(self, context: Context):
@ -309,7 +309,7 @@ class HttpStream(layer.Layer):
yield from () yield from ()
class HTTPLayer(layer.Layer): class HttpLayer(layer.Layer):
""" """
ConnectionEvent: We have received b"GET /\r\n\r\n" from the client. ConnectionEvent: We have received b"GET /\r\n\r\n" from the client.
HttpEvent: We have received request headers HttpEvent: We have received request headers
@ -339,22 +339,22 @@ class HTTPLayer(layer.Layer):
} }
def __repr__(self): def __repr__(self):
return f"HTTPLayer(conns: {len(self.connections)}, queue: {[type(e).__name__ for e in self.command_queue]})" return f"HttpLayer(conns: {len(self.connections)}, queue: {[type(e).__name__ for e in self.command_queue]})"
def _handle_event(self, event: events.Event): def _handle_event(self, event: events.Event):
if isinstance(event, events.Start): if isinstance(event, events.Start):
return return
elif isinstance(event, events.CommandReply): elif isinstance(event, events.CommandReply):
try:
stream = self.stream_by_command.pop(event.command) stream = self.stream_by_command.pop(event.command)
except KeyError:
raise
self.event_to_child(stream, event) self.event_to_child(stream, event)
elif isinstance(event, events.ConnectionEvent): elif isinstance(event, events.ConnectionEvent):
if event.connection == self.context.server and self.context.server not in self.connections: if event.connection == self.context.server and self.context.server not in self.connections:
pass pass
else: else:
try:
handler = self.connections[event.connection] handler = self.connections[event.connection]
except KeyError:
raise
self.event_to_child(handler, event) self.event_to_child(handler, event)
else: else:
raise ValueError(f"Unexpected event: {event}") raise ValueError(f"Unexpected event: {event}")
@ -388,7 +388,7 @@ class HTTPLayer(layer.Layer):
def get_connection(self, event: GetHttpConnection, *, reuse: bool = True): def get_connection(self, event: GetHttpConnection, *, reuse: bool = True):
# Do we already have a connection we can re-use? # Do we already have a connection we can re-use?
for connection, layer in self.connections.items(): for connection in self.connections:
connection_suitable = ( connection_suitable = (
reuse and reuse and
event.connection_spec_matches(connection) and event.connection_spec_matches(connection) and
@ -402,7 +402,7 @@ class HTTPLayer(layer.Layer):
self.waiting_for_establishment[connection].append(event) self.waiting_for_establishment[connection].append(event)
else: else:
stream = self.stream_by_command.pop(event) stream = self.stream_by_command.pop(event)
self.event_to_child(stream, GetHttpConnectionReply(event, (layer, None))) self.event_to_child(stream, GetHttpConnectionReply(event, (connection, None)))
return return
can_reuse_context_connection = ( can_reuse_context_connection = (
@ -414,8 +414,11 @@ class HTTPLayer(layer.Layer):
layer = HttpClient(context) layer = HttpClient(context)
if not can_reuse_context_connection: if not can_reuse_context_connection:
context.server = Server(event.address) context.server = Server(event.address)
if context.options.http2:
context.server.alpn_offers = tls.HTTP_ALPNS
else:
context.server.alpn_offers = tls.HTTP1_ALPNS
if event.tls: if event.tls:
context.server.tls = True
# TODO: This is a bit ugly, let's make up nicer syntax, e.g. using __truediv__ # TODO: This is a bit ugly, let's make up nicer syntax, e.g. using __truediv__
orig = layer orig = layer
layer = tls.ServerTLSLayer(context) layer = tls.ServerTLSLayer(context)
@ -432,7 +435,6 @@ class HTTPLayer(layer.Layer):
if command.err: if command.err:
reply = (None, command.err) reply = (None, command.err)
self.connections.pop(command.connection)
else: else:
reply = (command.connection, None) reply = (command.connection, None)

View File

@ -1,3 +1,5 @@
from dataclasses import dataclass
from mitmproxy import http from mitmproxy import http
from mitmproxy.proxy2 import commands from mitmproxy.proxy2 import commands

View File

@ -12,6 +12,7 @@ from mitmproxy.proxy2 import commands, events, layer
from mitmproxy.proxy2.context import Client, Connection, Context, Server from mitmproxy.proxy2.context import Client, Connection, Context, Server
from mitmproxy.proxy2.layers.http._base import ReceiveHttp, StreamId from mitmproxy.proxy2.layers.http._base import ReceiveHttp, StreamId
from mitmproxy.proxy2.utils import expect from mitmproxy.proxy2.utils import expect
from mitmproxy.utils import human
from ._base import HttpConnection from ._base import HttpConnection
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \ from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
@ -131,7 +132,7 @@ class Http1Server(Http1Connection):
elif isinstance(event, ResponseEndOfMessage): elif isinstance(event, ResponseEndOfMessage):
if "chunked" in self.response.headers.get("transfer-encoding", "").lower(): if "chunked" in self.response.headers.get("transfer-encoding", "").lower():
yield commands.SendData(self.conn, b"0\r\n\r\n") yield commands.SendData(self.conn, b"0\r\n\r\n")
elif http1.expected_http_body_size(self.request, self.response) == -1: elif http1.expected_http_body_size(self.request, self.response) == -1 or self.request.first_line_format == "authority":
yield commands.CloseConnection(self.conn) yield commands.CloseConnection(self.conn)
yield from self.mark_done(response=True) yield from self.mark_done(response=True)
elif isinstance(event, ResponseProtocolError): elif isinstance(event, ResponseProtocolError):
@ -162,7 +163,13 @@ class Http1Server(Http1Connection):
request_head = self.buf.maybe_extract_lines() request_head = self.buf.maybe_extract_lines()
if request_head: if request_head:
request_head = [bytes(x) for x in request_head] # TODO: Make url.parse compatible with bytearrays request_head = [bytes(x) for x in request_head] # TODO: Make url.parse compatible with bytearrays
try:
self.request = http.HTTPRequest.wrap(http1_sansio.read_request_head(request_head)) self.request = http.HTTPRequest.wrap(http1_sansio.read_request_head(request_head))
except ValueError as e:
yield commands.Log(f"{human.format_address(self.conn.address)}: {e}")
yield commands.CloseConnection(self.conn)
self.state = self.wait
return
yield ReceiveHttp(RequestHeaders(self.stream_id, self.request)) yield ReceiveHttp(RequestHeaders(self.stream_id, self.request))
if self.request.first_line_format == "authority": if self.request.first_line_format == "authority":
@ -181,6 +188,7 @@ class Http1Server(Http1Connection):
elif isinstance(event, events.ConnectionClosed): elif isinstance(event, events.ConnectionClosed):
if bytes(self.buf).strip(): if bytes(self.buf).strip():
yield commands.Log(f"Client closed connection before sending request headers: {bytes(self.buf)}") yield commands.Log(f"Client closed connection before sending request headers: {bytes(self.buf)}")
yield commands.Log(f"Receive Buffer: {bytes(self.buf)}", level="debug")
yield commands.CloseConnection(self.conn) yield commands.CloseConnection(self.conn)
else: else:
raise ValueError(f"Unexpected event: {event}") raise ValueError(f"Unexpected event: {event}")

View File

@ -2,6 +2,7 @@ from mitmproxy import platform
from mitmproxy.net import server_spec from mitmproxy.net import server_spec
from mitmproxy.proxy2 import commands, events, layer from mitmproxy.proxy2 import commands, events, layer
from mitmproxy.proxy2.context import Server from mitmproxy.proxy2.context import Server
from mitmproxy.proxy2.layers import tls
from mitmproxy.proxy2.utils import expect from mitmproxy.proxy2.utils import expect
@ -11,9 +12,10 @@ class ReverseProxy(layer.Layer):
spec = server_spec.parse_with_mode(self.context.options.mode)[1] spec = server_spec.parse_with_mode(self.context.options.mode)[1]
self.context.server = Server(spec.address) self.context.server = Server(spec.address)
if spec.scheme not in ("http", "tcp"): if spec.scheme not in ("http", "tcp"):
self.context.server.tls = True
if not self.context.options.keep_host_header: if not self.context.options.keep_host_header:
self.context.server.sni = spec.address[0] self.context.server.sni = spec.address[0]
child_layer = tls.ServerTLSLayer(self.context)
else:
child_layer = layer.NextLayer(self.context) child_layer = layer.NextLayer(self.context)
self._handle_event = child_layer.handle_event self._handle_event = child_layer.handle_event
yield from child_layer.handle_event(event) yield from child_layer.handle_event(event)

View File

@ -1,3 +1,4 @@
from dataclasses import dataclass
from typing import Optional from typing import Optional
from mitmproxy import flow, tcp from mitmproxy import flow, tcp

View File

@ -90,7 +90,8 @@ def parse_client_hello(data: bytes) -> Optional[net_tls.ClientHello]:
return None return None
HTTP_ALPNS = (b"h2", b"http/1.1", b"http/1.0", b"http/0.9") HTTP1_ALPNS = (b"http/1.1", b"http/1.0", b"http/0.9")
HTTP_ALPNS = (b"h2",) + HTTP1_ALPNS
# We need these classes as hooks can only have one argument at the moment. # We need these classes as hooks can only have one argument at the moment.
@ -122,10 +123,16 @@ class _TLSLayer(layer.Layer):
tls: SSL.Connection = None tls: SSL.Connection = None
"""The OpenSSL connection object""" """The OpenSSL connection object"""
child_layer: layer.Layer child_layer: layer.Layer
errored: bool = False
"""Have we errored yet?"""
def __init__(self, context: context.Context, conn: context.Connection): def __init__(self, context: context.Context, conn: context.Connection):
super().__init__(context) super().__init__(context)
assert not conn.tls
conn.tls = True
self.conn = conn self.conn = conn
self.child_layer = layer.NextLayer(self.context)
def __repr__(self): def __repr__(self):
if not self.tls: if not self.tls:
@ -138,8 +145,6 @@ class _TLSLayer(layer.Layer):
def start_tls(self, initial_data: bytes = b""): def start_tls(self, initial_data: bytes = b""):
assert not self.tls assert not self.tls
assert self.conn.connected
self.conn.tls = True
tls_start = TlsStartData(self.conn, self.context) tls_start = TlsStartData(self.conn, self.context)
yield TlsStartHook(tls_start) yield TlsStartHook(tls_start)
@ -227,7 +232,7 @@ class _TLSLayer(layer.Layer):
) )
if close: if close:
self.conn.state &= ~context.ConnectionState.CAN_READ self.conn.state &= ~context.ConnectionState.CAN_READ
yield commands.Log(f"TLS close_notify {self.conn}") yield commands.Log(f"TLS close_notify {self.conn}", level="debug")
yield from self.event_to_child( yield from self.event_to_child(
events.ConnectionClosed(self.conn) events.ConnectionClosed(self.conn)
) )
@ -252,12 +257,13 @@ class _TLSLayer(layer.Layer):
pass # We have already dispatched a ConnectionClosed to the child layer. pass # We have already dispatched a ConnectionClosed to the child layer.
else: else:
yield from self.event_to_child(event) yield from self.event_to_child(event)
else: elif not self.errored:
yield from self.on_handshake_error("connection closed without notice") yield from self.on_handshake_error("connection closed without notice")
else: else:
yield from self.event_to_child(event) yield from self.event_to_child(event)
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]: def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
self.errored = True
yield commands.CloseConnection(self.conn) yield commands.CloseConnection(self.conn)
@ -269,13 +275,12 @@ class ServerTLSLayer(_TLSLayer):
def __init__(self, context: context.Context): def __init__(self, context: context.Context):
super().__init__(context, context.server) super().__init__(context, context.server)
self.child_layer = layer.NextLayer(self.context)
@expect(events.Start) @expect(events.Start)
def state_start(self, _) -> layer.CommandGenerator[None]: def state_start(self, event) -> layer.CommandGenerator[None]:
self.context.server.tls = True
if self.context.server.connected: if self.context.server.connected:
yield from self.start_tls() yield from self.start_tls()
yield from self.event_to_child(event)
self._handle_event = super()._handle_event self._handle_event = super()._handle_event
_handle_event = state_start _handle_event = state_start
@ -327,20 +332,12 @@ class ClientTLSLayer(_TLSLayer):
""" """
recv_buffer: bytearray recv_buffer: bytearray
server_tls_available: bool
def __init__(self, context: context.Context): def __init__(self, context: context.Context):
assert isinstance(context.layers[-1], ServerTLSLayer) super().__init__(context, context.client)
super().__init__(context, self.context.client) self.server_tls_available = isinstance(self.context.layers[-2], ServerTLSLayer)
self.recv_buffer = bytearray() self.recv_buffer = bytearray()
self.child_layer = layer.NextLayer(self.context)
@expect(events.Start)
def state_start(self, _) -> layer.CommandGenerator[None]:
self.context.client.tls = True
self._handle_event = self.state_wait_for_clienthello
yield from ()
_handle_event = state_start
def state_wait_for_clienthello(self, event: events.Event) -> layer.CommandGenerator[None]: def state_wait_for_clienthello(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, events.DataReceived) and event.connection == self.conn: if isinstance(event, events.DataReceived) and event.connection == self.conn:
@ -361,8 +358,8 @@ class ClientTLSLayer(_TLSLayer):
if tls_clienthello.establish_server_tls_first and not self.context.server.tls_established: if tls_clienthello.establish_server_tls_first and not self.context.server.tls_established:
err = yield from self.start_server_tls() err = yield from self.start_server_tls()
if err: if err:
yield commands.Log("Unable to establish TLS connection with server. " yield commands.Log(f"Unable to establish TLS connection with server ({err}). "
"Trying to establish TLS with client anyway.") f"Trying to establish TLS with client anyway.")
yield from self.start_tls(bytes(self.recv_buffer)) yield from self.start_tls(bytes(self.recv_buffer))
self.recv_buffer.clear() self.recv_buffer.clear()
@ -373,14 +370,17 @@ class ClientTLSLayer(_TLSLayer):
else: else:
yield from self.event_to_child(event) yield from self.event_to_child(event)
_handle_event = state_wait_for_clienthello
def start_server_tls(self) -> layer.CommandGenerator[Optional[str]]: def start_server_tls(self) -> layer.CommandGenerator[Optional[str]]:
""" """
We often need information from the upstream connection to establish TLS with the client. We often need information from the upstream connection to establish TLS with the client.
For example, we need to check if the client does ALPN or not. For example, we need to check if the client does ALPN or not.
""" """
if not self.server_tls_available:
return "No server TLS available."
err = yield commands.OpenConnection(self.context.server) err = yield commands.OpenConnection(self.context.server)
if err: if err:
yield commands.Log(f"Cannot establish server connection: {err}")
return err return err
else: else:
return None return None
@ -400,3 +400,14 @@ class ClientTLSLayer(_TLSLayer):
level="warn" level="warn"
) )
yield from super().on_handshake_error(err) yield from super().on_handshake_error(err)
class MockTLSLayer(_TLSLayer):
"""Mock layer to disable actual TLS and use cleartext in tests.
Use like so:
monkeypatch.setattr(tls, "ServerTLSLayer", tls.MockTLSLayer)
"""
def __init__(self, ctx: context.Context):
super().__init__(ctx, context.Server(None))

View File

@ -15,6 +15,7 @@ import time
import traceback import traceback
import typing import typing
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass
from OpenSSL import SSL from OpenSSL import SSL
@ -27,11 +28,6 @@ from mitmproxy.utils import human
from mitmproxy.utils.data import pkg_data from mitmproxy.utils.data import pkg_data
class StreamIO(typing.NamedTuple):
r: asyncio.StreamReader
w: asyncio.StreamWriter
class TimeoutWatchdog: class TimeoutWatchdog:
last_activity: float last_activity: float
CONNECTION_TIMEOUT = 10 * 60 CONNECTION_TIMEOUT = 10 * 60
@ -72,8 +68,15 @@ class TimeoutWatchdog:
self.can_timeout.set() self.can_timeout.set()
@dataclass
class ConnectionIO:
handler: typing.Optional[asyncio.Task] = None
reader: typing.Optional[asyncio.StreamReader] = None
writer: typing.Optional[asyncio.StreamWriter] = None
class ConnectionHandler(metaclass=abc.ABCMeta): class ConnectionHandler(metaclass=abc.ABCMeta):
transports: typing.MutableMapping[Connection, StreamIO] transports: typing.MutableMapping[Connection, ConnectionIO]
timeout_watchdog: TimeoutWatchdog timeout_watchdog: TimeoutWatchdog
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, options: moptions.Options) -> None: def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, options: moptions.Options) -> None:
@ -81,93 +84,86 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
self.client = Client(addr) self.client = Client(addr)
self.context = Context(self.client, options) self.context = Context(self.client, options)
self.layer = layer.NextLayer(self.context) self.transports = {
self.timeout_watchdog = TimeoutWatchdog(self.on_timeout) self.client: ConnectionIO(handler=None, reader=reader, writer=writer)
}
# Ask for the first layer right away. # Ask for the first layer right away.
# In a reverse proxy scenario, this is necessary as we would otherwise hang # In a reverse proxy scenario, this is necessary as we would otherwise hang
# on protocols that start with a server greeting. # on protocols that start with a server greeting.
self.layer.ask_now() self.layer = layer.NextLayer(self.context, ask_on_start=True)
self.transports = { self.timeout_watchdog = TimeoutWatchdog(self.on_timeout)
self.client: StreamIO(reader, writer)
}
async def handle_client(self) -> None: async def handle_client(self) -> None:
# FIXME: Work around log suppression in core. # Hack: Work around log suppression in core.
logging.getLogger('asyncio').setLevel(logging.DEBUG) logging.getLogger('asyncio').setLevel(logging.DEBUG)
asyncio.get_event_loop().set_debug(True)
watch = asyncio.ensure_future(self.timeout_watchdog.watch()) watch = asyncio.ensure_future(self.timeout_watchdog.watch())
self.log("[sans-io] clientconnect") self.log("[sans-io] clientconnect")
handler = asyncio.create_task(
self.handle_connection(self.client)
)
self.transports[self.client].handler = handler
self.server_event(events.Start()) self.server_event(events.Start())
await self.handle_connection(self.client) await handler
self.log("[sans-io] clientdisconnected") self.log("[sans-io] clientdisconnected")
watch.cancel() watch.cancel()
if self.transports: if self.transports:
self.log("[sans-io] closing transports...") self.log("[sans-io] closing transports...")
await asyncio.wait([ for x in self.transports.values():
self.close_connection(x) x.handler.cancel()
for x in self.transports await asyncio.wait([x.handler for x in self.transports.values()])
])
self.log("[sans-io] transports closed!") self.log("[sans-io] transports closed!")
async def on_timeout(self) -> None:
self.log(f"Closing connection due to inactivity: {self.client}")
await self.close_connection(self.client)
async def close_connection(self, connection: Connection) -> None:
self.log(f"closing {connection}", "debug")
connection.state = ConnectionState.CLOSED
io = self.transports.pop(connection)
io.w.close()
await io.w.wait_closed()
async def shutdown_connection(self, connection: Connection) -> None:
assert connection.state & ConnectionState.CAN_WRITE
io = self.transports[connection]
self.log(f"shutting down {connection}", "debug")
io.w.write_eof()
connection.state &= ~ConnectionState.CAN_WRITE
async def handle_connection(self, connection: Connection) -> None:
reader, writer = self.transports[connection]
while True:
try:
data = await reader.read(65535)
except socket.error:
data = b""
if data:
self.server_event(events.DataReceived(connection, data))
else:
if connection.state is ConnectionState.CAN_READ:
await self.close_connection(connection)
else:
connection.state &= ~ConnectionState.CAN_READ
self.server_event(events.ConnectionClosed(connection))
break
async def open_connection(self, command: commands.OpenConnection) -> None: async def open_connection(self, command: commands.OpenConnection) -> None:
if not command.connection.address: if not command.connection.address:
raise ValueError("Cannot open connection, no hostname given.") raise ValueError("Cannot open connection, no hostname given.")
assert command.connection not in self.transports
try: try:
reader, writer = await asyncio.open_connection( reader, writer = await asyncio.open_connection(*command.connection.address)
*command.connection.address except (IOError, asyncio.CancelledError) as e:
)
except IOError as e:
self.server_event(events.OpenConnectionReply(command, str(e))) self.server_event(events.OpenConnectionReply(command, str(e)))
else: else:
self.log("serverconnect") self.log(f"serverconnect {command.connection.address}")
self.transports[command.connection] = StreamIO(reader, writer) self.transports[command.connection].reader = reader
self.transports[command.connection].writer = writer
command.connection.state = ConnectionState.OPEN command.connection.state = ConnectionState.OPEN
self.server_event(events.OpenConnectionReply(command, None)) self.server_event(events.OpenConnectionReply(command, None))
try:
await self.handle_connection(command.connection) await self.handle_connection(command.connection)
finally:
self.log("serverdisconnected") self.log("serverdisconnected")
async def handle_connection(self, connection: Connection) -> None:
reader = self.transports[connection].reader
assert reader
try:
while True:
try:
data = await reader.read(65535)
except (socket.error, asyncio.CancelledError):
data = b""
if data:
self.server_event(events.DataReceived(connection, data))
else:
connection.state &= ~ConnectionState.CAN_READ
self.server_event(events.ConnectionClosed(connection))
if connection.state is ConnectionState.CLOSED:
self.transports[connection].handler.cancel()
await asyncio.Event().wait() # wait for cancellation
except asyncio.CancelledError:
connection.state = ConnectionState.CLOSED
io = self.transports.pop(connection)
io.writer.close()
async def on_timeout(self) -> None:
self.log(f"Closing connection due to inactivity: {self.client}")
self.transports[self.client].handler.cancel()
@abc.abstractmethod @abc.abstractmethod
async def handle_hook(self, hook: commands.Hook) -> None: async def handle_hook(self, hook: commands.Hook) -> None:
pass pass
@ -180,31 +176,24 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
try: try:
layer_commands = self.layer.handle_event(event) layer_commands = self.layer.handle_event(event)
for command in layer_commands: for command in layer_commands:
if isinstance(command, commands.OpenConnection): if isinstance(command, commands.OpenConnection):
asyncio.ensure_future( assert command.connection not in self.transports
handler = asyncio.create_task(
self.open_connection(command) self.open_connection(command)
) )
self.transports[command.connection] = ConnectionIO(handler=handler)
elif isinstance(command, commands.ConnectionCommand) and command.connection not in self.transports:
return # The connection has already been closed.
elif isinstance(command, commands.SendData): elif isinstance(command, commands.SendData):
try: self.transports[command.connection].writer.write(command.data)
io = self.transports[command.connection]
except KeyError:
raise RuntimeError(f"Cannot write to closed connection: {command.connection}")
else:
io.w.write(command.data)
elif isinstance(command, commands.CloseConnection): elif isinstance(command, commands.CloseConnection):
if command.connection == self.client: self.close_our_end(command.connection)
asyncio.ensure_future(
self.close_connection(command.connection)
)
else:
asyncio.ensure_future(
self.shutdown_connection(command.connection)
)
elif isinstance(command, commands.GetSocket): elif isinstance(command, commands.GetSocket):
socket = self.transports[command.connection].w.get_extra_info("socket") socket = self.transports[command.connection].writer.get_extra_info("socket")
self.server_event(events.GetSocketReply(command, socket)) self.server_event(events.GetSocketReply(command, socket))
elif isinstance(command, commands.Hook): elif isinstance(command, commands.Hook):
asyncio.ensure_future( asyncio.create_task(
self.handle_hook(command) self.handle_hook(command)
) )
elif isinstance(command, commands.Log): elif isinstance(command, commands.Log):
@ -214,6 +203,22 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
except Exception: except Exception:
self.log(f"mitmproxy has crashed!\n{traceback.format_exc()}", level="error") self.log(f"mitmproxy has crashed!\n{traceback.format_exc()}", level="error")
def close_our_end(self, connection):
assert connection.state & ConnectionState.CAN_WRITE
self.log(f"shutting down {connection}", "debug")
try:
self.transports[connection].writer.write_eof()
except socket.error:
connection.state = ConnectionState.CLOSED
connection.state &= ~ConnectionState.CAN_WRITE
# if we are closing the client connection, we should destroy everything.
if connection == self.client:
self.transports[connection].handler.cancel()
# If we have already received a close, let's finish everything.
elif connection.state is ConnectionState.CLOSED:
self.transports[connection].handler.cancel()
class SimpleConnectionHandler(ConnectionHandler): class SimpleConnectionHandler(ConnectionHandler):
"""Simple handler that does not really process any hooks.""" """Simple handler that does not really process any hooks."""
@ -253,10 +258,10 @@ if __name__ == "__main__":
async def handle(reader, writer): async def handle(reader, writer):
layer_stack = [ layer_stack = [
lambda ctx: layers.ServerTLSLayer(ctx), lambda ctx: layers.ServerTLSLayer(ctx),
lambda ctx: layers.HTTPLayer(ctx, HTTPMode.regular), lambda ctx: layers.HttpLayer(ctx, HTTPMode.regular),
lambda ctx: setattr(ctx.server, "tls", True) or layers.ServerTLSLayer(ctx), lambda ctx: setattr(ctx.server, "tls", True) or layers.ServerTLSLayer(ctx),
lambda ctx: layers.ClientTLSLayer(ctx), lambda ctx: layers.ClientTLSLayer(ctx),
lambda ctx: layers.HTTPLayer(ctx, HTTPMode.transparent) lambda ctx: layers.HttpLayer(ctx, HTTPMode.transparent)
] ]
def next_layer(nl: layer.NextLayer): def next_layer(nl: layer.NextLayer):

View File

@ -6,13 +6,15 @@ import pytest
from mitmproxy.net.websockets import Frame, OPCODE from mitmproxy.net.websockets import Frame, OPCODE
from mitmproxy.proxy2 import commands, events from mitmproxy.proxy2 import commands, events
from mitmproxy.proxy2.layers.old import websocket from mitmproxy.proxy2.layers.old import websocket
from mitmproxy.proxy2.context import ConnectionState
from mitmproxy.test import tflow from mitmproxy.test import tflow
from .. import tutils from .. import tutils
@pytest.fixture @pytest.fixture
def ws_playbook(tctx): def ws_playbook(tctx):
tctx.server.connected = True tctx.server.state = ConnectionState.OPEN
playbook = tutils.Playbook( playbook = tutils.Playbook(
websocket.WebsocketLayer( websocket.WebsocketLayer(
tctx, tctx,

View File

@ -6,7 +6,7 @@ from mitmproxy.proxy2 import layer
from mitmproxy.proxy2.commands import CloseConnection, OpenConnection, SendData from mitmproxy.proxy2.commands import CloseConnection, OpenConnection, SendData
from mitmproxy.proxy2.events import ConnectionClosed, DataReceived from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
from mitmproxy.proxy2.layers import http, tls from mitmproxy.proxy2.layers import http, tls
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply, reply_establish_server_tls, reply_next_layer from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply, reply_next_layer
def test_http_proxy(tctx): def test_http_proxy(tctx):
@ -14,7 +14,7 @@ def test_http_proxy(tctx):
server = Placeholder() server = Placeholder()
flow = Placeholder() flow = Placeholder()
assert ( assert (
Playbook(http.HTTPLayer(tctx, HTTPMode.regular)) Playbook(http.HttpLayer(tctx, HTTPMode.regular))
>> DataReceived(tctx.client, b"GET http://example.com/foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n") >> DataReceived(tctx.client, b"GET http://example.com/foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< http.HttpRequestHeadersHook(flow) << http.HttpRequestHeadersHook(flow)
>> reply() >> reply()
@ -39,7 +39,7 @@ def test_https_proxy(strategy, tctx):
"""Test a CONNECT request, followed by a HTTP GET /""" """Test a CONNECT request, followed by a HTTP GET /"""
server = Placeholder() server = Placeholder()
flow = Placeholder() flow = Placeholder()
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular)) playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular))
tctx.options.connection_strategy = strategy tctx.options.connection_strategy = strategy
(playbook (playbook
@ -54,7 +54,7 @@ def test_https_proxy(strategy, tctx):
<< SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n') << SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n')
>> DataReceived(tctx.client, b"GET /foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n") >> DataReceived(tctx.client, b"GET /foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< layer.NextLayerHook(Placeholder()) << layer.NextLayerHook(Placeholder())
>> reply_next_layer(lambda ctx: http.HTTPLayer(ctx, HTTPMode.transparent)) >> reply_next_layer(lambda ctx: http.HttpLayer(ctx, HTTPMode.transparent))
<< http.HttpRequestHeadersHook(flow) << http.HttpRequestHeadersHook(flow)
>> reply() >> reply()
<< http.HttpRequestHook(flow) << http.HttpRequestHook(flow)
@ -77,12 +77,15 @@ def test_https_proxy(strategy, tctx):
@pytest.mark.parametrize("https_client", [False, True]) @pytest.mark.parametrize("https_client", [False, True])
@pytest.mark.parametrize("https_server", [False, True]) @pytest.mark.parametrize("https_server", [False, True])
@pytest.mark.parametrize("strategy", ["lazy", "eager"]) @pytest.mark.parametrize("strategy", ["lazy", "eager"])
def test_redirect(strategy, https_server, https_client, tctx): def test_redirect(strategy, https_server, https_client, tctx, monkeypatch):
"""Test redirects between http:// and https:// in regular proxy mode.""" """Test redirects between http:// and https:// in regular proxy mode."""
server = Placeholder() server = Placeholder()
flow = Placeholder() flow = Placeholder()
tctx.options.connection_strategy = strategy tctx.options.connection_strategy = strategy
p = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False) p = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
if https_server:
monkeypatch.setattr(tls, "ServerTLSLayer", tls.MockTLSLayer)
def redirect(flow: HTTPFlow): def redirect(flow: HTTPFlow):
if https_server: if https_server:
@ -98,16 +101,13 @@ def test_redirect(strategy, https_server, https_client, tctx):
p << SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n') p << SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n')
p >> DataReceived(tctx.client, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") p >> DataReceived(tctx.client, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
p << layer.NextLayerHook(Placeholder()) p << layer.NextLayerHook(Placeholder())
p >> reply_next_layer(lambda ctx: http.HTTPLayer(ctx, HTTPMode.transparent)) p >> reply_next_layer(lambda ctx: http.HttpLayer(ctx, HTTPMode.transparent))
else: else:
p >> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n") p >> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
p << http.HttpRequestHook(flow) p << http.HttpRequestHook(flow)
p >> reply(side_effect=redirect) p >> reply(side_effect=redirect)
p << OpenConnection(server) p << OpenConnection(server)
p >> reply(None) p >> reply(None)
if https_server:
pass # p << tls.EstablishServerTLS(server)
# p >> reply_establish_server_tls()
p << SendData(server, b"GET / HTTP/1.1\r\nHost: redirected.site\r\n\r\n") p << SendData(server, b"GET / HTTP/1.1\r\nHost: redirected.site\r\n\r\n")
p >> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!") p >> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
p << SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!") p << SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
@ -123,7 +123,7 @@ def test_multiple_server_connections(tctx):
"""Test multiple requests being rewritten to different targets.""" """Test multiple requests being rewritten to different targets."""
server1 = Placeholder() server1 = Placeholder()
server2 = Placeholder() server2 = Placeholder()
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False) playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
def redirect(to: str): def redirect(to: str):
def side_effect(flow: HTTPFlow): def side_effect(flow: HTTPFlow):
@ -164,7 +164,7 @@ def test_http_reply_from_proxy(tctx):
flow.response = HTTPResponse.make(418) flow.response = HTTPResponse.make(418)
assert ( assert (
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False) Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n") >> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< http.HttpRequestHook(Placeholder()) << http.HttpRequestHook(Placeholder())
>> reply(side_effect=reply_from_proxy) >> reply(side_effect=reply_from_proxy)
@ -176,7 +176,7 @@ def test_response_until_eof(tctx):
"""Test scenario where the server response body is terminated by EOF.""" """Test scenario where the server response body is terminated by EOF."""
server = Placeholder() server = Placeholder()
assert ( assert (
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False) Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n") >> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< OpenConnection(server) << OpenConnection(server)
>> reply(None) >> reply(None)
@ -197,7 +197,7 @@ def test_disconnect_while_intercept(tctx):
flow = Placeholder() flow = Placeholder()
assert ( assert (
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False) Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
>> DataReceived(tctx.client, b"CONNECT example.com:80 HTTP/1.1\r\n\r\n") >> DataReceived(tctx.client, b"CONNECT example.com:80 HTTP/1.1\r\n\r\n")
<< http.HttpConnectHook(Placeholder()) << http.HttpConnectHook(Placeholder())
>> reply() >> reply()
@ -206,7 +206,7 @@ def test_disconnect_while_intercept(tctx):
<< SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n') << SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n')
>> DataReceived(tctx.client, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") >> DataReceived(tctx.client, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< layer.NextLayerHook(Placeholder()) << layer.NextLayerHook(Placeholder())
>> reply_next_layer(lambda ctx: http.HTTPLayer(ctx, HTTPMode.transparent)) >> reply_next_layer(lambda ctx: http.HttpLayer(ctx, HTTPMode.transparent))
<< http.HttpRequestHook(flow) << http.HttpRequestHook(flow)
>> ConnectionClosed(server1) >> ConnectionClosed(server1)
>> reply(to=-2) >> reply(to=-2)
@ -229,7 +229,7 @@ def test_response_streaming(tctx):
flow.response.stream = lambda x: x.upper() flow.response.stream = lambda x: x.upper()
assert ( assert (
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False) Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
>> DataReceived(tctx.client, b"GET http://example.com/largefile HTTP/1.1\r\nHost: example.com\r\n\r\n") >> DataReceived(tctx.client, b"GET http://example.com/largefile HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< OpenConnection(server) << OpenConnection(server)
>> reply(None) >> reply(None)
@ -252,7 +252,7 @@ def test_request_streaming(tctx, response):
""" """
server = Placeholder() server = Placeholder()
flow = Placeholder() flow = Placeholder()
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False) playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
def enable_streaming(flow: HTTPFlow): def enable_streaming(flow: HTTPFlow):
flow.request.stream = lambda x: x.upper() flow.request.stream = lambda x: x.upper()
@ -324,7 +324,7 @@ def test_server_aborts(tctx, data):
server = Placeholder() server = Placeholder()
flow = Placeholder() flow = Placeholder()
err = Placeholder() err = Placeholder()
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False) playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
assert ( assert (
playbook playbook
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n") >> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")

View File

@ -1,4 +1,5 @@
from mitmproxy.proxy2.commands import CloseConnection, OpenConnection, SendData from mitmproxy.proxy2.commands import CloseConnection, OpenConnection, SendData
from mitmproxy.proxy2.context import ConnectionState
from mitmproxy.proxy2.events import ConnectionClosed, DataReceived from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
from mitmproxy.proxy2.layers import tcp from mitmproxy.proxy2.layers import tcp
from ..tutils import Placeholder, Playbook, reply from ..tutils import Placeholder, Playbook, reply
@ -14,7 +15,7 @@ def test_open_connection(tctx):
<< OpenConnection(tctx.server) << OpenConnection(tctx.server)
) )
tctx.server.connected = True tctx.server.state = ConnectionState.OPEN
assert ( assert (
Playbook(tcp.TCPLayer(tctx, True)) Playbook(tcp.TCPLayer(tctx, True))
<< None << None

View File

@ -5,6 +5,7 @@ import pytest
from OpenSSL import SSL from OpenSSL import SSL
from mitmproxy.proxy2 import commands, context, events, layer from mitmproxy.proxy2 import commands, context, events, layer
from mitmproxy.proxy2.context import ConnectionState
from mitmproxy.proxy2.layers import tls from mitmproxy.proxy2.layers import tls
from mitmproxy.utils import data from mitmproxy.utils import data
from test.mitmproxy.proxy2 import tutils from test.mitmproxy.proxy2 import tutils
@ -109,7 +110,6 @@ class TlsEchoLayer(tutils.EchoLayer):
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, events.DataReceived) and event.data == b"open-connection": if isinstance(event, events.DataReceived) and event.data == b"open-connection":
# noinspection PyTypeChecker
err = yield commands.OpenConnection(self.context.server) err = yield commands.OpenConnection(self.context.server)
if err: if err:
yield commands.SendData(event.connection, f"open-connection failed: {err}".encode()) yield commands.SendData(event.connection, f"open-connection failed: {err}".encode())
@ -180,23 +180,20 @@ def reply_tls_start(alpn: typing.Optional[bytes] = None, *args, **kwargs) -> tut
class TestServerTLS: class TestServerTLS:
def test_no_tls(self, tctx: context.Context): def test_not_connected(self, tctx: context.Context):
"""Test TLS layer without TLS""" """Test that we don't do anything if no server connection exists."""
layer = tls.ServerTLSLayer(tctx) layer = tls.ServerTLSLayer(tctx)
layer.child_layer = TlsEchoLayer(tctx) layer.child_layer = TlsEchoLayer(tctx)
# Handshake
assert ( assert (
tutils.Playbook(layer) tutils.Playbook(layer)
>> events.DataReceived(tctx.client, b"Hello World") >> events.DataReceived(tctx.client, b"Hello World")
<< commands.SendData(tctx.client, b"hello world") << commands.SendData(tctx.client, b"hello world")
>> events.DataReceived(tctx.server, b"Foo")
<< commands.SendData(tctx.server, b"foo")
) )
def test_simple(self, tctx): def test_simple(self, tctx):
playbook = tutils.Playbook(tls.ServerTLSLayer(tctx)) playbook = tutils.Playbook(tls.ServerTLSLayer(tctx))
tctx.server.connected = True tctx.server.state = ConnectionState.OPEN
tctx.server.address = ("example.mitmproxy.org", 443) tctx.server.address = ("example.mitmproxy.org", 443)
tctx.server.sni = b"example.mitmproxy.org" tctx.server.sni = b"example.mitmproxy.org"
@ -250,7 +247,6 @@ class TestServerTLS:
def test_untrusted_cert(self, tctx): def test_untrusted_cert(self, tctx):
"""If the certificate is not trusted, we should fail.""" """If the certificate is not trusted, we should fail."""
playbook = tutils.Playbook(tls.ServerTLSLayer(tctx)) playbook = tutils.Playbook(tls.ServerTLSLayer(tctx))
tctx.server.connected = True
tctx.server.address = ("wrong.host.mitmproxy.org", 443) tctx.server.address = ("wrong.host.mitmproxy.org", 443)
tctx.server.sni = b"wrong.host.mitmproxy.org" tctx.server.sni = b"wrong.host.mitmproxy.org"
@ -260,9 +256,11 @@ class TestServerTLS:
data = tutils.Placeholder() data = tutils.Placeholder()
assert ( assert (
playbook playbook
>> events.DataReceived(tctx.client, b"establish-server-tls") >> events.DataReceived(tctx.client, b"open-connection")
<< layer.NextLayerHook(tutils.Placeholder()) << layer.NextLayerHook(tutils.Placeholder())
>> tutils.reply_next_layer(TlsEchoLayer) >> tutils.reply_next_layer(TlsEchoLayer)
<< commands.OpenConnection(tctx.server)
>> tutils.reply(None)
<< tls.TlsStartHook(tutils.Placeholder()) << tls.TlsStartHook(tutils.Placeholder())
>> reply_tls_start() >> reply_tls_start()
<< commands.SendData(tctx.server, data) << commands.SendData(tctx.server, data)
@ -278,7 +276,8 @@ class TestServerTLS:
>> events.DataReceived(tctx.server, tssl.out.read()) >> events.DataReceived(tctx.server, tssl.out.read())
<< commands.Log("Server TLS handshake failed. Certificate verify failed: Hostname mismatch", "warn") << commands.Log("Server TLS handshake failed. Certificate verify failed: Hostname mismatch", "warn")
<< commands.CloseConnection(tctx.server) << commands.CloseConnection(tctx.server)
<< commands.SendData(tctx.client, b"server-tls-failed: Certificate verify failed: Hostname mismatch") << commands.SendData(tctx.client,
b"open-connection failed: Certificate verify failed: Hostname mismatch")
) )
assert not tctx.server.tls_established assert not tctx.server.tls_established
@ -334,10 +333,11 @@ class TestClientTLS:
# Echo # Echo
_test_echo(playbook, tssl_client, tctx.client) _test_echo(playbook, tssl_client, tctx.client)
other_server = context.Server(None)
assert ( assert (
playbook playbook
>> events.DataReceived(tctx.server, b"Plaintext") >> events.DataReceived(other_server, b"Plaintext")
<< commands.SendData(tctx.server, b"plaintext") << commands.SendData(other_server, b"plaintext")
) )
def test_server_required(self, tctx): def test_server_required(self, tctx):
@ -419,6 +419,7 @@ class TestClientTLS:
def test_mitmproxy_ca_is_untrusted(self, tctx: context.Context): def test_mitmproxy_ca_is_untrusted(self, tctx: context.Context):
"""Test the scenario where the client doesn't trust the mitmproxy CA.""" """Test the scenario where the client doesn't trust the mitmproxy CA."""
playbook, client_layer, tssl_client = make_client_tls_layer(tctx, sni=b"wrong.host.mitmproxy.org") playbook, client_layer, tssl_client = make_client_tls_layer(tctx, sni=b"wrong.host.mitmproxy.org")
playbook.logs = True
data = tutils.Placeholder() data = tutils.Placeholder()
assert ( assert (
@ -440,6 +441,7 @@ class TestClientTLS:
<< commands.Log("Client TLS handshake failed. The client does not trust the proxy's certificate " << commands.Log("Client TLS handshake failed. The client does not trust the proxy's certificate "
"for wrong.host.mitmproxy.org (sslv3 alert bad certificate)", "warn") "for wrong.host.mitmproxy.org (sslv3 alert bad certificate)", "warn")
<< commands.CloseConnection(tctx.client) << commands.CloseConnection(tctx.client)
>> events.ConnectionClosed(tctx.client)
) )
assert not tctx.client.tls_established assert not tctx.client.tls_established

View File

@ -1,4 +1,5 @@
import typing import typing
from dataclasses import dataclass
import pytest import pytest
@ -20,6 +21,7 @@ class TCommand(commands.Command):
self.x = x self.x = x
@dataclass
class TCommandReply(events.CommandReply): class TCommandReply(events.CommandReply):
command: TCommand command: TCommand
@ -157,7 +159,7 @@ def test_command_reply(tplaybook):
tplaybook tplaybook
>> TEvent() >> TEvent()
<< TCommand() << TCommand()
>> tutils.reply(42) >> tutils.reply()
) )
assert tplaybook.actual[1] == tplaybook.actual[2].command assert tplaybook.actual[1] == tplaybook.actual[2].command

View File

@ -10,8 +10,7 @@ from mitmproxy.proxy2 import commands, context, layer
from mitmproxy.proxy2 import events from mitmproxy.proxy2 import events
from mitmproxy.proxy2.context import ConnectionState from mitmproxy.proxy2.context import ConnectionState
from mitmproxy.proxy2.events import command_reply_subclasses from mitmproxy.proxy2.events import command_reply_subclasses
from mitmproxy.proxy2.layer import Layer, NextLayer from mitmproxy.proxy2.layer import Layer
from mitmproxy.proxy2.layers import tls
PlaybookEntry = typing.Union[commands.Command, events.Event] PlaybookEntry = typing.Union[commands.Command, events.Event]
PlaybookEntryList = typing.List[PlaybookEntry] PlaybookEntryList = typing.List[PlaybookEntry]
@ -101,7 +100,7 @@ class Playbook:
assert playbook(tcp.TCPLayer(tctx)) \ assert playbook(tcp.TCPLayer(tctx)) \
<< commands.OpenConnection(tctx.server) << commands.OpenConnection(tctx.server)
>> events.OpenConnectionReply(-1, "ok") # -1 = reply to command in previous line. >> reply(None)
<< None # this line is optional. << None # this line is optional.
This is syntactic sugar for the following: This is syntactic sugar for the following:
@ -351,15 +350,3 @@ def reply_next_layer(
next_layer.layer = child_layer(next_layer.context) next_layer.layer = child_layer(next_layer.context)
return reply(*args, side_effect=set_layer, **kwargs) return reply(*args, side_effect=set_layer, **kwargs)
def reply_establish_server_tls(**kwargs) -> reply:
"""Helper function to simplify the syntax for EstablishServerTls events to this:
<< tls.EstablishServerTLS(server)
>> tutils.reply_establish_server_tls()
"""
def fake_tls(cmd: tls.EstablishServerTLS) -> None:
cmd.connection.tls_established = True
return reply(None, side_effect=fake_tls, **kwargs)

View File

@ -84,7 +84,7 @@ class TestConnectionHandler:
def ask(_, x): def ask(_, x):
raise RuntimeError raise RuntimeError
channel.ask = ask channel._ask = ask
c = ConnectionHandler( c = ConnectionHandler(
mock.MagicMock(), mock.MagicMock(),
("127.0.0.1", 8080), ("127.0.0.1", 8080),