[sans-io] wip: tls establishment semantics

This commit is contained in:
Maximilian Hils 2019-12-27 18:39:01 +01:00
parent 7efe27be74
commit b2060356b6
10 changed files with 294 additions and 286 deletions

View File

@ -74,16 +74,6 @@ class CloseConnection(ConnectionCommand):
""" """
class ProtocolError(Command):
"""
Indicate that an unrecoverable protocol error has occured.
"""
message: str
def __init__(self, message: str):
self.message = message
class Hook(Command): class Hook(Command):
""" """
Callback to the master (like ".ask()") Callback to the master (like ".ask()")
@ -143,5 +133,3 @@ class Log(Command):
def __repr__(self): def __repr__(self):
return f"Log({self.message}, {self.level})" return f"Log({self.message}, {self.level})"

View File

@ -5,10 +5,10 @@ from mitmproxy import flow, http
from mitmproxy.proxy.protocol.http import HTTPMode from mitmproxy.proxy.protocol.http import HTTPMode
from mitmproxy.proxy2 import commands, events, layer from mitmproxy.proxy2 import commands, events, layer
from mitmproxy.proxy2.context import Connection, Context, Server from mitmproxy.proxy2.context import Connection, Context, Server
from mitmproxy.proxy2.layers.tls import EstablishServerTLS, EstablishServerTLSReply, HTTP_ALPNS from mitmproxy.proxy2.layers import tls
from mitmproxy.proxy2.utils import expect from mitmproxy.proxy2.utils import expect
from mitmproxy.utils import human from mitmproxy.utils import human
from ._base import HttpConnection, StreamId from ._base import HttpConnection, StreamId, HttpCommand, ReceiveHttp
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \ from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \ from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \
@ -17,10 +17,6 @@ from ._http1 import Http1Client, Http1Server
from ._http2 import Http2Client from ._http2 import Http2Client
class HttpCommand(commands.Command):
pass
class GetHttpConnection(HttpCommand): class GetHttpConnection(HttpCommand):
""" """
Open a HTTP Connection. This may not actually open a connection, but return an existing HTTP connection instead. Open a HTTP Connection. This may not actually open a connection, but return an existing HTTP connection instead.
@ -47,6 +43,18 @@ class GetHttpConnectionReply(events.CommandReply):
"""connection object, error message""" """connection object, error message"""
class RegisterHttpConnection(HttpCommand):
"""
Register that a HTTP connection has been successfully established.
"""
connection: Connection
err: str
def __init__(self, connection: Connection, err: str):
self.connection = connection
self.err = err
class SendHttp(HttpCommand): class SendHttp(HttpCommand):
connection: Connection connection: Connection
event: HttpEvent event: HttpEvent
@ -59,6 +67,8 @@ class SendHttp(HttpCommand):
return f"Send({self.event})" return f"Send({self.event})"
class HttpStream(layer.Layer): class HttpStream(layer.Layer):
request_body_buf: bytes request_body_buf: bytes
response_body_buf: bytes response_body_buf: bytes
@ -299,7 +309,7 @@ class HttpStream(layer.Layer):
yield from () yield from ()
class HTTPLayer(Layer): class HTTPLayer(layer.Layer):
""" """
ConnectionEvent: We have received b"GET /\r\n\r\n" from the client. ConnectionEvent: We have received b"GET /\r\n\r\n" from the client.
HttpEvent: We have received request headers HttpEvent: We have received request headers
@ -311,47 +321,34 @@ class HTTPLayer(Layer):
mode: HTTPMode mode: HTTPMode
stream_by_command: typing.Dict[commands.Command, HttpStream] stream_by_command: typing.Dict[commands.Command, HttpStream]
streams: typing.Dict[int, HttpStream] streams: typing.Dict[int, HttpStream]
connections: typing.Dict[Connection, typing.Union[HttpConnection, HttpStream]] connections: typing.Dict[Connection, typing.Union[layer.Layer, HttpStream]]
waiting_for_connection: typing.DefaultDict[Connection, typing.List[GetHttpConnection]] waiting_for_establishment: typing.DefaultDict[Connection, typing.List[GetHttpConnection]]
event_queue: typing.Deque[ command_queue: typing.Deque[commands.Command]
typing.Union[HttpEvent, HttpCommand, commands.Command]
]
def __init__(self, context: Context, mode: HTTPMode): def __init__(self, context: Context, mode: HTTPMode):
super().__init__(context) super().__init__(context)
self.mode = mode self.mode = mode
self.waiting_for_connection = collections.defaultdict(list) self.waiting_for_establishment = collections.defaultdict(list)
self.streams = {} self.streams = {}
self.stream_by_command = {} self.stream_by_command = {}
self.event_queue = collections.deque() self.command_queue = collections.deque()
self.connections = { self.connections = {
context.client: Http1Server(context.client) context.client: Http1Server(context.fork())
} }
def __repr__(self): def __repr__(self):
return f"HTTPLayer(conns: {len(self.connections)}, events: {[type(e).__name__ for e in self.event_queue]})" return f"HTTPLayer(conns: {len(self.connections)}, queue: {[type(e).__name__ for e in self.command_queue]})"
def _handle_event(self, event: events.Event): def _handle_event(self, event: events.Event):
if isinstance(event, events.Start): if isinstance(event, events.Start):
return return
elif isinstance(event, (EstablishServerTLSReply, events.OpenConnectionReply)) and \
event.command.connection in self.waiting_for_connection:
if event.reply:
waiting = self.waiting_for_connection.pop(event.command.connection)
for cmd in waiting:
stream = self.stream_by_command.pop(cmd)
self.event_to_child(stream, GetHttpConnectionReply(cmd, (None, event.reply)))
else:
yield from self.make_http_connection(event.command.connection)
elif isinstance(event, events.CommandReply): elif isinstance(event, events.CommandReply):
try: try:
stream = self.stream_by_command.pop(event.command) stream = self.stream_by_command.pop(event.command)
except KeyError: except KeyError:
raise raise
if isinstance(event, events.OpenConnectionReply):
self.connections[event.command.connection] = stream
self.event_to_child(stream, event) self.event_to_child(stream, event)
elif isinstance(event, events.ConnectionEvent): elif isinstance(event, events.ConnectionEvent):
if event.connection == self.context.server and self.context.server not in self.connections: if event.connection == self.context.server and self.context.server not in self.connections:
@ -362,117 +359,129 @@ class HTTPLayer(Layer):
else: else:
raise ValueError(f"Unexpected event: {event}") raise ValueError(f"Unexpected event: {event}")
while self.event_queue: while self.command_queue:
event = self.event_queue.popleft() command = self.command_queue.popleft()
if isinstance(event, RequestHeaders): if isinstance(command, ReceiveHttp):
self.streams[event.stream_id] = self.make_stream() if isinstance(command.event, RequestHeaders):
if isinstance(event, HttpEvent): self.streams[command.event.stream_id] = self.make_stream()
stream = self.streams[event.stream_id] stream = self.streams[command.event.stream_id]
self.event_to_child(stream, event) self.event_to_child(stream, command.event)
elif isinstance(event, SendHttp): elif isinstance(command, SendHttp):
conn = self.connections[event.connection] conn = self.connections[command.connection]
evts = conn.send(event.event) self.event_to_child(conn, command.event)
self.event_queue.extend(evts) elif isinstance(command, GetHttpConnection):
elif isinstance(event, GetHttpConnection): self.get_connection(command)
yield from self.get_connection(event) elif isinstance(command, RegisterHttpConnection):
elif isinstance(event, commands.Command): yield from self.register_connection(command)
yield event elif isinstance(command, commands.Command):
else: yield command
raise ValueError(f"Unexpected event: {event}") else: # pragma: no cover
raise ValueError(f"Not a command command: {command}")
def make_stream(self) -> HttpStream: def make_stream(self) -> HttpStream:
ctx = self.context.fork() ctx = self.context.fork()
stream = HttpStream(ctx) stream = HttpStream(ctx)
if self.debug: if self.debug:
stream.debug = self.debug + " " stream.debug = self.debug + " "
self.event_to_child(stream, events.Start()) self.event_to_child(stream, events.Start())
return stream return stream
def get_connection(self, event: GetHttpConnection): def get_connection(self, event: GetHttpConnection, *, reuse: bool = True):
# Do we already have a connection we can re-use? # Do we already have a connection we can re-use?
for connection, handler in self.connections.items(): for connection, layer in self.connections.items():
connection_suitable = ( connection_suitable = (
reuse and
event.connection_spec_matches(connection) and event.connection_spec_matches(connection) and
( (
isinstance(handler, Http2Client) or # see "tricky multiplexing edge case" in make_http_connection for an explanation
# see "tricky multiplexing edge case" in make_http_connection for an explanation connection.alpn == b"h2" or self.context.client.alpn != b"h2"
isinstance(handler, Http1Client) and self.context.client.alpn != b"h2"
) )
) )
if connection_suitable: if connection_suitable:
stream = self.stream_by_command.pop(event) if connection in self.waiting_for_establishment:
self.event_to_child(stream, GetHttpConnectionReply(event, (connection, None))) self.waiting_for_establishment[connection].append(event)
else:
stream = self.stream_by_command.pop(event)
self.event_to_child(stream, GetHttpConnectionReply(event, (layer, None)))
return return
# Are we waiting for one?
for connection in self.waiting_for_connection:
if event.connection_spec_matches(connection):
self.waiting_for_connection[connection].append(event)
return
# Can we reuse context.server?
can_reuse_context_connection = ( can_reuse_context_connection = (
self.context.server not in self.connections and self.context.server not in self.connections and
self.context.server.connected and self.context.server.connected and
self.context.server.address == event.address and event.connection_spec_matches(self.context.server)
self.context.server.tls == event.tls
) )
if can_reuse_context_connection: context = self.context.fork()
self.waiting_for_connection[self.context.server].append(event) layer = HttpClient(context)
yield from self.make_http_connection(self.context.server) if not can_reuse_context_connection:
# We need a new one. context.server = Server(event.address)
if event.tls:
context.server.tls = True
# TODO: This is a bit ugly, let's make up nicer syntax, e.g. using __truediv__
orig = layer
layer = tls.ServerTLSLayer(context)
layer.child_layer = orig
# TODO: Here we should create other sublayer(s) for upstream proxy etc.
self.connections[context.server] = layer
self.waiting_for_establishment[context.server].append(event)
self.event_to_child(layer, events.Start())
def register_connection(self, command: RegisterHttpConnection):
waiting = self.waiting_for_establishment.pop(command.connection)
if command.err:
reply = (None, command.err)
self.connections.pop(command.connection)
else: else:
connection = Server(event.address) reply = (command.connection, None)
connection.tls = event.tls
self.waiting_for_connection[connection].append(event)
open_command = commands.OpenConnection(connection)
open_command.blocking = object()
yield open_command
def make_http_connection(self, connection: Server) -> None:
if connection.tls and not connection.tls_established:
connection.alpn_offers = list(HTTP_ALPNS)
if not self.context.options.http2:
connection.alpn_offers.remove(b"h2")
new_command = EstablishServerTLS(connection)
new_command.blocking = object()
yield new_command
return
if connection.alpn == b"h2":
raise NotImplementedError
else:
self.connections[connection] = Http1Client(connection)
waiting = self.waiting_for_connection.pop(connection)
for cmd in waiting: for cmd in waiting:
stream = self.stream_by_command.pop(cmd) stream = self.stream_by_command.pop(cmd)
self.event_to_child(stream, GetHttpConnectionReply(cmd, (connection, None))) self.event_to_child(stream, GetHttpConnectionReply(cmd, reply))
# Tricky multiplexing edge case: Assume a h2 client that sends two requests (or receives two responses) # Tricky multiplexing edge case: Assume a h2 client that sends two requests (or receives two responses)
# that neither have a content-length specified nor a chunked transfer encoding. # that neither have a content-length specified nor a chunked transfer encoding.
# We can't process these two flows to the same h1 connection as they would both have # We can't process these two flows to the same h1 connection as they would both have
# "read until eof" semantics. We could force chunked transfer encoding for requests, but can't enforce that # "read until eof" semantics. We could force chunked transfer encoding for requests, but can't enforce that
# for responses. The only workaround left is to open a separate connection for each flow. # for responses. The only workaround left is to open a separate connection for each flow.
if self.context.client.alpn == b"h2" and connection.alpn != b"h2": if not command.err and self.context.client.alpn == b"h2" and command.connection.alpn != b"h2":
for cmd in waiting[1:]: for cmd in waiting[1:]:
new_connection = Server(connection.address) yield from self.get_connection(cmd, reuse=False)
new_connection.tls = connection.tls
self.waiting_for_connection[new_connection].append(cmd)
open_command = commands.OpenConnection(new_connection)
open_command.blocking = object()
yield open_command
break break
def event_to_child( def event_to_child(
self, self,
stream: typing.Union[HttpConnection, HttpStream], child: typing.Union[layer.Layer, HttpStream],
event: events.Event, event: events.Event,
) -> None: ) -> None:
stream_events = list(stream.handle_event(event)) child_commands = list(child.handle_event(event))
for se in stream_events: for cmd in child_commands:
assert isinstance(cmd, commands.Command)
# Streams may yield blocking commands, which ultimately generate CommandReply events. # Streams may yield blocking commands, which ultimately generate CommandReply events.
# Those need to be routed back to the correct stream, so we need to keep track of that. # Those need to be routed back to the correct stream, so we need to keep track of that.
if isinstance(se, commands.Command) and se.blocking: if isinstance(cmd, commands.OpenConnection):
self.stream_by_command[se] = stream self.connections[cmd.connection] = child
self.event_queue.extend(stream_events) if cmd.blocking:
self.stream_by_command[cmd] = child
self.command_queue.extend(child_commands)
class HttpClient(layer.Layer):
@expect(events.Start)
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if self.context.server.connected:
err = None
else:
err = yield commands.OpenConnection(self.context.server)
yield RegisterHttpConnection(self.context.server, err)
if err:
return
if self.context.server.alpn == b"h2":
raise NotImplementedError
else:
child_layer = Http1Client(self.context)
self._handle_event = child_layer.handle_event

View File

@ -1,8 +1,6 @@
import abc
import typing
from dataclasses import dataclass from dataclasses import dataclass
from mitmproxy.proxy2 import events, layer from mitmproxy.proxy2 import events, layer, commands
StreamId = int StreamId = int
@ -18,14 +16,22 @@ class HttpEvent(events.Event):
return f"{type(self).__name__}({repr(x) if x else ''})" return f"{type(self).__name__}({repr(x) if x else ''})"
class HttpConnection(abc.ABC): class HttpConnection(layer.Layer):
@abc.abstractmethod pass
def handle_event(self, event: events.Event) -> typing.Iterator[HttpEvent]:
yield from ()
@abc.abstractmethod
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]: class HttpCommand(commands.Command):
yield from () pass
class ReceiveHttp(HttpCommand):
event: HttpEvent
def __init__(self, event: HttpEvent):
self.event = event
def __repr__(self) -> str:
return f"Receive({self.event})"
__all__ = [ __all__ = [

View File

@ -1,5 +1,5 @@
import abc
import typing import typing
from abc import abstractmethod
import h11 import h11
from h11._readers import ChunkedReader, ContentLengthReader, Http10Reader from h11._readers import ChunkedReader, ContentLengthReader, Http10Reader
@ -9,8 +9,9 @@ from mitmproxy import http
from mitmproxy.net.http import http1 from mitmproxy.net.http import http1
from mitmproxy.net.http.http1 import read_sansio as http1_sansio from mitmproxy.net.http.http1 import read_sansio as http1_sansio
from mitmproxy.proxy2 import commands, events, layer from mitmproxy.proxy2 import commands, events, layer
from mitmproxy.proxy2.context import Client, Connection, Server from mitmproxy.proxy2.context import Client, Connection, Context, Server
from mitmproxy.proxy2.layers.http._base import StreamId from mitmproxy.proxy2.layers.http._base import ReceiveHttp, StreamId
from mitmproxy.proxy2.utils import expect
from ._base import HttpConnection from ._base import HttpConnection
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \ from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
@ -18,28 +19,32 @@ from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders
TBodyReader = typing.Union[ChunkedReader, Http10Reader, ContentLengthReader] TBodyReader = typing.Union[ChunkedReader, Http10Reader, ContentLengthReader]
class Http1Connection(HttpConnection): class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
conn: Connection conn: Connection
stream_id: typing.Optional[StreamId] = None stream_id: typing.Optional[StreamId] = None
request: typing.Optional[http.HTTPRequest] = None request: typing.Optional[http.HTTPRequest] = None
response: typing.Optional[http.HTTPResponse] = None response: typing.Optional[http.HTTPResponse] = None
request_done: bool = False request_done: bool = False
response_done: bool = False response_done: bool = False
state: typing.Callable[[events.Event], typing.Iterator[HttpEvent]] state: typing.Callable[[events.Event], layer.CommandGenerator[None]]
body_reader: TBodyReader body_reader: TBodyReader
buf: ReceiveBuffer buf: ReceiveBuffer
def __init__(self, conn: Connection): def __init__(self, context: Context, conn: Connection):
super().__init__(context)
assert isinstance(conn, Connection) assert isinstance(conn, Connection)
self.conn = conn self.conn = conn
self.buf = ReceiveBuffer() self.buf = ReceiveBuffer()
def handle_event(self, event: events.Event) -> typing.Iterator[HttpEvent]: def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, events.DataReceived): if isinstance(event, HttpEvent):
self.buf += event.data yield from self.send(event)
yield from self.state(event) else:
if isinstance(event, events.DataReceived):
self.buf += event.data
yield from self.state(event)
@abstractmethod @abc.abstractmethod
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]: def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
yield from () yield from ()
@ -51,7 +56,7 @@ class Http1Connection(HttpConnection):
else: else:
return ContentLengthReader(expected_size) return ContentLengthReader(expected_size)
def read_body(self, event: events.Event, is_request: bool) -> typing.Iterator[HttpEvent]: def read_body(self, event: events.Event, is_request: bool) -> layer.CommandGenerator[None]:
while True: while True:
try: try:
if isinstance(event, events.DataReceived): if isinstance(event, events.DataReceived):
@ -63,9 +68,9 @@ class Http1Connection(HttpConnection):
except h11.ProtocolError as e: except h11.ProtocolError as e:
yield commands.CloseConnection(self.conn) yield commands.CloseConnection(self.conn)
if is_request: if is_request:
yield RequestProtocolError(self.stream_id, str(e)) yield ReceiveHttp(RequestProtocolError(self.stream_id, str(e)))
else: else:
yield ResponseProtocolError(self.stream_id, str(e)) yield ReceiveHttp(ResponseProtocolError(self.stream_id, str(e)))
return return
if h11_event is None: if h11_event is None:
@ -73,17 +78,17 @@ class Http1Connection(HttpConnection):
elif isinstance(h11_event, h11.Data): elif isinstance(h11_event, h11.Data):
h11_event.data: bytearray # type checking h11_event.data: bytearray # type checking
if is_request: if is_request:
yield RequestData(self.stream_id, bytes(h11_event.data)) yield ReceiveHttp(RequestData(self.stream_id, bytes(h11_event.data)))
else: else:
yield ResponseData(self.stream_id, bytes(h11_event.data)) yield ReceiveHttp(ResponseData(self.stream_id, bytes(h11_event.data)))
elif isinstance(h11_event, h11.EndOfMessage): elif isinstance(h11_event, h11.EndOfMessage):
if is_request: if is_request:
yield RequestEndOfMessage(self.stream_id) yield ReceiveHttp(RequestEndOfMessage(self.stream_id))
else: else:
yield ResponseEndOfMessage(self.stream_id) yield ReceiveHttp(ResponseEndOfMessage(self.stream_id))
return return
def wait(self, event: events.Event) -> typing.Iterator[HttpEvent]: def wait(self, event: events.Event) -> layer.CommandGenerator[None]:
""" """
We wait for the current flow to be finished before parsing the next message, We wait for the current flow to be finished before parsing the next message,
as we may want to upgrade to WebSocket or plain TCP before that. as we may want to upgrade to WebSocket or plain TCP before that.
@ -101,8 +106,8 @@ class Http1Server(Http1Connection):
"""A simple HTTP/1 server with no pipelining support.""" """A simple HTTP/1 server with no pipelining support."""
conn: Client conn: Client
def __init__(self, conn: Client): def __init__(self, context: Context):
super().__init__(conn) super().__init__(context, context.client)
self.stream_id = 1 self.stream_id = 1
self.state = self.read_request_headers self.state = self.read_request_headers
@ -138,7 +143,7 @@ class Http1Server(Http1Connection):
else: else:
raise NotImplementedError(f"{event}") raise NotImplementedError(f"{event}")
def mark_done(self, *, request: bool = False, response: bool = False): def mark_done(self, *, request: bool = False, response: bool = False) -> layer.CommandGenerator[None]:
if request: if request:
self.request_done = True self.request_done = True
if response: if response:
@ -152,13 +157,13 @@ class Http1Server(Http1Connection):
elif self.request_done: elif self.request_done:
self.state = self.wait self.state = self.wait
def read_request_headers(self, event: events.Event) -> typing.Iterator[HttpEvent]: def read_request_headers(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, events.DataReceived): if isinstance(event, events.DataReceived):
request_head = self.buf.maybe_extract_lines() request_head = self.buf.maybe_extract_lines()
if request_head: if request_head:
request_head = [bytes(x) for x in request_head] # TODO: Make url.parse compatible with bytearrays request_head = [bytes(x) for x in request_head] # TODO: Make url.parse compatible with bytearrays
self.request = http.HTTPRequest.wrap(http1_sansio.read_request_head(request_head)) self.request = http.HTTPRequest.wrap(http1_sansio.read_request_head(request_head))
yield RequestHeaders(self.stream_id, self.request) yield ReceiveHttp(RequestHeaders(self.stream_id, self.request))
if self.request.first_line_format == "authority": if self.request.first_line_format == "authority":
# The previous proxy server implementation tried to read the request body here: # The previous proxy server implementation tried to read the request body here:
@ -180,10 +185,10 @@ class Http1Server(Http1Connection):
else: else:
raise ValueError(f"Unexpected event: {event}") raise ValueError(f"Unexpected event: {event}")
def read_request_body(self, event: events.Event) -> typing.Iterator[HttpEvent]: def read_request_body(self, event: events.Event) -> layer.CommandGenerator[None]:
for e in self.read_body(event, True): for e in self.read_body(event, True):
yield e yield e
if isinstance(e, RequestEndOfMessage): if isinstance(e, ReceiveHttp) and isinstance(e.event, RequestEndOfMessage):
yield from self.mark_done(request=True) yield from self.mark_done(request=True)
@ -192,8 +197,8 @@ class Http1Client(Http1Connection):
send_queue: typing.List[HttpEvent] send_queue: typing.List[HttpEvent]
"""A queue of send events for flows other than the one that is currently being transmitted.""" """A queue of send events for flows other than the one that is currently being transmitted."""
def __init__(self, conn: Server): def __init__(self, context: Context):
super().__init__(conn) super().__init__(context, context.server)
self.state = self.read_response_headers self.state = self.read_response_headers
self.send_queue = [] self.send_queue = []
@ -231,7 +236,7 @@ class Http1Client(Http1Connection):
else: else:
raise NotImplementedError(f"{event}") raise NotImplementedError(f"{event}")
def mark_done(self, *, request: bool = False, response: bool = False): def mark_done(self, *, request: bool = False, response: bool = False) -> layer.CommandGenerator[None]:
if request: if request:
self.request_done = True self.request_done = True
if response: if response:
@ -246,15 +251,15 @@ class Http1Client(Http1Connection):
for ev in send_queue: for ev in send_queue:
yield from self.send(ev) yield from self.send(ev)
def read_response_headers(self, event: events.Event) -> typing.Iterator[HttpEvent]: @expect(events.ConnectionEvent)
assert isinstance(event, events.ConnectionEvent) def read_response_headers(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
if isinstance(event, events.DataReceived): if isinstance(event, events.DataReceived):
response_head = self.buf.maybe_extract_lines() response_head = self.buf.maybe_extract_lines()
if response_head: if response_head:
response_head = [bytes(x) for x in response_head] response_head = [bytes(x) for x in response_head]
self.response = http.HTTPResponse.wrap(http1_sansio.read_response_head(response_head)) self.response = http.HTTPResponse.wrap(http1_sansio.read_response_head(response_head))
yield ResponseHeaders(self.stream_id, self.response) yield ReceiveHttp(ResponseHeaders(self.stream_id, self.response))
expected_size = http1.expected_http_body_size(self.request, self.response) expected_size = http1.expected_http_body_size(self.request, self.response)
self.body_reader = self.make_body_reader(expected_size) self.body_reader = self.make_body_reader(expected_size)
@ -266,23 +271,24 @@ class Http1Client(Http1Connection):
elif isinstance(event, events.ConnectionClosed): elif isinstance(event, events.ConnectionClosed):
if self.stream_id: if self.stream_id:
if self.buf: if self.buf:
yield ResponseProtocolError(self.stream_id, f"unexpected server response: {bytes(self.buf)}") yield ReceiveHttp(
ResponseProtocolError(self.stream_id, f"unexpected server response: {bytes(self.buf)}"))
else: else:
# The server has closed the connection to prevent us from continuing. # The server has closed the connection to prevent us from continuing.
# We need to signal that to the stream. # We need to signal that to the stream.
# https://tools.ietf.org/html/rfc7231#section-6.5.11 # https://tools.ietf.org/html/rfc7231#section-6.5.11
yield ResponseProtocolError(self.stream_id, "server closed connection") yield ReceiveHttp(ResponseProtocolError(self.stream_id, "server closed connection"))
else: else:
return return
yield commands.CloseConnection(self.conn) yield commands.CloseConnection(self.conn)
else: else:
raise ValueError(f"Unexpected event: {event}") raise ValueError(f"Unexpected event: {event}")
def read_response_body(self, event: events.Event) -> typing.Iterator[HttpEvent]: @expect(events.ConnectionEvent)
assert isinstance(event, events.ConnectionEvent) def read_response_body(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
for e in self.read_body(event, False): for e in self.read_body(event, False):
yield e yield e
if isinstance(e, ResponseEndOfMessage): if isinstance(e, ReceiveHttp) and isinstance(e.event, ResponseEndOfMessage):
self.state = self.read_response_headers self.state = self.read_response_headers
yield from self.mark_done(response=True) yield from self.mark_done(response=True)

View File

@ -67,7 +67,7 @@ class TCPLayer(layer.Layer):
if self.flow: if self.flow:
tcp_message = tcp.TCPMessage(from_client, event.data) tcp_message = tcp.TCPMessage(from_client, event.data)
self.flow.messages.append(tcp_message) self.flow.messages.append(tcp_message)
yield TcpMessageHook(self.flow)t yield TcpMessageHook(self.flow)
yield commands.SendData(send_to, tcp_message.content) yield commands.SendData(send_to, tcp_message.content)
else: else:
yield commands.SendData(send_to, event.data) yield commands.SendData(send_to, event.data)

View File

@ -1,7 +1,7 @@
import struct import struct
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Generator, Iterator, Optional, Tuple from typing import Iterator, Optional, Tuple
from OpenSSL import SSL from OpenSSL import SSL
@ -93,20 +93,6 @@ def parse_client_hello(data: bytes) -> Optional[net_tls.ClientHello]:
HTTP_ALPNS = (b"h2", b"http/1.1", b"http/1.0", b"http/0.9") HTTP_ALPNS = (b"h2", b"http/1.1", b"http/1.0", b"http/0.9")
class EstablishServerTLS(commands.ConnectionCommand):
"""Establish TLS on the given connection.
If TLS establishment fails, the connection will automatically be closed by the TLS layer."""
connection: context.Server
blocking = True
class EstablishServerTLSReply(events.CommandReply):
command: EstablishServerTLS
reply: Optional[str]
"""error message"""
# We need these classes as hooks can only have one argument at the moment. # We need these classes as hooks can only have one argument at the moment.
@dataclass @dataclass
@ -131,62 +117,61 @@ class TlsClienthelloHook(Hook):
class _TLSLayer(layer.Layer): class _TLSLayer(layer.Layer):
tls: Dict[context.Connection, SSL.Connection] conn: context.Connection
"""The connection for which we do TLS"""
tls: SSL.Connection = None
"""The OpenSSL connection object"""
child_layer: layer.Layer child_layer: layer.Layer
ssl_context: Optional[SSL.Context] = None
def __init__(self, context: context.Context): def __init__(self, context: context.Context, conn: context.Connection):
super().__init__(context) super().__init__(context)
self.tls = {} self.conn = conn
def __repr__(self): def __repr__(self):
if not self.tls: if not self.tls:
state = "inactive" state = "inactive"
elif self.conn.tls_established:
state = f"passthrough {self.conn.sni} {self.conn.alpn}"
else: else:
conn_states = [] state = f"negotiating {self.conn.sni} {self.conn.alpn}"
for conn in self.tls:
if conn.tls_established:
conn_states.append(f"passthrough {conn.sni} {conn.alpn}")
else:
conn_states.append(f"negotiating {conn.sni} {conn.alpn}")
state = ", ".join(conn_states)
return f"{type(self).__name__}({state})" return f"{type(self).__name__}({state})"
def start_tls(self, conn: context.Connection, initial_data: bytes = b""): def start_tls(self, initial_data: bytes = b""):
assert conn not in self.tls assert not self.tls
assert conn.connected assert self.conn.connected
conn.tls = True self.conn.tls = True
tls_start = TlsStartData(conn, self.context) tls_start = TlsStartData(self.conn, self.context)
yield TlsStartHook(tls_start) yield TlsStartHook(tls_start)
self.tls[conn] = tls_start.ssl_conn assert tls_start.ssl_conn
self.tls = tls_start.ssl_conn
yield from self.negotiate(conn, initial_data) yield from self.negotiate(initial_data)
def tls_interact(self, conn: context.Connection) -> layer.CommandGenerator[None]: def tls_interact(self) -> layer.CommandGenerator[None]:
while True: while True:
try: try:
data = self.tls[conn].bio_read(65535) data = self.tls.bio_read(65535)
except SSL.WantReadError: except SSL.WantReadError:
# Okay, nothing more waiting to be sent. # Okay, nothing more waiting to be sent.
return return
else: else:
yield commands.SendData(conn, data) yield commands.SendData(self.conn, data)
def negotiate(self, conn: context.Connection, data: bytes) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]: def negotiate(self, data: bytes) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]:
# bio_write errors for b"", so we need to check first if we actually received something. # bio_write errors for b"", so we need to check first if we actually received something.
if data: if data:
self.tls[conn].bio_write(data) self.tls.bio_write(data)
try: try:
self.tls[conn].do_handshake() self.tls.do_handshake()
except SSL.WantReadError: except SSL.WantReadError:
yield from self.tls_interact(conn) yield from self.tls_interact()
return False, None return False, None
except SSL.Error as e: except SSL.Error as e:
# provide more detailed information for some errors. # provide more detailed information for some errors.
last_err = e.args and isinstance(e.args[0], list) and e.args[0] and e.args[0][-1] last_err = e.args and isinstance(e.args[0], list) and e.args[0] and e.args[0][-1]
if last_err == ('SSL routines', 'tls_process_server_certificate', 'certificate verify failed'): if last_err == ('SSL routines', 'tls_process_server_certificate', 'certificate verify failed'):
verify_result = SSL._lib.SSL_get_verify_result(self.tls[conn]._ssl) verify_result = SSL._lib.SSL_get_verify_result(self.tls._ssl)
error = SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(verify_result)).decode() error = SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(verify_result)).decode()
err = f"Certificate verify failed: {error}" err = f"Certificate verify failed: {error}"
elif last_err in [ elif last_err in [
@ -196,40 +181,40 @@ class _TLSLayer(layer.Layer):
err = last_err[2] err = last_err[2]
else: else:
err = repr(e) err = repr(e)
yield from self.on_handshake_error(conn, err) yield from self.on_handshake_error(err)
return False, err return False, err
else: else:
# Get all peer certificates. # Get all peer certificates.
# https://www.openssl.org/docs/man1.1.1/man3/SSL_get_peer_cert_chain.html # https://www.openssl.org/docs/man1.1.1/man3/SSL_get_peer_cert_chain.html
# If called on the client side, the stack also contains the peer's certificate; if called on the server # If called on the client side, the stack also contains the peer's certificate; if called on the server
# side, the peer's certificate must be obtained separately using SSL_get_peer_certificate(3). # side, the peer's certificate must be obtained separately using SSL_get_peer_certificate(3).
all_certs = self.tls[conn].get_peer_cert_chain() or [] all_certs = self.tls.get_peer_cert_chain() or []
if conn == self.context.client: if self.conn == self.context.client:
cert = self.tls[conn].get_peer_certificate() cert = self.tls.get_peer_certificate()
if cert: if cert:
all_certs.insert(0, cert) all_certs.insert(0, cert)
conn.tls_established = True self.conn.tls_established = True
conn.sni = self.tls[conn].get_servername() self.conn.sni = self.tls.get_servername()
conn.alpn = self.tls[conn].get_alpn_proto_negotiated() self.conn.alpn = self.tls.get_alpn_proto_negotiated()
conn.certificate_chain = [certs.Cert(x) for x in all_certs] self.conn.certificate_chain = [certs.Cert(x) for x in all_certs]
conn.cipher_list = self.tls[conn].get_cipher_list() self.conn.cipher_list = self.tls.get_cipher_list()
conn.tls_version = self.tls[conn].get_protocol_version_name() self.conn.tls_version = self.tls.get_protocol_version_name()
conn.timestamp_tls_setup = time.time() self.conn.timestamp_tls_setup = time.time()
yield commands.Log(f"TLS established: {conn}") yield commands.Log(f"TLS established: {self.conn}")
yield from self.receive(conn, b"") yield from self.receive(b"")
return True, None return True, None
def receive(self, conn: context.Connection, data: bytes): def receive(self, data: bytes):
if data: if data:
self.tls[conn].bio_write(data) self.tls.bio_write(data)
yield from self.tls_interact(conn) yield from self.tls_interact()
plaintext = bytearray() plaintext = bytearray()
close = False close = False
while True: while True:
try: try:
plaintext.extend(self.tls[conn].recv(65535)) plaintext.extend(self.tls.recv(65535))
except SSL.WantReadError: except SSL.WantReadError:
break break
except SSL.ZeroReturnError: except SSL.ZeroReturnError:
@ -238,76 +223,90 @@ class _TLSLayer(layer.Layer):
if plaintext: if plaintext:
yield from self.event_to_child( yield from self.event_to_child(
events.DataReceived(conn, bytes(plaintext)) events.DataReceived(self.conn, bytes(plaintext))
) )
if close: if close:
conn.state &= ~context.ConnectionState.CAN_READ self.conn.state &= ~context.ConnectionState.CAN_READ
yield commands.Log(f"TLS close_notify {conn=}") yield commands.Log(f"TLS close_notify {self.conn}")
yield from self.event_to_child( yield from self.event_to_child(
events.ConnectionClosed(conn) events.ConnectionClosed(self.conn)
) )
def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]: def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]:
for command in self.child_layer.handle_event(event): for command in self.child_layer.handle_event(event):
if isinstance(command, commands.SendData) and command.connection in self.tls: if isinstance(command, commands.SendData) and command.connection == self.conn:
self.tls[command.connection].sendall(command.data) self.tls.sendall(command.data)
yield from self.tls_interact(command.connection) yield from self.tls_interact()
else: else:
yield command yield command
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, events.DataReceived) and event.connection in self.tls: if isinstance(event, events.DataReceived) and event.connection == self.conn:
if not event.connection.tls_established: if not event.connection.tls_established:
yield from self.negotiate(event.connection, event.data) yield from self.negotiate(event.data)
else: else:
yield from self.receive(event.connection, event.data) yield from self.receive(event.data)
elif isinstance(event, events.ConnectionClosed) and event.connection in self.tls: elif isinstance(event, events.ConnectionClosed) and event.connection == self.conn:
if event.connection.tls_established: if event.connection.tls_established:
if self.tls[event.connection].get_shutdown() & SSL.RECEIVED_SHUTDOWN: if self.tls.get_shutdown() & SSL.RECEIVED_SHUTDOWN:
pass # We have already dispatched a ConnectionClosed to the child layer. pass # We have already dispatched a ConnectionClosed to the child layer.
else: else:
yield from self.event_to_child(event) yield from self.event_to_child(event)
else: else:
yield from self.on_handshake_error(event.connection, "connection closed without notice") yield from self.on_handshake_error("connection closed without notice")
else: else:
yield from self.event_to_child(event) yield from self.event_to_child(event)
def on_handshake_error(self, conn: context.Connection, err: str) -> layer.CommandGenerator[None]: def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
yield commands.CloseConnection(conn) yield commands.CloseConnection(self.conn)
class ServerTLSLayer(_TLSLayer): class ServerTLSLayer(_TLSLayer):
""" """
This layer manages TLS for potentially multiple server connections. This layer establishes TLS for a single server connection.
""" """
command_to_reply_to: Dict[context.Connection, commands.OpenConnection] command_to_reply_to: Optional[commands.OpenConnection] = None
def __init__(self, context: context.Context): def __init__(self, context: context.Context):
super().__init__(context) super().__init__(context, context.server)
self.command_to_reply_to = {}
self.child_layer = layer.NextLayer(self.context) self.child_layer = layer.NextLayer(self.context)
def negotiate(self, conn: context.Connection, data: bytes) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]: @expect(events.Start)
done, err = yield from super().negotiate(conn, data) def state_start(self, _) -> layer.CommandGenerator[None]:
self.context.server.tls = True
if self.context.server.connected:
yield from self.start_tls()
self._handle_event = super()._handle_event
_handle_event = state_start
def negotiate(self, data: bytes) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]:
done, err = yield from super().negotiate(data)
if done or err: if done or err:
cmd = self.command_to_reply_to.pop(conn) cmd = self.command_to_reply_to
yield from self.event_to_child(events.OpenConnectionReply(cmd, err)) yield from self.event_to_child(events.OpenConnectionReply(cmd, err))
self.command_to_reply_to = None
return done, err return done, err
def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]: def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]:
for command in super().event_to_child(event): for command in super().event_to_child(event):
if isinstance(command, EstablishServerTLS): if isinstance(command, commands.OpenConnection) and command.connection == self.context.server:
self.command_to_reply_to[command.connection] = command # create our own OpenConnection command object that blocks here.
yield from self.start_tls(command.connection) err = yield commands.OpenConnection(command.connection)
if err:
yield from self.event_to_child(events.OpenConnectionReply(command, err))
else:
self.command_to_reply_to = command
yield from self.start_tls()
else: else:
yield command yield command
def on_handshake_error(self, conn: context.Connection, err: str) -> layer.CommandGenerator[None]: def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
yield commands.Log( yield commands.Log(
f"Server TLS handshake failed. {err}", f"Server TLS handshake failed. {err}",
level="warn" level="warn"
) )
yield from super().on_handshake_error(conn, err) yield from super().on_handshake_error(err)
class ClientTLSLayer(_TLSLayer): class ClientTLSLayer(_TLSLayer):
@ -331,10 +330,9 @@ class ClientTLSLayer(_TLSLayer):
def __init__(self, context: context.Context): def __init__(self, context: context.Context):
assert isinstance(context.layers[-1], ServerTLSLayer) assert isinstance(context.layers[-1], ServerTLSLayer)
super().__init__(context) super().__init__(context, self.context.client)
self.recv_buffer = bytearray() self.recv_buffer = bytearray()
self.child_layer = layer.NextLayer(self.context) self.child_layer = layer.NextLayer(self.context)
self._handle_event = self.state_start
@expect(events.Start) @expect(events.Start)
def state_start(self, _) -> layer.CommandGenerator[None]: def state_start(self, _) -> layer.CommandGenerator[None]:
@ -342,20 +340,21 @@ class ClientTLSLayer(_TLSLayer):
self._handle_event = self.state_wait_for_clienthello self._handle_event = self.state_wait_for_clienthello
yield from () yield from ()
_handle_event = state_start
def state_wait_for_clienthello(self, event: events.Event) -> layer.CommandGenerator[None]: def state_wait_for_clienthello(self, event: events.Event) -> layer.CommandGenerator[None]:
client = self.context.client if isinstance(event, events.DataReceived) and event.connection == self.conn:
if isinstance(event, events.DataReceived) and event.connection == client:
self.recv_buffer.extend(event.data) self.recv_buffer.extend(event.data)
try: try:
client_hello = parse_client_hello(self.recv_buffer) client_hello = parse_client_hello(self.recv_buffer)
except ValueError: except ValueError:
yield commands.Log(f"Cannot parse ClientHello: {self.recv_buffer.hex()}") yield commands.Log(f"Cannot parse ClientHello: {self.recv_buffer.hex()}")
yield commands.CloseConnection(client) yield commands.CloseConnection(self.conn)
return return
if client_hello: if client_hello:
client.sni = client_hello.sni self.conn.sni = client_hello.sni
client.alpn_offers = client_hello.alpn_protocols self.conn.alpn_offers = client_hello.alpn_protocols
tls_clienthello = ClientHelloData(self.context) tls_clienthello = ClientHelloData(self.context)
yield TlsClienthelloHook(tls_clienthello) yield TlsClienthelloHook(tls_clienthello)
@ -365,7 +364,7 @@ class ClientTLSLayer(_TLSLayer):
yield commands.Log("Unable to establish TLS connection with server. " yield commands.Log("Unable to establish TLS connection with server. "
"Trying to establish TLS with client anyway.") "Trying to establish TLS with client anyway.")
yield from self.start_tls(client, bytes(self.recv_buffer)) yield from self.start_tls(bytes(self.recv_buffer))
self.recv_buffer.clear() self.recv_buffer.clear()
self._handle_event = super()._handle_event self._handle_event = super()._handle_event
@ -379,25 +378,16 @@ class ClientTLSLayer(_TLSLayer):
We often need information from the upstream connection to establish TLS with the client. We often need information from the upstream connection to establish TLS with the client.
For example, we need to check if the client does ALPN or not. For example, we need to check if the client does ALPN or not.
""" """
server = self.context.server err = yield commands.OpenConnection(self.context.server)
if not server.connected:
err = yield commands.OpenConnection(server)
if err:
yield commands.Log(
f"Cannot establish server connection: {err}"
)
return err
err = yield EstablishServerTLS(server)
if err: if err:
yield commands.Log( yield commands.Log(f"Cannot establish server connection: {err}")
f"Cannot establish TLS with server: {err}"
)
return err return err
else:
return None
def on_handshake_error(self, conn: context.Connection, err: str) -> layer.CommandGenerator[None]: def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
if conn.sni: if self.conn.sni:
dest = conn.sni.decode("idna") dest = self.conn.sni.decode("idna")
else: else:
dest = human.format_address(self.context.server.address) dest = human.format_address(self.context.server.address)
if "unknown ca" in err or "bad certificate" in err: if "unknown ca" in err or "bad certificate" in err:
@ -409,4 +399,4 @@ class ClientTLSLayer(_TLSLayer):
f"The client {keyword} trust the proxy's certificate for {dest} ({err})", f"The client {keyword} trust the proxy's certificate for {dest} ({err})",
level="warn" level="warn"
) )
yield from super().on_handshake_error(conn, err) yield from super().on_handshake_error(err)

View File

@ -18,7 +18,7 @@ def expect(*event_types):
if isinstance(event, event_types): if isinstance(event, event_types):
yield from f(self, event) yield from f(self, event)
else: else:
event_types_str = '|'.join(e.__name__ for e in event_types) event_types_str = '|'.join(e.__name__ for e in event_types) or "no events"
raise AssertionError(f"Unexpected event type at {f.__qualname__}: Expected {event_types_str}, got {event}.") raise AssertionError(f"Unexpected event type at {f.__qualname__}: Expected {event_types_str}, got {event}.")
return wrapper return wrapper

View File

@ -106,8 +106,8 @@ def test_redirect(strategy, https_server, https_client, tctx):
p << OpenConnection(server) p << OpenConnection(server)
p >> reply(None) p >> reply(None)
if https_server: if https_server:
p << tls.EstablishServerTLS(server) pass # p << tls.EstablishServerTLS(server)
p >> reply_establish_server_tls() # p >> reply_establish_server_tls()
p << SendData(server, b"GET / HTTP/1.1\r\nHost: redirected.site\r\n\r\n") p << SendData(server, b"GET / HTTP/1.1\r\nHost: redirected.site\r\n\r\n")
p >> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!") p >> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
p << SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!") p << SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
@ -123,6 +123,7 @@ def test_multiple_server_connections(tctx):
"""Test multiple requests being rewritten to different targets.""" """Test multiple requests being rewritten to different targets."""
server1 = Placeholder() server1 = Placeholder()
server2 = Placeholder() server2 = Placeholder()
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
def redirect(to: str): def redirect(to: str):
def side_effect(flow: HTTPFlow): def side_effect(flow: HTTPFlow):
@ -131,7 +132,7 @@ def test_multiple_server_connections(tctx):
return side_effect return side_effect
assert ( assert (
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False) playbook
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n") >> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< http.HttpRequestHook(Placeholder()) << http.HttpRequestHook(Placeholder())
>> reply(side_effect=redirect("http://one.redirect/")) >> reply(side_effect=redirect("http://one.redirect/"))
@ -140,6 +141,9 @@ def test_multiple_server_connections(tctx):
<< SendData(server1, b"GET / HTTP/1.1\r\nHost: one.redirect\r\n\r\n") << SendData(server1, b"GET / HTTP/1.1\r\nHost: one.redirect\r\n\r\n")
>> DataReceived(server1, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n") >> DataReceived(server1, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n") << SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
)
assert (
playbook
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n") >> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< http.HttpRequestHook(Placeholder()) << http.HttpRequestHook(Placeholder())
>> reply(side_effect=redirect("http://two.redirect/")) >> reply(side_effect=redirect("http://two.redirect/"))

View File

@ -4,8 +4,6 @@ import typing
import pytest import pytest
from OpenSSL import SSL from OpenSSL import SSL
import mitmproxy.proxy2.layer
import mitmproxy.proxy2.layers.tls
from mitmproxy.proxy2 import commands, context, events, layer from mitmproxy.proxy2 import commands, context, events, layer
from mitmproxy.proxy2.layers import tls from mitmproxy.proxy2.layers import tls
from mitmproxy.utils import data from mitmproxy.utils import data
@ -110,11 +108,11 @@ class TlsEchoLayer(tutils.EchoLayer):
err: typing.Optional[str] = None err: typing.Optional[str] = None
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, events.DataReceived) and event.data == b"establish-server-tls": if isinstance(event, events.DataReceived) and event.data == b"open-connection":
# noinspection PyTypeChecker # noinspection PyTypeChecker
err = yield tls.EstablishServerTLS(self.context.server) err = yield commands.OpenConnection(self.context.server)
if err: if err:
yield commands.SendData(event.connection, f"server-tls-failed: {err}".encode()) yield commands.SendData(event.connection, f"open-connection failed: {err}".encode())
else: else:
yield from super()._handle_event(event) yield from super()._handle_event(event)
@ -208,9 +206,6 @@ class TestServerTLS:
data = tutils.Placeholder() data = tutils.Placeholder()
assert ( assert (
playbook playbook
>> events.DataReceived(tctx.client, b"establish-server-tls")
<< layer.NextLayerHook(tutils.Placeholder())
>> tutils.reply_next_layer(TlsEchoLayer)
<< tls.TlsStartHook(tutils.Placeholder()) << tls.TlsStartHook(tutils.Placeholder())
>> reply_tls_start() >> reply_tls_start()
<< commands.SendData(tctx.server, data) << commands.SendData(tctx.server, data)
@ -233,6 +228,13 @@ class TestServerTLS:
assert tctx.server.tls_established assert tctx.server.tls_established
# Echo # Echo
assert (
playbook
>> events.DataReceived(tctx.client, b"foo")
<< layer.NextLayerHook(tutils.Placeholder())
>> tutils.reply_next_layer(TlsEchoLayer)
<< commands.SendData(tctx.client, b"foo")
)
_test_echo(playbook, tssl, tctx.server) _test_echo(playbook, tssl, tctx.server)
with pytest.raises(ssl.SSLWantReadError): with pytest.raises(ssl.SSLWantReadError):
@ -274,7 +276,7 @@ class TestServerTLS:
assert ( assert (
playbook playbook
>> events.DataReceived(tctx.server, tssl.out.read()) >> events.DataReceived(tctx.server, tssl.out.read())
<< commands.Log("Server TLS handshake failed. Certificate verify failed: Hostname mismatch","warn") << commands.Log("Server TLS handshake failed. Certificate verify failed: Hostname mismatch", "warn")
<< commands.CloseConnection(tctx.server) << commands.CloseConnection(tctx.server)
<< commands.SendData(tctx.client, b"server-tls-failed: Certificate verify failed: Hostname mismatch") << commands.SendData(tctx.client, b"server-tls-failed: Certificate verify failed: Hostname mismatch")
) )

View File

@ -170,7 +170,11 @@ class Playbook:
pass pass
else: else:
if hasattr(x, "playbook_eval"): if hasattr(x, "playbook_eval"):
x = self.expected[i] = x.playbook_eval(self) try:
x = self.expected[i] = x.playbook_eval(self)
except Exception:
self.actual.append(_TracebackInPlaybook(traceback.format_exc()))
break
for name, value in vars(x).items(): for name, value in vars(x).items():
if isinstance(value, _Placeholder): if isinstance(value, _Placeholder):
setattr(x, name, value()) setattr(x, name, value())
@ -273,8 +277,7 @@ class reply(events.Event):
self.to = cmd self.to = cmd
break break
else: else:
actual_str = "\n".join(_fmt_entry(x) for x in playbook.actual) raise AssertionError(f"Expected command {self.to} did not occur.")
raise AssertionError(f"Expected command {self.to} did not occur:\n{actual_str}")
assert isinstance(self.to, commands.Command) assert isinstance(self.to, commands.Command)
if isinstance(self.to, commands.Hook): if isinstance(self.to, commands.Hook):