[sans-io] fixes, fixes, fixes

This commit is contained in:
Maximilian Hils 2020-01-02 03:09:48 +01:00
parent b2060356b6
commit a30a6758f3
20 changed files with 236 additions and 215 deletions

View File

@ -86,12 +86,11 @@ class NextLayer:
def next_layer(self, nextlayer: layer.NextLayer):
nextlayer.layer = self._next_layer(nextlayer.context, nextlayer.data_client())
# nextlayer.layer.debug = " " * len(nextlayer.context.layers)
def _next_layer(self, context: context.Context, data_client: bytes) -> typing.Optional[layer.Layer]:
if len(context.layers) == 0:
return self.make_top_layer(context)
if len(context.layers) == 1:
return layers.ServerTLSLayer(context)
if len(data_client) < 3:
return
@ -113,21 +112,14 @@ class NextLayer:
if isinstance(top_layer, layers.ServerTLSLayer):
return layers.ClientTLSLayer(context)
else:
if s(modes.HttpProxy):
# A "Secure Web Proxy" (https://www.chromium.org/developers/design-documents/secure-web-proxy)
# This does not imply TLS on the server side.
pass
else:
# In all other cases, client TLS implies TLS for both ends.
context.server.tls = True
return layers.ServerTLSLayer(context)
# 3. Setup the HTTP layer for a regular HTTP proxy or an upstream proxy.
if any([
s(modes.HttpProxy, layers.ServerTLSLayer),
s(modes.HttpProxy, layers.ServerTLSLayer, layers.ClientTLSLayer),
s(modes.HttpProxy),
s(modes.HttpProxy, layers.ClientTLSLayer),
]):
return layers.HTTPLayer(context, HTTPMode.regular)
return layers.HttpLayer(context, HTTPMode.regular)
if ctx.options.mode.startswith("upstream:") and len(context.layers) <= 3 and isinstance(top_layer,
layers.ServerTLSLayer):
raise NotImplementedError()
@ -153,7 +145,7 @@ class NextLayer:
return layers.TCPLayer(context)
# 6. Assume HTTP by default.
return layers.HTTPLayer(context, HTTPMode.transparent)
return layers.HttpLayer(context, HTTPMode.transparent)
def make_top_layer(self, context: context.Context) -> layer.Layer:
if ctx.options.mode == "regular":

View File

@ -13,9 +13,11 @@ from mitmproxy.proxy2.layers import tls
def alpn_select_callback(conn: SSL.Connection, options):
server_alpn = conn.get_app_data()["server_alpn"]
http2 = conn.get_app_data()["http2"]
if server_alpn and server_alpn in options:
return server_alpn
for alpn in tls.HTTP_ALPNS:
http_alpns = tls.HTTP_ALPNS if http2 else tls.HTTP1_ALPNS
for alpn in http_alpns:
if alpn in options:
return alpn
else:
@ -76,11 +78,11 @@ class TlsConfig:
)
tls_start.ssl_conn = SSL.Connection(ssl_ctx)
tls_start.ssl_conn.set_app_data({
"server_alpn": tls_start.context.server.alpn
"server_alpn": tls_start.context.server.alpn,
"http2": ctx.options.http2,
})
tls_start.ssl_conn.set_accept_state()
def create_proxy_server_ssl_conn(self, tls_start: tls.TlsStartData) -> None:
client = tls_start.context.client
server: context.Server = tls_start.conn
@ -120,7 +122,7 @@ class TlsConfig:
if os.path.exists(path):
client_cert = path
args["cipher_list"] = ':'.join(server.cipher_list) if server.cipher_list else None
args["cipher_list"] = b':'.join(server.cipher_list) if server.cipher_list else None
ssl_ctx = net_tls.create_client_context(
cert=client_cert,
sni=server.sni.decode("idna"), # FIXME: Should pass-through here.

View File

@ -1,5 +1,5 @@
from enum import Flag, auto
from typing import List, Optional, Sequence, Union
from typing import List, Literal, Optional, Sequence, Union
from mitmproxy import certs
from mitmproxy.options import Options
@ -16,7 +16,7 @@ class Connection:
"""
Connections exposed to the layers only contain metadata, no socket objects.
"""
address: tuple
address: Optional[tuple]
state: ConnectionState
tls: bool = False
tls_established: bool = False
@ -25,7 +25,7 @@ class Connection:
alpn_offers: Sequence[bytes] = ()
cipher_list: Sequence[bytes] = ()
tls_version: Optional[str] = None
sni: Union[bytes, bool, None]
sni: Union[bytes, Literal[True], None]
timestamp_tls_setup: Optional[float] = None
@ -33,21 +33,17 @@ class Connection:
def connected(self):
return self.state is ConnectionState.OPEN
@connected.setter
def connected(self, val: bool) -> None:
# We should really set .state, but verdict is still due if we even want to keep .state around.
# We allow setting .connected while we figure that out.
if val:
self.state = ConnectionState.OPEN
else:
self.state = ConnectionState.CLOSED
def __repr__(self):
return f"{type(self).__name__}({repr(self.__dict__)})"
attrs = repr({
k: {"cipher_list": lambda: f"<{len(v)} ciphers>"}.get(k,lambda: v)()
for k, v in self.__dict__.items()
})
return f"{type(self).__name__}({attrs})"
class Client(Connection):
sni: Optional[bytes] = None
sni: Union[bytes, None] = None
address: tuple
def __init__(self, address):
self.address = address
@ -55,9 +51,9 @@ class Client(Connection):
class Server(Connection):
sni: Union[bytes, bool] = True
sni = True
"""True: client SNI, False: no SNI, bytes: custom value"""
address: Optional[tuple]
via: Optional["Server"] = None
def __init__(self, address: Optional[tuple]):
self.address = address

View File

@ -5,7 +5,7 @@ The counterpart to events are commands.
"""
import socket
import typing
from dataclasses import dataclass
from dataclasses import dataclass, is_dataclass
from mitmproxy.proxy2 import commands
from mitmproxy.proxy2.context import Connection
@ -55,7 +55,6 @@ class DataReceived(ConnectionEvent):
return f"DataReceived({target}, {self.data})"
@dataclass
class CommandReply(Event):
"""
Emitted when a command has been finished, e.g.
@ -65,6 +64,7 @@ class CommandReply(Event):
reply: typing.Any
def __new__(cls, *args, **kwargs):
assert is_dataclass(cls)
if cls is CommandReply:
raise TypeError("CommandReply may not be instantiated directly.")
return super().__new__(cls)

View File

@ -5,6 +5,7 @@ import collections
import textwrap
import typing
from abc import abstractmethod
from dataclasses import dataclass
from mitmproxy import log
from mitmproxy.proxy2 import commands, events
@ -149,11 +150,14 @@ class NextLayer(Layer):
events: typing.List[mevents.Event]
"""All events that happened before a decision was made."""
def __init__(self, context: Context) -> None:
_ask_on_start: bool
def __init__(self, context: Context, ask_on_start: bool = False) -> None:
super().__init__(context)
self.context.layers.remove(self)
self.events = []
self.layer = None
self.events = []
self._ask_on_start = ask_on_start
self._handle = None
def __repr__(self):
@ -169,11 +173,13 @@ class NextLayer(Layer):
self.events.append(event)
# We receive new data. Let's find out if we can determine the next layer now?
if isinstance(event, mevents.DataReceived):
if self._ask_on_start and isinstance(event, events.Start):
yield from self._ask()
elif isinstance(event, mevents.DataReceived):
# For now, we only ask if we have received new data to reduce hook noise.
yield from self.ask_now()
yield from self._ask()
def ask_now(self):
def _ask(self):
"""
Manually trigger a next_layer hook.
The only use at the moment is to make sure that the top layer is initialized.

View File

@ -1,11 +1,11 @@
from . import modes
from .http import HTTPLayer
from .http import HttpLayer
from .tcp import TCPLayer
from .tls import ClientTLSLayer, ServerTLSLayer
__all__ = [
"modes",
"HTTPLayer",
"HttpLayer",
"TCPLayer",
"ClientTLSLayer", "ServerTLSLayer",
]

View File

@ -1,5 +1,6 @@
import collections
import typing
from dataclasses import dataclass
from mitmproxy import flow, http
from mitmproxy.proxy.protocol.http import HTTPMode
@ -8,7 +9,7 @@ from mitmproxy.proxy2.context import Connection, Context, Server
from mitmproxy.proxy2.layers import tls
from mitmproxy.proxy2.utils import expect
from mitmproxy.utils import human
from ._base import HttpConnection, StreamId, HttpCommand, ReceiveHttp
from ._base import HttpCommand, HttpConnection, ReceiveHttp, StreamId
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \
@ -37,6 +38,7 @@ class GetHttpConnection(HttpCommand):
)
@dataclass
class GetHttpConnectionReply(events.CommandReply):
command: GetHttpConnection
reply: typing.Tuple[typing.Optional[Connection], typing.Optional[str]]
@ -67,8 +69,6 @@ class SendHttp(HttpCommand):
return f"Send({self.event})"
class HttpStream(layer.Layer):
request_body_buf: bytes
response_body_buf: bytes
@ -78,7 +78,7 @@ class HttpStream(layer.Layer):
@property
def mode(self):
parent: HTTPLayer = self.context.layers[-2]
parent: HttpLayer = self.context.layers[-2]
return parent.mode
def __init__(self, context: Context):
@ -309,7 +309,7 @@ class HttpStream(layer.Layer):
yield from ()
class HTTPLayer(layer.Layer):
class HttpLayer(layer.Layer):
"""
ConnectionEvent: We have received b"GET /\r\n\r\n" from the client.
HttpEvent: We have received request headers
@ -339,22 +339,22 @@ class HTTPLayer(layer.Layer):
}
def __repr__(self):
return f"HTTPLayer(conns: {len(self.connections)}, queue: {[type(e).__name__ for e in self.command_queue]})"
return f"HttpLayer(conns: {len(self.connections)}, queue: {[type(e).__name__ for e in self.command_queue]})"
def _handle_event(self, event: events.Event):
if isinstance(event, events.Start):
return
elif isinstance(event, events.CommandReply):
try:
stream = self.stream_by_command.pop(event.command)
except KeyError:
raise
self.event_to_child(stream, event)
elif isinstance(event, events.ConnectionEvent):
if event.connection == self.context.server and self.context.server not in self.connections:
pass
else:
try:
handler = self.connections[event.connection]
except KeyError:
raise
self.event_to_child(handler, event)
else:
raise ValueError(f"Unexpected event: {event}")
@ -388,7 +388,7 @@ class HTTPLayer(layer.Layer):
def get_connection(self, event: GetHttpConnection, *, reuse: bool = True):
# Do we already have a connection we can re-use?
for connection, layer in self.connections.items():
for connection in self.connections:
connection_suitable = (
reuse and
event.connection_spec_matches(connection) and
@ -402,7 +402,7 @@ class HTTPLayer(layer.Layer):
self.waiting_for_establishment[connection].append(event)
else:
stream = self.stream_by_command.pop(event)
self.event_to_child(stream, GetHttpConnectionReply(event, (layer, None)))
self.event_to_child(stream, GetHttpConnectionReply(event, (connection, None)))
return
can_reuse_context_connection = (
@ -414,8 +414,11 @@ class HTTPLayer(layer.Layer):
layer = HttpClient(context)
if not can_reuse_context_connection:
context.server = Server(event.address)
if context.options.http2:
context.server.alpn_offers = tls.HTTP_ALPNS
else:
context.server.alpn_offers = tls.HTTP1_ALPNS
if event.tls:
context.server.tls = True
# TODO: This is a bit ugly, let's make up nicer syntax, e.g. using __truediv__
orig = layer
layer = tls.ServerTLSLayer(context)
@ -432,7 +435,6 @@ class HTTPLayer(layer.Layer):
if command.err:
reply = (None, command.err)
self.connections.pop(command.connection)
else:
reply = (command.connection, None)

View File

@ -1,3 +1,5 @@
from dataclasses import dataclass
from mitmproxy import http
from mitmproxy.proxy2 import commands

View File

@ -12,6 +12,7 @@ from mitmproxy.proxy2 import commands, events, layer
from mitmproxy.proxy2.context import Client, Connection, Context, Server
from mitmproxy.proxy2.layers.http._base import ReceiveHttp, StreamId
from mitmproxy.proxy2.utils import expect
from mitmproxy.utils import human
from ._base import HttpConnection
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
@ -131,7 +132,7 @@ class Http1Server(Http1Connection):
elif isinstance(event, ResponseEndOfMessage):
if "chunked" in self.response.headers.get("transfer-encoding", "").lower():
yield commands.SendData(self.conn, b"0\r\n\r\n")
elif http1.expected_http_body_size(self.request, self.response) == -1:
elif http1.expected_http_body_size(self.request, self.response) == -1 or self.request.first_line_format == "authority":
yield commands.CloseConnection(self.conn)
yield from self.mark_done(response=True)
elif isinstance(event, ResponseProtocolError):
@ -162,7 +163,13 @@ class Http1Server(Http1Connection):
request_head = self.buf.maybe_extract_lines()
if request_head:
request_head = [bytes(x) for x in request_head] # TODO: Make url.parse compatible with bytearrays
try:
self.request = http.HTTPRequest.wrap(http1_sansio.read_request_head(request_head))
except ValueError as e:
yield commands.Log(f"{human.format_address(self.conn.address)}: {e}")
yield commands.CloseConnection(self.conn)
self.state = self.wait
return
yield ReceiveHttp(RequestHeaders(self.stream_id, self.request))
if self.request.first_line_format == "authority":
@ -181,6 +188,7 @@ class Http1Server(Http1Connection):
elif isinstance(event, events.ConnectionClosed):
if bytes(self.buf).strip():
yield commands.Log(f"Client closed connection before sending request headers: {bytes(self.buf)}")
yield commands.Log(f"Receive Buffer: {bytes(self.buf)}", level="debug")
yield commands.CloseConnection(self.conn)
else:
raise ValueError(f"Unexpected event: {event}")

View File

@ -2,6 +2,7 @@ from mitmproxy import platform
from mitmproxy.net import server_spec
from mitmproxy.proxy2 import commands, events, layer
from mitmproxy.proxy2.context import Server
from mitmproxy.proxy2.layers import tls
from mitmproxy.proxy2.utils import expect
@ -11,9 +12,10 @@ class ReverseProxy(layer.Layer):
spec = server_spec.parse_with_mode(self.context.options.mode)[1]
self.context.server = Server(spec.address)
if spec.scheme not in ("http", "tcp"):
self.context.server.tls = True
if not self.context.options.keep_host_header:
self.context.server.sni = spec.address[0]
child_layer = tls.ServerTLSLayer(self.context)
else:
child_layer = layer.NextLayer(self.context)
self._handle_event = child_layer.handle_event
yield from child_layer.handle_event(event)

View File

@ -1,3 +1,4 @@
from dataclasses import dataclass
from typing import Optional
from mitmproxy import flow, tcp

View File

@ -90,7 +90,8 @@ def parse_client_hello(data: bytes) -> Optional[net_tls.ClientHello]:
return None
HTTP_ALPNS = (b"h2", b"http/1.1", b"http/1.0", b"http/0.9")
HTTP1_ALPNS = (b"http/1.1", b"http/1.0", b"http/0.9")
HTTP_ALPNS = (b"h2",) + HTTP1_ALPNS
# We need these classes as hooks can only have one argument at the moment.
@ -122,10 +123,16 @@ class _TLSLayer(layer.Layer):
tls: SSL.Connection = None
"""The OpenSSL connection object"""
child_layer: layer.Layer
errored: bool = False
"""Have we errored yet?"""
def __init__(self, context: context.Context, conn: context.Connection):
super().__init__(context)
assert not conn.tls
conn.tls = True
self.conn = conn
self.child_layer = layer.NextLayer(self.context)
def __repr__(self):
if not self.tls:
@ -138,8 +145,6 @@ class _TLSLayer(layer.Layer):
def start_tls(self, initial_data: bytes = b""):
assert not self.tls
assert self.conn.connected
self.conn.tls = True
tls_start = TlsStartData(self.conn, self.context)
yield TlsStartHook(tls_start)
@ -227,7 +232,7 @@ class _TLSLayer(layer.Layer):
)
if close:
self.conn.state &= ~context.ConnectionState.CAN_READ
yield commands.Log(f"TLS close_notify {self.conn}")
yield commands.Log(f"TLS close_notify {self.conn}", level="debug")
yield from self.event_to_child(
events.ConnectionClosed(self.conn)
)
@ -252,12 +257,13 @@ class _TLSLayer(layer.Layer):
pass # We have already dispatched a ConnectionClosed to the child layer.
else:
yield from self.event_to_child(event)
else:
elif not self.errored:
yield from self.on_handshake_error("connection closed without notice")
else:
yield from self.event_to_child(event)
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
self.errored = True
yield commands.CloseConnection(self.conn)
@ -269,13 +275,12 @@ class ServerTLSLayer(_TLSLayer):
def __init__(self, context: context.Context):
super().__init__(context, context.server)
self.child_layer = layer.NextLayer(self.context)
@expect(events.Start)
def state_start(self, _) -> layer.CommandGenerator[None]:
self.context.server.tls = True
def state_start(self, event) -> layer.CommandGenerator[None]:
if self.context.server.connected:
yield from self.start_tls()
yield from self.event_to_child(event)
self._handle_event = super()._handle_event
_handle_event = state_start
@ -327,20 +332,12 @@ class ClientTLSLayer(_TLSLayer):
"""
recv_buffer: bytearray
server_tls_available: bool
def __init__(self, context: context.Context):
assert isinstance(context.layers[-1], ServerTLSLayer)
super().__init__(context, self.context.client)
super().__init__(context, context.client)
self.server_tls_available = isinstance(self.context.layers[-2], ServerTLSLayer)
self.recv_buffer = bytearray()
self.child_layer = layer.NextLayer(self.context)
@expect(events.Start)
def state_start(self, _) -> layer.CommandGenerator[None]:
self.context.client.tls = True
self._handle_event = self.state_wait_for_clienthello
yield from ()
_handle_event = state_start
def state_wait_for_clienthello(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, events.DataReceived) and event.connection == self.conn:
@ -361,8 +358,8 @@ class ClientTLSLayer(_TLSLayer):
if tls_clienthello.establish_server_tls_first and not self.context.server.tls_established:
err = yield from self.start_server_tls()
if err:
yield commands.Log("Unable to establish TLS connection with server. "
"Trying to establish TLS with client anyway.")
yield commands.Log(f"Unable to establish TLS connection with server ({err}). "
f"Trying to establish TLS with client anyway.")
yield from self.start_tls(bytes(self.recv_buffer))
self.recv_buffer.clear()
@ -373,14 +370,17 @@ class ClientTLSLayer(_TLSLayer):
else:
yield from self.event_to_child(event)
_handle_event = state_wait_for_clienthello
def start_server_tls(self) -> layer.CommandGenerator[Optional[str]]:
"""
We often need information from the upstream connection to establish TLS with the client.
For example, we need to check if the client does ALPN or not.
"""
if not self.server_tls_available:
return "No server TLS available."
err = yield commands.OpenConnection(self.context.server)
if err:
yield commands.Log(f"Cannot establish server connection: {err}")
return err
else:
return None
@ -400,3 +400,14 @@ class ClientTLSLayer(_TLSLayer):
level="warn"
)
yield from super().on_handshake_error(err)
class MockTLSLayer(_TLSLayer):
"""Mock layer to disable actual TLS and use cleartext in tests.
Use like so:
monkeypatch.setattr(tls, "ServerTLSLayer", tls.MockTLSLayer)
"""
def __init__(self, ctx: context.Context):
super().__init__(ctx, context.Server(None))

View File

@ -15,6 +15,7 @@ import time
import traceback
import typing
from contextlib import contextmanager
from dataclasses import dataclass
from OpenSSL import SSL
@ -27,11 +28,6 @@ from mitmproxy.utils import human
from mitmproxy.utils.data import pkg_data
class StreamIO(typing.NamedTuple):
r: asyncio.StreamReader
w: asyncio.StreamWriter
class TimeoutWatchdog:
last_activity: float
CONNECTION_TIMEOUT = 10 * 60
@ -72,8 +68,15 @@ class TimeoutWatchdog:
self.can_timeout.set()
@dataclass
class ConnectionIO:
handler: typing.Optional[asyncio.Task] = None
reader: typing.Optional[asyncio.StreamReader] = None
writer: typing.Optional[asyncio.StreamWriter] = None
class ConnectionHandler(metaclass=abc.ABCMeta):
transports: typing.MutableMapping[Connection, StreamIO]
transports: typing.MutableMapping[Connection, ConnectionIO]
timeout_watchdog: TimeoutWatchdog
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, options: moptions.Options) -> None:
@ -81,93 +84,86 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
self.client = Client(addr)
self.context = Context(self.client, options)
self.layer = layer.NextLayer(self.context)
self.timeout_watchdog = TimeoutWatchdog(self.on_timeout)
self.transports = {
self.client: ConnectionIO(handler=None, reader=reader, writer=writer)
}
# Ask for the first layer right away.
# In a reverse proxy scenario, this is necessary as we would otherwise hang
# on protocols that start with a server greeting.
self.layer.ask_now()
self.layer = layer.NextLayer(self.context, ask_on_start=True)
self.transports = {
self.client: StreamIO(reader, writer)
}
self.timeout_watchdog = TimeoutWatchdog(self.on_timeout)
async def handle_client(self) -> None:
# FIXME: Work around log suppression in core.
# Hack: Work around log suppression in core.
logging.getLogger('asyncio').setLevel(logging.DEBUG)
asyncio.get_event_loop().set_debug(True)
watch = asyncio.ensure_future(self.timeout_watchdog.watch())
self.log("[sans-io] clientconnect")
handler = asyncio.create_task(
self.handle_connection(self.client)
)
self.transports[self.client].handler = handler
self.server_event(events.Start())
await self.handle_connection(self.client)
await handler
self.log("[sans-io] clientdisconnected")
watch.cancel()
if self.transports:
self.log("[sans-io] closing transports...")
await asyncio.wait([
self.close_connection(x)
for x in self.transports
])
for x in self.transports.values():
x.handler.cancel()
await asyncio.wait([x.handler for x in self.transports.values()])
self.log("[sans-io] transports closed!")
async def on_timeout(self) -> None:
self.log(f"Closing connection due to inactivity: {self.client}")
await self.close_connection(self.client)
async def close_connection(self, connection: Connection) -> None:
self.log(f"closing {connection}", "debug")
connection.state = ConnectionState.CLOSED
io = self.transports.pop(connection)
io.w.close()
await io.w.wait_closed()
async def shutdown_connection(self, connection: Connection) -> None:
assert connection.state & ConnectionState.CAN_WRITE
io = self.transports[connection]
self.log(f"shutting down {connection}", "debug")
io.w.write_eof()
connection.state &= ~ConnectionState.CAN_WRITE
async def handle_connection(self, connection: Connection) -> None:
reader, writer = self.transports[connection]
while True:
try:
data = await reader.read(65535)
except socket.error:
data = b""
if data:
self.server_event(events.DataReceived(connection, data))
else:
if connection.state is ConnectionState.CAN_READ:
await self.close_connection(connection)
else:
connection.state &= ~ConnectionState.CAN_READ
self.server_event(events.ConnectionClosed(connection))
break
async def open_connection(self, command: commands.OpenConnection) -> None:
if not command.connection.address:
raise ValueError("Cannot open connection, no hostname given.")
assert command.connection not in self.transports
try:
reader, writer = await asyncio.open_connection(
*command.connection.address
)
except IOError as e:
reader, writer = await asyncio.open_connection(*command.connection.address)
except (IOError, asyncio.CancelledError) as e:
self.server_event(events.OpenConnectionReply(command, str(e)))
else:
self.log("serverconnect")
self.transports[command.connection] = StreamIO(reader, writer)
self.log(f"serverconnect {command.connection.address}")
self.transports[command.connection].reader = reader
self.transports[command.connection].writer = writer
command.connection.state = ConnectionState.OPEN
self.server_event(events.OpenConnectionReply(command, None))
try:
await self.handle_connection(command.connection)
finally:
self.log("serverdisconnected")
async def handle_connection(self, connection: Connection) -> None:
reader = self.transports[connection].reader
assert reader
try:
while True:
try:
data = await reader.read(65535)
except (socket.error, asyncio.CancelledError):
data = b""
if data:
self.server_event(events.DataReceived(connection, data))
else:
connection.state &= ~ConnectionState.CAN_READ
self.server_event(events.ConnectionClosed(connection))
if connection.state is ConnectionState.CLOSED:
self.transports[connection].handler.cancel()
await asyncio.Event().wait() # wait for cancellation
except asyncio.CancelledError:
connection.state = ConnectionState.CLOSED
io = self.transports.pop(connection)
io.writer.close()
async def on_timeout(self) -> None:
self.log(f"Closing connection due to inactivity: {self.client}")
self.transports[self.client].handler.cancel()
@abc.abstractmethod
async def handle_hook(self, hook: commands.Hook) -> None:
pass
@ -180,31 +176,24 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
try:
layer_commands = self.layer.handle_event(event)
for command in layer_commands:
if isinstance(command, commands.OpenConnection):
asyncio.ensure_future(
assert command.connection not in self.transports
handler = asyncio.create_task(
self.open_connection(command)
)
self.transports[command.connection] = ConnectionIO(handler=handler)
elif isinstance(command, commands.ConnectionCommand) and command.connection not in self.transports:
return # The connection has already been closed.
elif isinstance(command, commands.SendData):
try:
io = self.transports[command.connection]
except KeyError:
raise RuntimeError(f"Cannot write to closed connection: {command.connection}")
else:
io.w.write(command.data)
self.transports[command.connection].writer.write(command.data)
elif isinstance(command, commands.CloseConnection):
if command.connection == self.client:
asyncio.ensure_future(
self.close_connection(command.connection)
)
else:
asyncio.ensure_future(
self.shutdown_connection(command.connection)
)
self.close_our_end(command.connection)
elif isinstance(command, commands.GetSocket):
socket = self.transports[command.connection].w.get_extra_info("socket")
socket = self.transports[command.connection].writer.get_extra_info("socket")
self.server_event(events.GetSocketReply(command, socket))
elif isinstance(command, commands.Hook):
asyncio.ensure_future(
asyncio.create_task(
self.handle_hook(command)
)
elif isinstance(command, commands.Log):
@ -214,6 +203,22 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
except Exception:
self.log(f"mitmproxy has crashed!\n{traceback.format_exc()}", level="error")
def close_our_end(self, connection):
assert connection.state & ConnectionState.CAN_WRITE
self.log(f"shutting down {connection}", "debug")
try:
self.transports[connection].writer.write_eof()
except socket.error:
connection.state = ConnectionState.CLOSED
connection.state &= ~ConnectionState.CAN_WRITE
# if we are closing the client connection, we should destroy everything.
if connection == self.client:
self.transports[connection].handler.cancel()
# If we have already received a close, let's finish everything.
elif connection.state is ConnectionState.CLOSED:
self.transports[connection].handler.cancel()
class SimpleConnectionHandler(ConnectionHandler):
"""Simple handler that does not really process any hooks."""
@ -253,10 +258,10 @@ if __name__ == "__main__":
async def handle(reader, writer):
layer_stack = [
lambda ctx: layers.ServerTLSLayer(ctx),
lambda ctx: layers.HTTPLayer(ctx, HTTPMode.regular),
lambda ctx: layers.HttpLayer(ctx, HTTPMode.regular),
lambda ctx: setattr(ctx.server, "tls", True) or layers.ServerTLSLayer(ctx),
lambda ctx: layers.ClientTLSLayer(ctx),
lambda ctx: layers.HTTPLayer(ctx, HTTPMode.transparent)
lambda ctx: layers.HttpLayer(ctx, HTTPMode.transparent)
]
def next_layer(nl: layer.NextLayer):

View File

@ -6,13 +6,15 @@ import pytest
from mitmproxy.net.websockets import Frame, OPCODE
from mitmproxy.proxy2 import commands, events
from mitmproxy.proxy2.layers.old import websocket
from mitmproxy.proxy2.context import ConnectionState
from mitmproxy.test import tflow
from .. import tutils
@pytest.fixture
def ws_playbook(tctx):
tctx.server.connected = True
tctx.server.state = ConnectionState.OPEN
playbook = tutils.Playbook(
websocket.WebsocketLayer(
tctx,

View File

@ -6,7 +6,7 @@ from mitmproxy.proxy2 import layer
from mitmproxy.proxy2.commands import CloseConnection, OpenConnection, SendData
from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
from mitmproxy.proxy2.layers import http, tls
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply, reply_establish_server_tls, reply_next_layer
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply, reply_next_layer
def test_http_proxy(tctx):
@ -14,7 +14,7 @@ def test_http_proxy(tctx):
server = Placeholder()
flow = Placeholder()
assert (
Playbook(http.HTTPLayer(tctx, HTTPMode.regular))
Playbook(http.HttpLayer(tctx, HTTPMode.regular))
>> DataReceived(tctx.client, b"GET http://example.com/foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< http.HttpRequestHeadersHook(flow)
>> reply()
@ -39,7 +39,7 @@ def test_https_proxy(strategy, tctx):
"""Test a CONNECT request, followed by a HTTP GET /"""
server = Placeholder()
flow = Placeholder()
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular))
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular))
tctx.options.connection_strategy = strategy
(playbook
@ -54,7 +54,7 @@ def test_https_proxy(strategy, tctx):
<< SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n')
>> DataReceived(tctx.client, b"GET /foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< layer.NextLayerHook(Placeholder())
>> reply_next_layer(lambda ctx: http.HTTPLayer(ctx, HTTPMode.transparent))
>> reply_next_layer(lambda ctx: http.HttpLayer(ctx, HTTPMode.transparent))
<< http.HttpRequestHeadersHook(flow)
>> reply()
<< http.HttpRequestHook(flow)
@ -77,12 +77,15 @@ def test_https_proxy(strategy, tctx):
@pytest.mark.parametrize("https_client", [False, True])
@pytest.mark.parametrize("https_server", [False, True])
@pytest.mark.parametrize("strategy", ["lazy", "eager"])
def test_redirect(strategy, https_server, https_client, tctx):
def test_redirect(strategy, https_server, https_client, tctx, monkeypatch):
"""Test redirects between http:// and https:// in regular proxy mode."""
server = Placeholder()
flow = Placeholder()
tctx.options.connection_strategy = strategy
p = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
p = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
if https_server:
monkeypatch.setattr(tls, "ServerTLSLayer", tls.MockTLSLayer)
def redirect(flow: HTTPFlow):
if https_server:
@ -98,16 +101,13 @@ def test_redirect(strategy, https_server, https_client, tctx):
p << SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n')
p >> DataReceived(tctx.client, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
p << layer.NextLayerHook(Placeholder())
p >> reply_next_layer(lambda ctx: http.HTTPLayer(ctx, HTTPMode.transparent))
p >> reply_next_layer(lambda ctx: http.HttpLayer(ctx, HTTPMode.transparent))
else:
p >> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
p << http.HttpRequestHook(flow)
p >> reply(side_effect=redirect)
p << OpenConnection(server)
p >> reply(None)
if https_server:
pass # p << tls.EstablishServerTLS(server)
# p >> reply_establish_server_tls()
p << SendData(server, b"GET / HTTP/1.1\r\nHost: redirected.site\r\n\r\n")
p >> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
p << SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
@ -123,7 +123,7 @@ def test_multiple_server_connections(tctx):
"""Test multiple requests being rewritten to different targets."""
server1 = Placeholder()
server2 = Placeholder()
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
def redirect(to: str):
def side_effect(flow: HTTPFlow):
@ -164,7 +164,7 @@ def test_http_reply_from_proxy(tctx):
flow.response = HTTPResponse.make(418)
assert (
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< http.HttpRequestHook(Placeholder())
>> reply(side_effect=reply_from_proxy)
@ -176,7 +176,7 @@ def test_response_until_eof(tctx):
"""Test scenario where the server response body is terminated by EOF."""
server = Placeholder()
assert (
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< OpenConnection(server)
>> reply(None)
@ -197,7 +197,7 @@ def test_disconnect_while_intercept(tctx):
flow = Placeholder()
assert (
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
>> DataReceived(tctx.client, b"CONNECT example.com:80 HTTP/1.1\r\n\r\n")
<< http.HttpConnectHook(Placeholder())
>> reply()
@ -206,7 +206,7 @@ def test_disconnect_while_intercept(tctx):
<< SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n')
>> DataReceived(tctx.client, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< layer.NextLayerHook(Placeholder())
>> reply_next_layer(lambda ctx: http.HTTPLayer(ctx, HTTPMode.transparent))
>> reply_next_layer(lambda ctx: http.HttpLayer(ctx, HTTPMode.transparent))
<< http.HttpRequestHook(flow)
>> ConnectionClosed(server1)
>> reply(to=-2)
@ -229,7 +229,7 @@ def test_response_streaming(tctx):
flow.response.stream = lambda x: x.upper()
assert (
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
>> DataReceived(tctx.client, b"GET http://example.com/largefile HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< OpenConnection(server)
>> reply(None)
@ -252,7 +252,7 @@ def test_request_streaming(tctx, response):
"""
server = Placeholder()
flow = Placeholder()
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
def enable_streaming(flow: HTTPFlow):
flow.request.stream = lambda x: x.upper()
@ -324,7 +324,7 @@ def test_server_aborts(tctx, data):
server = Placeholder()
flow = Placeholder()
err = Placeholder()
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
assert (
playbook
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")

View File

@ -1,4 +1,5 @@
from mitmproxy.proxy2.commands import CloseConnection, OpenConnection, SendData
from mitmproxy.proxy2.context import ConnectionState
from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
from mitmproxy.proxy2.layers import tcp
from ..tutils import Placeholder, Playbook, reply
@ -14,7 +15,7 @@ def test_open_connection(tctx):
<< OpenConnection(tctx.server)
)
tctx.server.connected = True
tctx.server.state = ConnectionState.OPEN
assert (
Playbook(tcp.TCPLayer(tctx, True))
<< None

View File

@ -5,6 +5,7 @@ import pytest
from OpenSSL import SSL
from mitmproxy.proxy2 import commands, context, events, layer
from mitmproxy.proxy2.context import ConnectionState
from mitmproxy.proxy2.layers import tls
from mitmproxy.utils import data
from test.mitmproxy.proxy2 import tutils
@ -109,7 +110,6 @@ class TlsEchoLayer(tutils.EchoLayer):
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, events.DataReceived) and event.data == b"open-connection":
# noinspection PyTypeChecker
err = yield commands.OpenConnection(self.context.server)
if err:
yield commands.SendData(event.connection, f"open-connection failed: {err}".encode())
@ -180,23 +180,20 @@ def reply_tls_start(alpn: typing.Optional[bytes] = None, *args, **kwargs) -> tut
class TestServerTLS:
def test_no_tls(self, tctx: context.Context):
"""Test TLS layer without TLS"""
def test_not_connected(self, tctx: context.Context):
"""Test that we don't do anything if no server connection exists."""
layer = tls.ServerTLSLayer(tctx)
layer.child_layer = TlsEchoLayer(tctx)
# Handshake
assert (
tutils.Playbook(layer)
>> events.DataReceived(tctx.client, b"Hello World")
<< commands.SendData(tctx.client, b"hello world")
>> events.DataReceived(tctx.server, b"Foo")
<< commands.SendData(tctx.server, b"foo")
)
def test_simple(self, tctx):
playbook = tutils.Playbook(tls.ServerTLSLayer(tctx))
tctx.server.connected = True
tctx.server.state = ConnectionState.OPEN
tctx.server.address = ("example.mitmproxy.org", 443)
tctx.server.sni = b"example.mitmproxy.org"
@ -250,7 +247,6 @@ class TestServerTLS:
def test_untrusted_cert(self, tctx):
"""If the certificate is not trusted, we should fail."""
playbook = tutils.Playbook(tls.ServerTLSLayer(tctx))
tctx.server.connected = True
tctx.server.address = ("wrong.host.mitmproxy.org", 443)
tctx.server.sni = b"wrong.host.mitmproxy.org"
@ -260,9 +256,11 @@ class TestServerTLS:
data = tutils.Placeholder()
assert (
playbook
>> events.DataReceived(tctx.client, b"establish-server-tls")
>> events.DataReceived(tctx.client, b"open-connection")
<< layer.NextLayerHook(tutils.Placeholder())
>> tutils.reply_next_layer(TlsEchoLayer)
<< commands.OpenConnection(tctx.server)
>> tutils.reply(None)
<< tls.TlsStartHook(tutils.Placeholder())
>> reply_tls_start()
<< commands.SendData(tctx.server, data)
@ -278,7 +276,8 @@ class TestServerTLS:
>> events.DataReceived(tctx.server, tssl.out.read())
<< commands.Log("Server TLS handshake failed. Certificate verify failed: Hostname mismatch", "warn")
<< commands.CloseConnection(tctx.server)
<< commands.SendData(tctx.client, b"server-tls-failed: Certificate verify failed: Hostname mismatch")
<< commands.SendData(tctx.client,
b"open-connection failed: Certificate verify failed: Hostname mismatch")
)
assert not tctx.server.tls_established
@ -334,10 +333,11 @@ class TestClientTLS:
# Echo
_test_echo(playbook, tssl_client, tctx.client)
other_server = context.Server(None)
assert (
playbook
>> events.DataReceived(tctx.server, b"Plaintext")
<< commands.SendData(tctx.server, b"plaintext")
>> events.DataReceived(other_server, b"Plaintext")
<< commands.SendData(other_server, b"plaintext")
)
def test_server_required(self, tctx):
@ -419,6 +419,7 @@ class TestClientTLS:
def test_mitmproxy_ca_is_untrusted(self, tctx: context.Context):
"""Test the scenario where the client doesn't trust the mitmproxy CA."""
playbook, client_layer, tssl_client = make_client_tls_layer(tctx, sni=b"wrong.host.mitmproxy.org")
playbook.logs = True
data = tutils.Placeholder()
assert (
@ -440,6 +441,7 @@ class TestClientTLS:
<< commands.Log("Client TLS handshake failed. The client does not trust the proxy's certificate "
"for wrong.host.mitmproxy.org (sslv3 alert bad certificate)", "warn")
<< commands.CloseConnection(tctx.client)
>> events.ConnectionClosed(tctx.client)
)
assert not tctx.client.tls_established

View File

@ -1,4 +1,5 @@
import typing
from dataclasses import dataclass
import pytest
@ -20,6 +21,7 @@ class TCommand(commands.Command):
self.x = x
@dataclass
class TCommandReply(events.CommandReply):
command: TCommand
@ -157,7 +159,7 @@ def test_command_reply(tplaybook):
tplaybook
>> TEvent()
<< TCommand()
>> tutils.reply(42)
>> tutils.reply()
)
assert tplaybook.actual[1] == tplaybook.actual[2].command

View File

@ -10,8 +10,7 @@ from mitmproxy.proxy2 import commands, context, layer
from mitmproxy.proxy2 import events
from mitmproxy.proxy2.context import ConnectionState
from mitmproxy.proxy2.events import command_reply_subclasses
from mitmproxy.proxy2.layer import Layer, NextLayer
from mitmproxy.proxy2.layers import tls
from mitmproxy.proxy2.layer import Layer
PlaybookEntry = typing.Union[commands.Command, events.Event]
PlaybookEntryList = typing.List[PlaybookEntry]
@ -101,7 +100,7 @@ class Playbook:
assert playbook(tcp.TCPLayer(tctx)) \
<< commands.OpenConnection(tctx.server)
>> events.OpenConnectionReply(-1, "ok") # -1 = reply to command in previous line.
>> reply(None)
<< None # this line is optional.
This is syntactic sugar for the following:
@ -351,15 +350,3 @@ def reply_next_layer(
next_layer.layer = child_layer(next_layer.context)
return reply(*args, side_effect=set_layer, **kwargs)
def reply_establish_server_tls(**kwargs) -> reply:
"""Helper function to simplify the syntax for EstablishServerTls events to this:
<< tls.EstablishServerTLS(server)
>> tutils.reply_establish_server_tls()
"""
def fake_tls(cmd: tls.EstablishServerTLS) -> None:
cmd.connection.tls_established = True
return reply(None, side_effect=fake_tls, **kwargs)

View File

@ -84,7 +84,7 @@ class TestConnectionHandler:
def ask(_, x):
raise RuntimeError
channel.ask = ask
channel._ask = ask
c = ConnectionHandler(
mock.MagicMock(),
("127.0.0.1", 8080),