[sans-io] wip: tls establishment semantics

This commit is contained in:
Maximilian Hils 2019-12-27 18:39:01 +01:00
parent 7efe27be74
commit b2060356b6
10 changed files with 294 additions and 286 deletions

View File

@ -74,16 +74,6 @@ class CloseConnection(ConnectionCommand):
"""
class ProtocolError(Command):
"""
Indicate that an unrecoverable protocol error has occured.
"""
message: str
def __init__(self, message: str):
self.message = message
class Hook(Command):
"""
Callback to the master (like ".ask()")
@ -143,5 +133,3 @@ class Log(Command):
def __repr__(self):
return f"Log({self.message}, {self.level})"

View File

@ -5,10 +5,10 @@ from mitmproxy import flow, http
from mitmproxy.proxy.protocol.http import HTTPMode
from mitmproxy.proxy2 import commands, events, layer
from mitmproxy.proxy2.context import Connection, Context, Server
from mitmproxy.proxy2.layers.tls import EstablishServerTLS, EstablishServerTLSReply, HTTP_ALPNS
from mitmproxy.proxy2.layers import tls
from mitmproxy.proxy2.utils import expect
from mitmproxy.utils import human
from ._base import HttpConnection, StreamId
from ._base import HttpConnection, StreamId, HttpCommand, ReceiveHttp
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \
@ -17,10 +17,6 @@ from ._http1 import Http1Client, Http1Server
from ._http2 import Http2Client
class HttpCommand(commands.Command):
pass
class GetHttpConnection(HttpCommand):
"""
Open a HTTP Connection. This may not actually open a connection, but return an existing HTTP connection instead.
@ -47,6 +43,18 @@ class GetHttpConnectionReply(events.CommandReply):
"""connection object, error message"""
class RegisterHttpConnection(HttpCommand):
"""
Register that a HTTP connection has been successfully established.
"""
connection: Connection
err: str
def __init__(self, connection: Connection, err: str):
self.connection = connection
self.err = err
class SendHttp(HttpCommand):
connection: Connection
event: HttpEvent
@ -59,6 +67,8 @@ class SendHttp(HttpCommand):
return f"Send({self.event})"
class HttpStream(layer.Layer):
request_body_buf: bytes
response_body_buf: bytes
@ -299,7 +309,7 @@ class HttpStream(layer.Layer):
yield from ()
class HTTPLayer(Layer):
class HTTPLayer(layer.Layer):
"""
ConnectionEvent: We have received b"GET /\r\n\r\n" from the client.
HttpEvent: We have received request headers
@ -311,47 +321,34 @@ class HTTPLayer(Layer):
mode: HTTPMode
stream_by_command: typing.Dict[commands.Command, HttpStream]
streams: typing.Dict[int, HttpStream]
connections: typing.Dict[Connection, typing.Union[HttpConnection, HttpStream]]
waiting_for_connection: typing.DefaultDict[Connection, typing.List[GetHttpConnection]]
event_queue: typing.Deque[
typing.Union[HttpEvent, HttpCommand, commands.Command]
]
connections: typing.Dict[Connection, typing.Union[layer.Layer, HttpStream]]
waiting_for_establishment: typing.DefaultDict[Connection, typing.List[GetHttpConnection]]
command_queue: typing.Deque[commands.Command]
def __init__(self, context: Context, mode: HTTPMode):
super().__init__(context)
self.mode = mode
self.waiting_for_connection = collections.defaultdict(list)
self.waiting_for_establishment = collections.defaultdict(list)
self.streams = {}
self.stream_by_command = {}
self.event_queue = collections.deque()
self.command_queue = collections.deque()
self.connections = {
context.client: Http1Server(context.client)
context.client: Http1Server(context.fork())
}
def __repr__(self):
return f"HTTPLayer(conns: {len(self.connections)}, events: {[type(e).__name__ for e in self.event_queue]})"
return f"HTTPLayer(conns: {len(self.connections)}, queue: {[type(e).__name__ for e in self.command_queue]})"
def _handle_event(self, event: events.Event):
if isinstance(event, events.Start):
return
elif isinstance(event, (EstablishServerTLSReply, events.OpenConnectionReply)) and \
event.command.connection in self.waiting_for_connection:
if event.reply:
waiting = self.waiting_for_connection.pop(event.command.connection)
for cmd in waiting:
stream = self.stream_by_command.pop(cmd)
self.event_to_child(stream, GetHttpConnectionReply(cmd, (None, event.reply)))
else:
yield from self.make_http_connection(event.command.connection)
elif isinstance(event, events.CommandReply):
try:
stream = self.stream_by_command.pop(event.command)
except KeyError:
raise
if isinstance(event, events.OpenConnectionReply):
self.connections[event.command.connection] = stream
self.event_to_child(stream, event)
elif isinstance(event, events.ConnectionEvent):
if event.connection == self.context.server and self.context.server not in self.connections:
@ -362,117 +359,129 @@ class HTTPLayer(Layer):
else:
raise ValueError(f"Unexpected event: {event}")
while self.event_queue:
event = self.event_queue.popleft()
if isinstance(event, RequestHeaders):
self.streams[event.stream_id] = self.make_stream()
if isinstance(event, HttpEvent):
stream = self.streams[event.stream_id]
self.event_to_child(stream, event)
elif isinstance(event, SendHttp):
conn = self.connections[event.connection]
evts = conn.send(event.event)
self.event_queue.extend(evts)
elif isinstance(event, GetHttpConnection):
yield from self.get_connection(event)
elif isinstance(event, commands.Command):
yield event
else:
raise ValueError(f"Unexpected event: {event}")
while self.command_queue:
command = self.command_queue.popleft()
if isinstance(command, ReceiveHttp):
if isinstance(command.event, RequestHeaders):
self.streams[command.event.stream_id] = self.make_stream()
stream = self.streams[command.event.stream_id]
self.event_to_child(stream, command.event)
elif isinstance(command, SendHttp):
conn = self.connections[command.connection]
self.event_to_child(conn, command.event)
elif isinstance(command, GetHttpConnection):
self.get_connection(command)
elif isinstance(command, RegisterHttpConnection):
yield from self.register_connection(command)
elif isinstance(command, commands.Command):
yield command
else: # pragma: no cover
raise ValueError(f"Not a command command: {command}")
def make_stream(self) -> HttpStream:
ctx = self.context.fork()
stream = HttpStream(ctx)
if self.debug:
stream.debug = self.debug + " "
self.event_to_child(stream, events.Start())
return stream
def get_connection(self, event: GetHttpConnection):
def get_connection(self, event: GetHttpConnection, *, reuse: bool = True):
# Do we already have a connection we can re-use?
for connection, handler in self.connections.items():
for connection, layer in self.connections.items():
connection_suitable = (
reuse and
event.connection_spec_matches(connection) and
(
isinstance(handler, Http2Client) or
# see "tricky multiplexing edge case" in make_http_connection for an explanation
isinstance(handler, Http1Client) and self.context.client.alpn != b"h2"
connection.alpn == b"h2" or self.context.client.alpn != b"h2"
)
)
if connection_suitable:
if connection in self.waiting_for_establishment:
self.waiting_for_establishment[connection].append(event)
else:
stream = self.stream_by_command.pop(event)
self.event_to_child(stream, GetHttpConnectionReply(event, (connection, None)))
self.event_to_child(stream, GetHttpConnectionReply(event, (layer, None)))
return
# Are we waiting for one?
for connection in self.waiting_for_connection:
if event.connection_spec_matches(connection):
self.waiting_for_connection[connection].append(event)
return
# Can we reuse context.server?
can_reuse_context_connection = (
self.context.server not in self.connections and
self.context.server.connected and
self.context.server.address == event.address and
self.context.server.tls == event.tls
event.connection_spec_matches(self.context.server)
)
if can_reuse_context_connection:
self.waiting_for_connection[self.context.server].append(event)
yield from self.make_http_connection(self.context.server)
# We need a new one.
context = self.context.fork()
layer = HttpClient(context)
if not can_reuse_context_connection:
context.server = Server(event.address)
if event.tls:
context.server.tls = True
# TODO: This is a bit ugly, let's make up nicer syntax, e.g. using __truediv__
orig = layer
layer = tls.ServerTLSLayer(context)
layer.child_layer = orig
# TODO: Here we should create other sublayer(s) for upstream proxy etc.
self.connections[context.server] = layer
self.waiting_for_establishment[context.server].append(event)
self.event_to_child(layer, events.Start())
def register_connection(self, command: RegisterHttpConnection):
waiting = self.waiting_for_establishment.pop(command.connection)
if command.err:
reply = (None, command.err)
self.connections.pop(command.connection)
else:
connection = Server(event.address)
connection.tls = event.tls
self.waiting_for_connection[connection].append(event)
open_command = commands.OpenConnection(connection)
open_command.blocking = object()
yield open_command
reply = (command.connection, None)
def make_http_connection(self, connection: Server) -> None:
if connection.tls and not connection.tls_established:
connection.alpn_offers = list(HTTP_ALPNS)
if not self.context.options.http2:
connection.alpn_offers.remove(b"h2")
new_command = EstablishServerTLS(connection)
new_command.blocking = object()
yield new_command
return
if connection.alpn == b"h2":
raise NotImplementedError
else:
self.connections[connection] = Http1Client(connection)
waiting = self.waiting_for_connection.pop(connection)
for cmd in waiting:
stream = self.stream_by_command.pop(cmd)
self.event_to_child(stream, GetHttpConnectionReply(cmd, (connection, None)))
self.event_to_child(stream, GetHttpConnectionReply(cmd, reply))
# Tricky multiplexing edge case: Assume a h2 client that sends two requests (or receives two responses)
# that neither have a content-length specified nor a chunked transfer encoding.
# We can't process these two flows to the same h1 connection as they would both have
# "read until eof" semantics. We could force chunked transfer encoding for requests, but can't enforce that
# for responses. The only workaround left is to open a separate connection for each flow.
if self.context.client.alpn == b"h2" and connection.alpn != b"h2":
if not command.err and self.context.client.alpn == b"h2" and command.connection.alpn != b"h2":
for cmd in waiting[1:]:
new_connection = Server(connection.address)
new_connection.tls = connection.tls
self.waiting_for_connection[new_connection].append(cmd)
open_command = commands.OpenConnection(new_connection)
open_command.blocking = object()
yield open_command
yield from self.get_connection(cmd, reuse=False)
break
def event_to_child(
self,
stream: typing.Union[HttpConnection, HttpStream],
child: typing.Union[layer.Layer, HttpStream],
event: events.Event,
) -> None:
stream_events = list(stream.handle_event(event))
for se in stream_events:
child_commands = list(child.handle_event(event))
for cmd in child_commands:
assert isinstance(cmd, commands.Command)
# Streams may yield blocking commands, which ultimately generate CommandReply events.
# Those need to be routed back to the correct stream, so we need to keep track of that.
if isinstance(se, commands.Command) and se.blocking:
self.stream_by_command[se] = stream
if isinstance(cmd, commands.OpenConnection):
self.connections[cmd.connection] = child
self.event_queue.extend(stream_events)
if cmd.blocking:
self.stream_by_command[cmd] = child
self.command_queue.extend(child_commands)
class HttpClient(layer.Layer):
@expect(events.Start)
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if self.context.server.connected:
err = None
else:
err = yield commands.OpenConnection(self.context.server)
yield RegisterHttpConnection(self.context.server, err)
if err:
return
if self.context.server.alpn == b"h2":
raise NotImplementedError
else:
child_layer = Http1Client(self.context)
self._handle_event = child_layer.handle_event

View File

@ -1,8 +1,6 @@
import abc
import typing
from dataclasses import dataclass
from mitmproxy.proxy2 import events, layer
from mitmproxy.proxy2 import events, layer, commands
StreamId = int
@ -18,14 +16,22 @@ class HttpEvent(events.Event):
return f"{type(self).__name__}({repr(x) if x else ''})"
class HttpConnection(abc.ABC):
@abc.abstractmethod
def handle_event(self, event: events.Event) -> typing.Iterator[HttpEvent]:
yield from ()
class HttpConnection(layer.Layer):
pass
@abc.abstractmethod
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
yield from ()
class HttpCommand(commands.Command):
pass
class ReceiveHttp(HttpCommand):
event: HttpEvent
def __init__(self, event: HttpEvent):
self.event = event
def __repr__(self) -> str:
return f"Receive({self.event})"
__all__ = [

View File

@ -1,5 +1,5 @@
import abc
import typing
from abc import abstractmethod
import h11
from h11._readers import ChunkedReader, ContentLengthReader, Http10Reader
@ -9,8 +9,9 @@ from mitmproxy import http
from mitmproxy.net.http import http1
from mitmproxy.net.http.http1 import read_sansio as http1_sansio
from mitmproxy.proxy2 import commands, events, layer
from mitmproxy.proxy2.context import Client, Connection, Server
from mitmproxy.proxy2.layers.http._base import StreamId
from mitmproxy.proxy2.context import Client, Connection, Context, Server
from mitmproxy.proxy2.layers.http._base import ReceiveHttp, StreamId
from mitmproxy.proxy2.utils import expect
from ._base import HttpConnection
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
@ -18,28 +19,32 @@ from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders
TBodyReader = typing.Union[ChunkedReader, Http10Reader, ContentLengthReader]
class Http1Connection(HttpConnection):
class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
conn: Connection
stream_id: typing.Optional[StreamId] = None
request: typing.Optional[http.HTTPRequest] = None
response: typing.Optional[http.HTTPResponse] = None
request_done: bool = False
response_done: bool = False
state: typing.Callable[[events.Event], typing.Iterator[HttpEvent]]
state: typing.Callable[[events.Event], layer.CommandGenerator[None]]
body_reader: TBodyReader
buf: ReceiveBuffer
def __init__(self, conn: Connection):
def __init__(self, context: Context, conn: Connection):
super().__init__(context)
assert isinstance(conn, Connection)
self.conn = conn
self.buf = ReceiveBuffer()
def handle_event(self, event: events.Event) -> typing.Iterator[HttpEvent]:
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, HttpEvent):
yield from self.send(event)
else:
if isinstance(event, events.DataReceived):
self.buf += event.data
yield from self.state(event)
@abstractmethod
@abc.abstractmethod
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
yield from ()
@ -51,7 +56,7 @@ class Http1Connection(HttpConnection):
else:
return ContentLengthReader(expected_size)
def read_body(self, event: events.Event, is_request: bool) -> typing.Iterator[HttpEvent]:
def read_body(self, event: events.Event, is_request: bool) -> layer.CommandGenerator[None]:
while True:
try:
if isinstance(event, events.DataReceived):
@ -63,9 +68,9 @@ class Http1Connection(HttpConnection):
except h11.ProtocolError as e:
yield commands.CloseConnection(self.conn)
if is_request:
yield RequestProtocolError(self.stream_id, str(e))
yield ReceiveHttp(RequestProtocolError(self.stream_id, str(e)))
else:
yield ResponseProtocolError(self.stream_id, str(e))
yield ReceiveHttp(ResponseProtocolError(self.stream_id, str(e)))
return
if h11_event is None:
@ -73,17 +78,17 @@ class Http1Connection(HttpConnection):
elif isinstance(h11_event, h11.Data):
h11_event.data: bytearray # type checking
if is_request:
yield RequestData(self.stream_id, bytes(h11_event.data))
yield ReceiveHttp(RequestData(self.stream_id, bytes(h11_event.data)))
else:
yield ResponseData(self.stream_id, bytes(h11_event.data))
yield ReceiveHttp(ResponseData(self.stream_id, bytes(h11_event.data)))
elif isinstance(h11_event, h11.EndOfMessage):
if is_request:
yield RequestEndOfMessage(self.stream_id)
yield ReceiveHttp(RequestEndOfMessage(self.stream_id))
else:
yield ResponseEndOfMessage(self.stream_id)
yield ReceiveHttp(ResponseEndOfMessage(self.stream_id))
return
def wait(self, event: events.Event) -> typing.Iterator[HttpEvent]:
def wait(self, event: events.Event) -> layer.CommandGenerator[None]:
"""
We wait for the current flow to be finished before parsing the next message,
as we may want to upgrade to WebSocket or plain TCP before that.
@ -101,8 +106,8 @@ class Http1Server(Http1Connection):
"""A simple HTTP/1 server with no pipelining support."""
conn: Client
def __init__(self, conn: Client):
super().__init__(conn)
def __init__(self, context: Context):
super().__init__(context, context.client)
self.stream_id = 1
self.state = self.read_request_headers
@ -138,7 +143,7 @@ class Http1Server(Http1Connection):
else:
raise NotImplementedError(f"{event}")
def mark_done(self, *, request: bool = False, response: bool = False):
def mark_done(self, *, request: bool = False, response: bool = False) -> layer.CommandGenerator[None]:
if request:
self.request_done = True
if response:
@ -152,13 +157,13 @@ class Http1Server(Http1Connection):
elif self.request_done:
self.state = self.wait
def read_request_headers(self, event: events.Event) -> typing.Iterator[HttpEvent]:
def read_request_headers(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, events.DataReceived):
request_head = self.buf.maybe_extract_lines()
if request_head:
request_head = [bytes(x) for x in request_head] # TODO: Make url.parse compatible with bytearrays
self.request = http.HTTPRequest.wrap(http1_sansio.read_request_head(request_head))
yield RequestHeaders(self.stream_id, self.request)
yield ReceiveHttp(RequestHeaders(self.stream_id, self.request))
if self.request.first_line_format == "authority":
# The previous proxy server implementation tried to read the request body here:
@ -180,10 +185,10 @@ class Http1Server(Http1Connection):
else:
raise ValueError(f"Unexpected event: {event}")
def read_request_body(self, event: events.Event) -> typing.Iterator[HttpEvent]:
def read_request_body(self, event: events.Event) -> layer.CommandGenerator[None]:
for e in self.read_body(event, True):
yield e
if isinstance(e, RequestEndOfMessage):
if isinstance(e, ReceiveHttp) and isinstance(e.event, RequestEndOfMessage):
yield from self.mark_done(request=True)
@ -192,8 +197,8 @@ class Http1Client(Http1Connection):
send_queue: typing.List[HttpEvent]
"""A queue of send events for flows other than the one that is currently being transmitted."""
def __init__(self, conn: Server):
super().__init__(conn)
def __init__(self, context: Context):
super().__init__(context, context.server)
self.state = self.read_response_headers
self.send_queue = []
@ -231,7 +236,7 @@ class Http1Client(Http1Connection):
else:
raise NotImplementedError(f"{event}")
def mark_done(self, *, request: bool = False, response: bool = False):
def mark_done(self, *, request: bool = False, response: bool = False) -> layer.CommandGenerator[None]:
if request:
self.request_done = True
if response:
@ -246,15 +251,15 @@ class Http1Client(Http1Connection):
for ev in send_queue:
yield from self.send(ev)
def read_response_headers(self, event: events.Event) -> typing.Iterator[HttpEvent]:
assert isinstance(event, events.ConnectionEvent)
@expect(events.ConnectionEvent)
def read_response_headers(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
if isinstance(event, events.DataReceived):
response_head = self.buf.maybe_extract_lines()
if response_head:
response_head = [bytes(x) for x in response_head]
self.response = http.HTTPResponse.wrap(http1_sansio.read_response_head(response_head))
yield ResponseHeaders(self.stream_id, self.response)
yield ReceiveHttp(ResponseHeaders(self.stream_id, self.response))
expected_size = http1.expected_http_body_size(self.request, self.response)
self.body_reader = self.make_body_reader(expected_size)
@ -266,23 +271,24 @@ class Http1Client(Http1Connection):
elif isinstance(event, events.ConnectionClosed):
if self.stream_id:
if self.buf:
yield ResponseProtocolError(self.stream_id, f"unexpected server response: {bytes(self.buf)}")
yield ReceiveHttp(
ResponseProtocolError(self.stream_id, f"unexpected server response: {bytes(self.buf)}"))
else:
# The server has closed the connection to prevent us from continuing.
# We need to signal that to the stream.
# https://tools.ietf.org/html/rfc7231#section-6.5.11
yield ResponseProtocolError(self.stream_id, "server closed connection")
yield ReceiveHttp(ResponseProtocolError(self.stream_id, "server closed connection"))
else:
return
yield commands.CloseConnection(self.conn)
else:
raise ValueError(f"Unexpected event: {event}")
def read_response_body(self, event: events.Event) -> typing.Iterator[HttpEvent]:
assert isinstance(event, events.ConnectionEvent)
@expect(events.ConnectionEvent)
def read_response_body(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
for e in self.read_body(event, False):
yield e
if isinstance(e, ResponseEndOfMessage):
if isinstance(e, ReceiveHttp) and isinstance(e.event, ResponseEndOfMessage):
self.state = self.read_response_headers
yield from self.mark_done(response=True)

View File

@ -67,7 +67,7 @@ class TCPLayer(layer.Layer):
if self.flow:
tcp_message = tcp.TCPMessage(from_client, event.data)
self.flow.messages.append(tcp_message)
yield TcpMessageHook(self.flow)t
yield TcpMessageHook(self.flow)
yield commands.SendData(send_to, tcp_message.content)
else:
yield commands.SendData(send_to, event.data)

View File

@ -1,7 +1,7 @@
import struct
import time
from dataclasses import dataclass
from typing import Any, Dict, Generator, Iterator, Optional, Tuple
from typing import Iterator, Optional, Tuple
from OpenSSL import SSL
@ -93,20 +93,6 @@ def parse_client_hello(data: bytes) -> Optional[net_tls.ClientHello]:
HTTP_ALPNS = (b"h2", b"http/1.1", b"http/1.0", b"http/0.9")
class EstablishServerTLS(commands.ConnectionCommand):
"""Establish TLS on the given connection.
If TLS establishment fails, the connection will automatically be closed by the TLS layer."""
connection: context.Server
blocking = True
class EstablishServerTLSReply(events.CommandReply):
command: EstablishServerTLS
reply: Optional[str]
"""error message"""
# We need these classes as hooks can only have one argument at the moment.
@dataclass
@ -131,62 +117,61 @@ class TlsClienthelloHook(Hook):
class _TLSLayer(layer.Layer):
tls: Dict[context.Connection, SSL.Connection]
conn: context.Connection
"""The connection for which we do TLS"""
tls: SSL.Connection = None
"""The OpenSSL connection object"""
child_layer: layer.Layer
ssl_context: Optional[SSL.Context] = None
def __init__(self, context: context.Context):
def __init__(self, context: context.Context, conn: context.Connection):
super().__init__(context)
self.tls = {}
self.conn = conn
def __repr__(self):
if not self.tls:
state = "inactive"
elif self.conn.tls_established:
state = f"passthrough {self.conn.sni} {self.conn.alpn}"
else:
conn_states = []
for conn in self.tls:
if conn.tls_established:
conn_states.append(f"passthrough {conn.sni} {conn.alpn}")
else:
conn_states.append(f"negotiating {conn.sni} {conn.alpn}")
state = ", ".join(conn_states)
state = f"negotiating {self.conn.sni} {self.conn.alpn}"
return f"{type(self).__name__}({state})"
def start_tls(self, conn: context.Connection, initial_data: bytes = b""):
assert conn not in self.tls
assert conn.connected
conn.tls = True
def start_tls(self, initial_data: bytes = b""):
assert not self.tls
assert self.conn.connected
self.conn.tls = True
tls_start = TlsStartData(conn, self.context)
tls_start = TlsStartData(self.conn, self.context)
yield TlsStartHook(tls_start)
self.tls[conn] = tls_start.ssl_conn
assert tls_start.ssl_conn
self.tls = tls_start.ssl_conn
yield from self.negotiate(conn, initial_data)
yield from self.negotiate(initial_data)
def tls_interact(self, conn: context.Connection) -> layer.CommandGenerator[None]:
def tls_interact(self) -> layer.CommandGenerator[None]:
while True:
try:
data = self.tls[conn].bio_read(65535)
data = self.tls.bio_read(65535)
except SSL.WantReadError:
# Okay, nothing more waiting to be sent.
return
else:
yield commands.SendData(conn, data)
yield commands.SendData(self.conn, data)
def negotiate(self, conn: context.Connection, data: bytes) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]:
def negotiate(self, data: bytes) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]:
# bio_write errors for b"", so we need to check first if we actually received something.
if data:
self.tls[conn].bio_write(data)
self.tls.bio_write(data)
try:
self.tls[conn].do_handshake()
self.tls.do_handshake()
except SSL.WantReadError:
yield from self.tls_interact(conn)
yield from self.tls_interact()
return False, None
except SSL.Error as e:
# provide more detailed information for some errors.
last_err = e.args and isinstance(e.args[0], list) and e.args[0] and e.args[0][-1]
if last_err == ('SSL routines', 'tls_process_server_certificate', 'certificate verify failed'):
verify_result = SSL._lib.SSL_get_verify_result(self.tls[conn]._ssl)
verify_result = SSL._lib.SSL_get_verify_result(self.tls._ssl)
error = SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(verify_result)).decode()
err = f"Certificate verify failed: {error}"
elif last_err in [
@ -196,40 +181,40 @@ class _TLSLayer(layer.Layer):
err = last_err[2]
else:
err = repr(e)
yield from self.on_handshake_error(conn, err)
yield from self.on_handshake_error(err)
return False, err
else:
# Get all peer certificates.
# https://www.openssl.org/docs/man1.1.1/man3/SSL_get_peer_cert_chain.html
# If called on the client side, the stack also contains the peer's certificate; if called on the server
# side, the peer's certificate must be obtained separately using SSL_get_peer_certificate(3).
all_certs = self.tls[conn].get_peer_cert_chain() or []
if conn == self.context.client:
cert = self.tls[conn].get_peer_certificate()
all_certs = self.tls.get_peer_cert_chain() or []
if self.conn == self.context.client:
cert = self.tls.get_peer_certificate()
if cert:
all_certs.insert(0, cert)
conn.tls_established = True
conn.sni = self.tls[conn].get_servername()
conn.alpn = self.tls[conn].get_alpn_proto_negotiated()
conn.certificate_chain = [certs.Cert(x) for x in all_certs]
conn.cipher_list = self.tls[conn].get_cipher_list()
conn.tls_version = self.tls[conn].get_protocol_version_name()
conn.timestamp_tls_setup = time.time()
yield commands.Log(f"TLS established: {conn}")
yield from self.receive(conn, b"")
self.conn.tls_established = True
self.conn.sni = self.tls.get_servername()
self.conn.alpn = self.tls.get_alpn_proto_negotiated()
self.conn.certificate_chain = [certs.Cert(x) for x in all_certs]
self.conn.cipher_list = self.tls.get_cipher_list()
self.conn.tls_version = self.tls.get_protocol_version_name()
self.conn.timestamp_tls_setup = time.time()
yield commands.Log(f"TLS established: {self.conn}")
yield from self.receive(b"")
return True, None
def receive(self, conn: context.Connection, data: bytes):
def receive(self, data: bytes):
if data:
self.tls[conn].bio_write(data)
yield from self.tls_interact(conn)
self.tls.bio_write(data)
yield from self.tls_interact()
plaintext = bytearray()
close = False
while True:
try:
plaintext.extend(self.tls[conn].recv(65535))
plaintext.extend(self.tls.recv(65535))
except SSL.WantReadError:
break
except SSL.ZeroReturnError:
@ -238,76 +223,90 @@ class _TLSLayer(layer.Layer):
if plaintext:
yield from self.event_to_child(
events.DataReceived(conn, bytes(plaintext))
events.DataReceived(self.conn, bytes(plaintext))
)
if close:
conn.state &= ~context.ConnectionState.CAN_READ
yield commands.Log(f"TLS close_notify {conn=}")
self.conn.state &= ~context.ConnectionState.CAN_READ
yield commands.Log(f"TLS close_notify {self.conn}")
yield from self.event_to_child(
events.ConnectionClosed(conn)
events.ConnectionClosed(self.conn)
)
def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]:
for command in self.child_layer.handle_event(event):
if isinstance(command, commands.SendData) and command.connection in self.tls:
self.tls[command.connection].sendall(command.data)
yield from self.tls_interact(command.connection)
if isinstance(command, commands.SendData) and command.connection == self.conn:
self.tls.sendall(command.data)
yield from self.tls_interact()
else:
yield command
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, events.DataReceived) and event.connection in self.tls:
if isinstance(event, events.DataReceived) and event.connection == self.conn:
if not event.connection.tls_established:
yield from self.negotiate(event.connection, event.data)
yield from self.negotiate(event.data)
else:
yield from self.receive(event.connection, event.data)
elif isinstance(event, events.ConnectionClosed) and event.connection in self.tls:
yield from self.receive(event.data)
elif isinstance(event, events.ConnectionClosed) and event.connection == self.conn:
if event.connection.tls_established:
if self.tls[event.connection].get_shutdown() & SSL.RECEIVED_SHUTDOWN:
if self.tls.get_shutdown() & SSL.RECEIVED_SHUTDOWN:
pass # We have already dispatched a ConnectionClosed to the child layer.
else:
yield from self.event_to_child(event)
else:
yield from self.on_handshake_error(event.connection, "connection closed without notice")
yield from self.on_handshake_error("connection closed without notice")
else:
yield from self.event_to_child(event)
def on_handshake_error(self, conn: context.Connection, err: str) -> layer.CommandGenerator[None]:
yield commands.CloseConnection(conn)
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
yield commands.CloseConnection(self.conn)
class ServerTLSLayer(_TLSLayer):
"""
This layer manages TLS for potentially multiple server connections.
This layer establishes TLS for a single server connection.
"""
command_to_reply_to: Dict[context.Connection, commands.OpenConnection]
command_to_reply_to: Optional[commands.OpenConnection] = None
def __init__(self, context: context.Context):
super().__init__(context)
self.command_to_reply_to = {}
super().__init__(context, context.server)
self.child_layer = layer.NextLayer(self.context)
def negotiate(self, conn: context.Connection, data: bytes) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]:
done, err = yield from super().negotiate(conn, data)
@expect(events.Start)
def state_start(self, _) -> layer.CommandGenerator[None]:
self.context.server.tls = True
if self.context.server.connected:
yield from self.start_tls()
self._handle_event = super()._handle_event
_handle_event = state_start
def negotiate(self, data: bytes) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]:
done, err = yield from super().negotiate(data)
if done or err:
cmd = self.command_to_reply_to.pop(conn)
cmd = self.command_to_reply_to
yield from self.event_to_child(events.OpenConnectionReply(cmd, err))
self.command_to_reply_to = None
return done, err
def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]:
for command in super().event_to_child(event):
if isinstance(command, EstablishServerTLS):
self.command_to_reply_to[command.connection] = command
yield from self.start_tls(command.connection)
if isinstance(command, commands.OpenConnection) and command.connection == self.context.server:
# create our own OpenConnection command object that blocks here.
err = yield commands.OpenConnection(command.connection)
if err:
yield from self.event_to_child(events.OpenConnectionReply(command, err))
else:
self.command_to_reply_to = command
yield from self.start_tls()
else:
yield command
def on_handshake_error(self, conn: context.Connection, err: str) -> layer.CommandGenerator[None]:
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
yield commands.Log(
f"Server TLS handshake failed. {err}",
level="warn"
)
yield from super().on_handshake_error(conn, err)
yield from super().on_handshake_error(err)
class ClientTLSLayer(_TLSLayer):
@ -331,10 +330,9 @@ class ClientTLSLayer(_TLSLayer):
def __init__(self, context: context.Context):
assert isinstance(context.layers[-1], ServerTLSLayer)
super().__init__(context)
super().__init__(context, self.context.client)
self.recv_buffer = bytearray()
self.child_layer = layer.NextLayer(self.context)
self._handle_event = self.state_start
@expect(events.Start)
def state_start(self, _) -> layer.CommandGenerator[None]:
@ -342,20 +340,21 @@ class ClientTLSLayer(_TLSLayer):
self._handle_event = self.state_wait_for_clienthello
yield from ()
_handle_event = state_start
def state_wait_for_clienthello(self, event: events.Event) -> layer.CommandGenerator[None]:
client = self.context.client
if isinstance(event, events.DataReceived) and event.connection == client:
if isinstance(event, events.DataReceived) and event.connection == self.conn:
self.recv_buffer.extend(event.data)
try:
client_hello = parse_client_hello(self.recv_buffer)
except ValueError:
yield commands.Log(f"Cannot parse ClientHello: {self.recv_buffer.hex()}")
yield commands.CloseConnection(client)
yield commands.CloseConnection(self.conn)
return
if client_hello:
client.sni = client_hello.sni
client.alpn_offers = client_hello.alpn_protocols
self.conn.sni = client_hello.sni
self.conn.alpn_offers = client_hello.alpn_protocols
tls_clienthello = ClientHelloData(self.context)
yield TlsClienthelloHook(tls_clienthello)
@ -365,7 +364,7 @@ class ClientTLSLayer(_TLSLayer):
yield commands.Log("Unable to establish TLS connection with server. "
"Trying to establish TLS with client anyway.")
yield from self.start_tls(client, bytes(self.recv_buffer))
yield from self.start_tls(bytes(self.recv_buffer))
self.recv_buffer.clear()
self._handle_event = super()._handle_event
@ -379,25 +378,16 @@ class ClientTLSLayer(_TLSLayer):
We often need information from the upstream connection to establish TLS with the client.
For example, we need to check if the client does ALPN or not.
"""
server = self.context.server
if not server.connected:
err = yield commands.OpenConnection(server)
err = yield commands.OpenConnection(self.context.server)
if err:
yield commands.Log(
f"Cannot establish server connection: {err}"
)
yield commands.Log(f"Cannot establish server connection: {err}")
return err
else:
return None
err = yield EstablishServerTLS(server)
if err:
yield commands.Log(
f"Cannot establish TLS with server: {err}"
)
return err
def on_handshake_error(self, conn: context.Connection, err: str) -> layer.CommandGenerator[None]:
if conn.sni:
dest = conn.sni.decode("idna")
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
if self.conn.sni:
dest = self.conn.sni.decode("idna")
else:
dest = human.format_address(self.context.server.address)
if "unknown ca" in err or "bad certificate" in err:
@ -409,4 +399,4 @@ class ClientTLSLayer(_TLSLayer):
f"The client {keyword} trust the proxy's certificate for {dest} ({err})",
level="warn"
)
yield from super().on_handshake_error(conn, err)
yield from super().on_handshake_error(err)

View File

@ -18,7 +18,7 @@ def expect(*event_types):
if isinstance(event, event_types):
yield from f(self, event)
else:
event_types_str = '|'.join(e.__name__ for e in event_types)
event_types_str = '|'.join(e.__name__ for e in event_types) or "no events"
raise AssertionError(f"Unexpected event type at {f.__qualname__}: Expected {event_types_str}, got {event}.")
return wrapper

View File

@ -106,8 +106,8 @@ def test_redirect(strategy, https_server, https_client, tctx):
p << OpenConnection(server)
p >> reply(None)
if https_server:
p << tls.EstablishServerTLS(server)
p >> reply_establish_server_tls()
pass # p << tls.EstablishServerTLS(server)
# p >> reply_establish_server_tls()
p << SendData(server, b"GET / HTTP/1.1\r\nHost: redirected.site\r\n\r\n")
p >> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
p << SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
@ -123,6 +123,7 @@ def test_multiple_server_connections(tctx):
"""Test multiple requests being rewritten to different targets."""
server1 = Placeholder()
server2 = Placeholder()
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
def redirect(to: str):
def side_effect(flow: HTTPFlow):
@ -131,7 +132,7 @@ def test_multiple_server_connections(tctx):
return side_effect
assert (
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
playbook
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< http.HttpRequestHook(Placeholder())
>> reply(side_effect=redirect("http://one.redirect/"))
@ -140,6 +141,9 @@ def test_multiple_server_connections(tctx):
<< SendData(server1, b"GET / HTTP/1.1\r\nHost: one.redirect\r\n\r\n")
>> DataReceived(server1, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
)
assert (
playbook
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< http.HttpRequestHook(Placeholder())
>> reply(side_effect=redirect("http://two.redirect/"))

View File

@ -4,8 +4,6 @@ import typing
import pytest
from OpenSSL import SSL
import mitmproxy.proxy2.layer
import mitmproxy.proxy2.layers.tls
from mitmproxy.proxy2 import commands, context, events, layer
from mitmproxy.proxy2.layers import tls
from mitmproxy.utils import data
@ -110,11 +108,11 @@ class TlsEchoLayer(tutils.EchoLayer):
err: typing.Optional[str] = None
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, events.DataReceived) and event.data == b"establish-server-tls":
if isinstance(event, events.DataReceived) and event.data == b"open-connection":
# noinspection PyTypeChecker
err = yield tls.EstablishServerTLS(self.context.server)
err = yield commands.OpenConnection(self.context.server)
if err:
yield commands.SendData(event.connection, f"server-tls-failed: {err}".encode())
yield commands.SendData(event.connection, f"open-connection failed: {err}".encode())
else:
yield from super()._handle_event(event)
@ -208,9 +206,6 @@ class TestServerTLS:
data = tutils.Placeholder()
assert (
playbook
>> events.DataReceived(tctx.client, b"establish-server-tls")
<< layer.NextLayerHook(tutils.Placeholder())
>> tutils.reply_next_layer(TlsEchoLayer)
<< tls.TlsStartHook(tutils.Placeholder())
>> reply_tls_start()
<< commands.SendData(tctx.server, data)
@ -233,6 +228,13 @@ class TestServerTLS:
assert tctx.server.tls_established
# Echo
assert (
playbook
>> events.DataReceived(tctx.client, b"foo")
<< layer.NextLayerHook(tutils.Placeholder())
>> tutils.reply_next_layer(TlsEchoLayer)
<< commands.SendData(tctx.client, b"foo")
)
_test_echo(playbook, tssl, tctx.server)
with pytest.raises(ssl.SSLWantReadError):

View File

@ -170,7 +170,11 @@ class Playbook:
pass
else:
if hasattr(x, "playbook_eval"):
try:
x = self.expected[i] = x.playbook_eval(self)
except Exception:
self.actual.append(_TracebackInPlaybook(traceback.format_exc()))
break
for name, value in vars(x).items():
if isinstance(value, _Placeholder):
setattr(x, name, value())
@ -273,8 +277,7 @@ class reply(events.Event):
self.to = cmd
break
else:
actual_str = "\n".join(_fmt_entry(x) for x in playbook.actual)
raise AssertionError(f"Expected command {self.to} did not occur:\n{actual_str}")
raise AssertionError(f"Expected command {self.to} did not occur.")
assert isinstance(self.to, commands.Command)
if isinstance(self.to, commands.Hook):