mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 00:01:36 +00:00
[sans-io] wip: tls establishment semantics
This commit is contained in:
parent
7efe27be74
commit
b2060356b6
@ -74,16 +74,6 @@ class CloseConnection(ConnectionCommand):
|
||||
"""
|
||||
|
||||
|
||||
class ProtocolError(Command):
|
||||
"""
|
||||
Indicate that an unrecoverable protocol error has occured.
|
||||
"""
|
||||
message: str
|
||||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
|
||||
|
||||
class Hook(Command):
|
||||
"""
|
||||
Callback to the master (like ".ask()")
|
||||
@ -143,5 +133,3 @@ class Log(Command):
|
||||
|
||||
def __repr__(self):
|
||||
return f"Log({self.message}, {self.level})"
|
||||
|
||||
|
||||
|
@ -5,10 +5,10 @@ from mitmproxy import flow, http
|
||||
from mitmproxy.proxy.protocol.http import HTTPMode
|
||||
from mitmproxy.proxy2 import commands, events, layer
|
||||
from mitmproxy.proxy2.context import Connection, Context, Server
|
||||
from mitmproxy.proxy2.layers.tls import EstablishServerTLS, EstablishServerTLSReply, HTTP_ALPNS
|
||||
from mitmproxy.proxy2.layers import tls
|
||||
from mitmproxy.proxy2.utils import expect
|
||||
from mitmproxy.utils import human
|
||||
from ._base import HttpConnection, StreamId
|
||||
from ._base import HttpConnection, StreamId, HttpCommand, ReceiveHttp
|
||||
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
|
||||
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
|
||||
from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \
|
||||
@ -17,10 +17,6 @@ from ._http1 import Http1Client, Http1Server
|
||||
from ._http2 import Http2Client
|
||||
|
||||
|
||||
class HttpCommand(commands.Command):
|
||||
pass
|
||||
|
||||
|
||||
class GetHttpConnection(HttpCommand):
|
||||
"""
|
||||
Open a HTTP Connection. This may not actually open a connection, but return an existing HTTP connection instead.
|
||||
@ -47,6 +43,18 @@ class GetHttpConnectionReply(events.CommandReply):
|
||||
"""connection object, error message"""
|
||||
|
||||
|
||||
class RegisterHttpConnection(HttpCommand):
|
||||
"""
|
||||
Register that a HTTP connection has been successfully established.
|
||||
"""
|
||||
connection: Connection
|
||||
err: str
|
||||
|
||||
def __init__(self, connection: Connection, err: str):
|
||||
self.connection = connection
|
||||
self.err = err
|
||||
|
||||
|
||||
class SendHttp(HttpCommand):
|
||||
connection: Connection
|
||||
event: HttpEvent
|
||||
@ -59,6 +67,8 @@ class SendHttp(HttpCommand):
|
||||
return f"Send({self.event})"
|
||||
|
||||
|
||||
|
||||
|
||||
class HttpStream(layer.Layer):
|
||||
request_body_buf: bytes
|
||||
response_body_buf: bytes
|
||||
@ -299,7 +309,7 @@ class HttpStream(layer.Layer):
|
||||
yield from ()
|
||||
|
||||
|
||||
class HTTPLayer(Layer):
|
||||
class HTTPLayer(layer.Layer):
|
||||
"""
|
||||
ConnectionEvent: We have received b"GET /\r\n\r\n" from the client.
|
||||
HttpEvent: We have received request headers
|
||||
@ -311,47 +321,34 @@ class HTTPLayer(Layer):
|
||||
mode: HTTPMode
|
||||
stream_by_command: typing.Dict[commands.Command, HttpStream]
|
||||
streams: typing.Dict[int, HttpStream]
|
||||
connections: typing.Dict[Connection, typing.Union[HttpConnection, HttpStream]]
|
||||
waiting_for_connection: typing.DefaultDict[Connection, typing.List[GetHttpConnection]]
|
||||
event_queue: typing.Deque[
|
||||
typing.Union[HttpEvent, HttpCommand, commands.Command]
|
||||
]
|
||||
connections: typing.Dict[Connection, typing.Union[layer.Layer, HttpStream]]
|
||||
waiting_for_establishment: typing.DefaultDict[Connection, typing.List[GetHttpConnection]]
|
||||
command_queue: typing.Deque[commands.Command]
|
||||
|
||||
def __init__(self, context: Context, mode: HTTPMode):
|
||||
super().__init__(context)
|
||||
self.mode = mode
|
||||
|
||||
self.waiting_for_connection = collections.defaultdict(list)
|
||||
self.waiting_for_establishment = collections.defaultdict(list)
|
||||
self.streams = {}
|
||||
self.stream_by_command = {}
|
||||
self.event_queue = collections.deque()
|
||||
self.command_queue = collections.deque()
|
||||
|
||||
self.connections = {
|
||||
context.client: Http1Server(context.client)
|
||||
context.client: Http1Server(context.fork())
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
return f"HTTPLayer(conns: {len(self.connections)}, events: {[type(e).__name__ for e in self.event_queue]})"
|
||||
return f"HTTPLayer(conns: {len(self.connections)}, queue: {[type(e).__name__ for e in self.command_queue]})"
|
||||
|
||||
def _handle_event(self, event: events.Event):
|
||||
if isinstance(event, events.Start):
|
||||
return
|
||||
elif isinstance(event, (EstablishServerTLSReply, events.OpenConnectionReply)) and \
|
||||
event.command.connection in self.waiting_for_connection:
|
||||
if event.reply:
|
||||
waiting = self.waiting_for_connection.pop(event.command.connection)
|
||||
for cmd in waiting:
|
||||
stream = self.stream_by_command.pop(cmd)
|
||||
self.event_to_child(stream, GetHttpConnectionReply(cmd, (None, event.reply)))
|
||||
else:
|
||||
yield from self.make_http_connection(event.command.connection)
|
||||
elif isinstance(event, events.CommandReply):
|
||||
try:
|
||||
stream = self.stream_by_command.pop(event.command)
|
||||
except KeyError:
|
||||
raise
|
||||
if isinstance(event, events.OpenConnectionReply):
|
||||
self.connections[event.command.connection] = stream
|
||||
self.event_to_child(stream, event)
|
||||
elif isinstance(event, events.ConnectionEvent):
|
||||
if event.connection == self.context.server and self.context.server not in self.connections:
|
||||
@ -362,117 +359,129 @@ class HTTPLayer(Layer):
|
||||
else:
|
||||
raise ValueError(f"Unexpected event: {event}")
|
||||
|
||||
while self.event_queue:
|
||||
event = self.event_queue.popleft()
|
||||
if isinstance(event, RequestHeaders):
|
||||
self.streams[event.stream_id] = self.make_stream()
|
||||
if isinstance(event, HttpEvent):
|
||||
stream = self.streams[event.stream_id]
|
||||
self.event_to_child(stream, event)
|
||||
elif isinstance(event, SendHttp):
|
||||
conn = self.connections[event.connection]
|
||||
evts = conn.send(event.event)
|
||||
self.event_queue.extend(evts)
|
||||
elif isinstance(event, GetHttpConnection):
|
||||
yield from self.get_connection(event)
|
||||
elif isinstance(event, commands.Command):
|
||||
yield event
|
||||
else:
|
||||
raise ValueError(f"Unexpected event: {event}")
|
||||
while self.command_queue:
|
||||
command = self.command_queue.popleft()
|
||||
if isinstance(command, ReceiveHttp):
|
||||
if isinstance(command.event, RequestHeaders):
|
||||
self.streams[command.event.stream_id] = self.make_stream()
|
||||
stream = self.streams[command.event.stream_id]
|
||||
self.event_to_child(stream, command.event)
|
||||
elif isinstance(command, SendHttp):
|
||||
conn = self.connections[command.connection]
|
||||
self.event_to_child(conn, command.event)
|
||||
elif isinstance(command, GetHttpConnection):
|
||||
self.get_connection(command)
|
||||
elif isinstance(command, RegisterHttpConnection):
|
||||
yield from self.register_connection(command)
|
||||
elif isinstance(command, commands.Command):
|
||||
yield command
|
||||
else: # pragma: no cover
|
||||
raise ValueError(f"Not a command command: {command}")
|
||||
|
||||
def make_stream(self) -> HttpStream:
|
||||
ctx = self.context.fork()
|
||||
|
||||
stream = HttpStream(ctx)
|
||||
if self.debug:
|
||||
stream.debug = self.debug + " "
|
||||
self.event_to_child(stream, events.Start())
|
||||
return stream
|
||||
|
||||
def get_connection(self, event: GetHttpConnection):
|
||||
def get_connection(self, event: GetHttpConnection, *, reuse: bool = True):
|
||||
# Do we already have a connection we can re-use?
|
||||
for connection, handler in self.connections.items():
|
||||
for connection, layer in self.connections.items():
|
||||
connection_suitable = (
|
||||
reuse and
|
||||
event.connection_spec_matches(connection) and
|
||||
(
|
||||
isinstance(handler, Http2Client) or
|
||||
# see "tricky multiplexing edge case" in make_http_connection for an explanation
|
||||
isinstance(handler, Http1Client) and self.context.client.alpn != b"h2"
|
||||
connection.alpn == b"h2" or self.context.client.alpn != b"h2"
|
||||
)
|
||||
)
|
||||
if connection_suitable:
|
||||
if connection in self.waiting_for_establishment:
|
||||
self.waiting_for_establishment[connection].append(event)
|
||||
else:
|
||||
stream = self.stream_by_command.pop(event)
|
||||
self.event_to_child(stream, GetHttpConnectionReply(event, (connection, None)))
|
||||
self.event_to_child(stream, GetHttpConnectionReply(event, (layer, None)))
|
||||
return
|
||||
# Are we waiting for one?
|
||||
for connection in self.waiting_for_connection:
|
||||
if event.connection_spec_matches(connection):
|
||||
self.waiting_for_connection[connection].append(event)
|
||||
return
|
||||
# Can we reuse context.server?
|
||||
|
||||
can_reuse_context_connection = (
|
||||
self.context.server not in self.connections and
|
||||
self.context.server.connected and
|
||||
self.context.server.address == event.address and
|
||||
self.context.server.tls == event.tls
|
||||
event.connection_spec_matches(self.context.server)
|
||||
)
|
||||
if can_reuse_context_connection:
|
||||
self.waiting_for_connection[self.context.server].append(event)
|
||||
yield from self.make_http_connection(self.context.server)
|
||||
# We need a new one.
|
||||
context = self.context.fork()
|
||||
layer = HttpClient(context)
|
||||
if not can_reuse_context_connection:
|
||||
context.server = Server(event.address)
|
||||
if event.tls:
|
||||
context.server.tls = True
|
||||
# TODO: This is a bit ugly, let's make up nicer syntax, e.g. using __truediv__
|
||||
orig = layer
|
||||
layer = tls.ServerTLSLayer(context)
|
||||
layer.child_layer = orig
|
||||
# TODO: Here we should create other sublayer(s) for upstream proxy etc.
|
||||
|
||||
self.connections[context.server] = layer
|
||||
self.waiting_for_establishment[context.server].append(event)
|
||||
|
||||
self.event_to_child(layer, events.Start())
|
||||
|
||||
def register_connection(self, command: RegisterHttpConnection):
|
||||
waiting = self.waiting_for_establishment.pop(command.connection)
|
||||
|
||||
if command.err:
|
||||
reply = (None, command.err)
|
||||
self.connections.pop(command.connection)
|
||||
else:
|
||||
connection = Server(event.address)
|
||||
connection.tls = event.tls
|
||||
self.waiting_for_connection[connection].append(event)
|
||||
open_command = commands.OpenConnection(connection)
|
||||
open_command.blocking = object()
|
||||
yield open_command
|
||||
reply = (command.connection, None)
|
||||
|
||||
def make_http_connection(self, connection: Server) -> None:
|
||||
if connection.tls and not connection.tls_established:
|
||||
connection.alpn_offers = list(HTTP_ALPNS)
|
||||
if not self.context.options.http2:
|
||||
connection.alpn_offers.remove(b"h2")
|
||||
new_command = EstablishServerTLS(connection)
|
||||
new_command.blocking = object()
|
||||
yield new_command
|
||||
return
|
||||
|
||||
if connection.alpn == b"h2":
|
||||
raise NotImplementedError
|
||||
else:
|
||||
self.connections[connection] = Http1Client(connection)
|
||||
|
||||
waiting = self.waiting_for_connection.pop(connection)
|
||||
for cmd in waiting:
|
||||
stream = self.stream_by_command.pop(cmd)
|
||||
self.event_to_child(stream, GetHttpConnectionReply(cmd, (connection, None)))
|
||||
self.event_to_child(stream, GetHttpConnectionReply(cmd, reply))
|
||||
|
||||
# Tricky multiplexing edge case: Assume a h2 client that sends two requests (or receives two responses)
|
||||
# that neither have a content-length specified nor a chunked transfer encoding.
|
||||
# We can't process these two flows to the same h1 connection as they would both have
|
||||
# "read until eof" semantics. We could force chunked transfer encoding for requests, but can't enforce that
|
||||
# for responses. The only workaround left is to open a separate connection for each flow.
|
||||
if self.context.client.alpn == b"h2" and connection.alpn != b"h2":
|
||||
if not command.err and self.context.client.alpn == b"h2" and command.connection.alpn != b"h2":
|
||||
for cmd in waiting[1:]:
|
||||
new_connection = Server(connection.address)
|
||||
new_connection.tls = connection.tls
|
||||
self.waiting_for_connection[new_connection].append(cmd)
|
||||
open_command = commands.OpenConnection(new_connection)
|
||||
open_command.blocking = object()
|
||||
yield open_command
|
||||
yield from self.get_connection(cmd, reuse=False)
|
||||
break
|
||||
|
||||
def event_to_child(
|
||||
self,
|
||||
stream: typing.Union[HttpConnection, HttpStream],
|
||||
child: typing.Union[layer.Layer, HttpStream],
|
||||
event: events.Event,
|
||||
) -> None:
|
||||
stream_events = list(stream.handle_event(event))
|
||||
for se in stream_events:
|
||||
child_commands = list(child.handle_event(event))
|
||||
for cmd in child_commands:
|
||||
assert isinstance(cmd, commands.Command)
|
||||
# Streams may yield blocking commands, which ultimately generate CommandReply events.
|
||||
# Those need to be routed back to the correct stream, so we need to keep track of that.
|
||||
if isinstance(se, commands.Command) and se.blocking:
|
||||
self.stream_by_command[se] = stream
|
||||
if isinstance(cmd, commands.OpenConnection):
|
||||
self.connections[cmd.connection] = child
|
||||
|
||||
self.event_queue.extend(stream_events)
|
||||
if cmd.blocking:
|
||||
self.stream_by_command[cmd] = child
|
||||
|
||||
self.command_queue.extend(child_commands)
|
||||
|
||||
|
||||
class HttpClient(layer.Layer):
|
||||
@expect(events.Start)
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if self.context.server.connected:
|
||||
err = None
|
||||
else:
|
||||
err = yield commands.OpenConnection(self.context.server)
|
||||
yield RegisterHttpConnection(self.context.server, err)
|
||||
if err:
|
||||
return
|
||||
|
||||
if self.context.server.alpn == b"h2":
|
||||
raise NotImplementedError
|
||||
else:
|
||||
child_layer = Http1Client(self.context)
|
||||
self._handle_event = child_layer.handle_event
|
||||
|
@ -1,8 +1,6 @@
|
||||
import abc
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
|
||||
from mitmproxy.proxy2 import events, layer
|
||||
from mitmproxy.proxy2 import events, layer, commands
|
||||
|
||||
StreamId = int
|
||||
|
||||
@ -18,14 +16,22 @@ class HttpEvent(events.Event):
|
||||
return f"{type(self).__name__}({repr(x) if x else ''})"
|
||||
|
||||
|
||||
class HttpConnection(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def handle_event(self, event: events.Event) -> typing.Iterator[HttpEvent]:
|
||||
yield from ()
|
||||
class HttpConnection(layer.Layer):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
|
||||
yield from ()
|
||||
|
||||
class HttpCommand(commands.Command):
|
||||
pass
|
||||
|
||||
|
||||
class ReceiveHttp(HttpCommand):
|
||||
event: HttpEvent
|
||||
|
||||
def __init__(self, event: HttpEvent):
|
||||
self.event = event
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Receive({self.event})"
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@ -1,5 +1,5 @@
|
||||
import abc
|
||||
import typing
|
||||
from abc import abstractmethod
|
||||
|
||||
import h11
|
||||
from h11._readers import ChunkedReader, ContentLengthReader, Http10Reader
|
||||
@ -9,8 +9,9 @@ from mitmproxy import http
|
||||
from mitmproxy.net.http import http1
|
||||
from mitmproxy.net.http.http1 import read_sansio as http1_sansio
|
||||
from mitmproxy.proxy2 import commands, events, layer
|
||||
from mitmproxy.proxy2.context import Client, Connection, Server
|
||||
from mitmproxy.proxy2.layers.http._base import StreamId
|
||||
from mitmproxy.proxy2.context import Client, Connection, Context, Server
|
||||
from mitmproxy.proxy2.layers.http._base import ReceiveHttp, StreamId
|
||||
from mitmproxy.proxy2.utils import expect
|
||||
from ._base import HttpConnection
|
||||
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
|
||||
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
|
||||
@ -18,28 +19,32 @@ from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders
|
||||
TBodyReader = typing.Union[ChunkedReader, Http10Reader, ContentLengthReader]
|
||||
|
||||
|
||||
class Http1Connection(HttpConnection):
|
||||
class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
|
||||
conn: Connection
|
||||
stream_id: typing.Optional[StreamId] = None
|
||||
request: typing.Optional[http.HTTPRequest] = None
|
||||
response: typing.Optional[http.HTTPResponse] = None
|
||||
request_done: bool = False
|
||||
response_done: bool = False
|
||||
state: typing.Callable[[events.Event], typing.Iterator[HttpEvent]]
|
||||
state: typing.Callable[[events.Event], layer.CommandGenerator[None]]
|
||||
body_reader: TBodyReader
|
||||
buf: ReceiveBuffer
|
||||
|
||||
def __init__(self, conn: Connection):
|
||||
def __init__(self, context: Context, conn: Connection):
|
||||
super().__init__(context)
|
||||
assert isinstance(conn, Connection)
|
||||
self.conn = conn
|
||||
self.buf = ReceiveBuffer()
|
||||
|
||||
def handle_event(self, event: events.Event) -> typing.Iterator[HttpEvent]:
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, HttpEvent):
|
||||
yield from self.send(event)
|
||||
else:
|
||||
if isinstance(event, events.DataReceived):
|
||||
self.buf += event.data
|
||||
yield from self.state(event)
|
||||
|
||||
@abstractmethod
|
||||
@abc.abstractmethod
|
||||
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
|
||||
yield from ()
|
||||
|
||||
@ -51,7 +56,7 @@ class Http1Connection(HttpConnection):
|
||||
else:
|
||||
return ContentLengthReader(expected_size)
|
||||
|
||||
def read_body(self, event: events.Event, is_request: bool) -> typing.Iterator[HttpEvent]:
|
||||
def read_body(self, event: events.Event, is_request: bool) -> layer.CommandGenerator[None]:
|
||||
while True:
|
||||
try:
|
||||
if isinstance(event, events.DataReceived):
|
||||
@ -63,9 +68,9 @@ class Http1Connection(HttpConnection):
|
||||
except h11.ProtocolError as e:
|
||||
yield commands.CloseConnection(self.conn)
|
||||
if is_request:
|
||||
yield RequestProtocolError(self.stream_id, str(e))
|
||||
yield ReceiveHttp(RequestProtocolError(self.stream_id, str(e)))
|
||||
else:
|
||||
yield ResponseProtocolError(self.stream_id, str(e))
|
||||
yield ReceiveHttp(ResponseProtocolError(self.stream_id, str(e)))
|
||||
return
|
||||
|
||||
if h11_event is None:
|
||||
@ -73,17 +78,17 @@ class Http1Connection(HttpConnection):
|
||||
elif isinstance(h11_event, h11.Data):
|
||||
h11_event.data: bytearray # type checking
|
||||
if is_request:
|
||||
yield RequestData(self.stream_id, bytes(h11_event.data))
|
||||
yield ReceiveHttp(RequestData(self.stream_id, bytes(h11_event.data)))
|
||||
else:
|
||||
yield ResponseData(self.stream_id, bytes(h11_event.data))
|
||||
yield ReceiveHttp(ResponseData(self.stream_id, bytes(h11_event.data)))
|
||||
elif isinstance(h11_event, h11.EndOfMessage):
|
||||
if is_request:
|
||||
yield RequestEndOfMessage(self.stream_id)
|
||||
yield ReceiveHttp(RequestEndOfMessage(self.stream_id))
|
||||
else:
|
||||
yield ResponseEndOfMessage(self.stream_id)
|
||||
yield ReceiveHttp(ResponseEndOfMessage(self.stream_id))
|
||||
return
|
||||
|
||||
def wait(self, event: events.Event) -> typing.Iterator[HttpEvent]:
|
||||
def wait(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
"""
|
||||
We wait for the current flow to be finished before parsing the next message,
|
||||
as we may want to upgrade to WebSocket or plain TCP before that.
|
||||
@ -101,8 +106,8 @@ class Http1Server(Http1Connection):
|
||||
"""A simple HTTP/1 server with no pipelining support."""
|
||||
conn: Client
|
||||
|
||||
def __init__(self, conn: Client):
|
||||
super().__init__(conn)
|
||||
def __init__(self, context: Context):
|
||||
super().__init__(context, context.client)
|
||||
self.stream_id = 1
|
||||
self.state = self.read_request_headers
|
||||
|
||||
@ -138,7 +143,7 @@ class Http1Server(Http1Connection):
|
||||
else:
|
||||
raise NotImplementedError(f"{event}")
|
||||
|
||||
def mark_done(self, *, request: bool = False, response: bool = False):
|
||||
def mark_done(self, *, request: bool = False, response: bool = False) -> layer.CommandGenerator[None]:
|
||||
if request:
|
||||
self.request_done = True
|
||||
if response:
|
||||
@ -152,13 +157,13 @@ class Http1Server(Http1Connection):
|
||||
elif self.request_done:
|
||||
self.state = self.wait
|
||||
|
||||
def read_request_headers(self, event: events.Event) -> typing.Iterator[HttpEvent]:
|
||||
def read_request_headers(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, events.DataReceived):
|
||||
request_head = self.buf.maybe_extract_lines()
|
||||
if request_head:
|
||||
request_head = [bytes(x) for x in request_head] # TODO: Make url.parse compatible with bytearrays
|
||||
self.request = http.HTTPRequest.wrap(http1_sansio.read_request_head(request_head))
|
||||
yield RequestHeaders(self.stream_id, self.request)
|
||||
yield ReceiveHttp(RequestHeaders(self.stream_id, self.request))
|
||||
|
||||
if self.request.first_line_format == "authority":
|
||||
# The previous proxy server implementation tried to read the request body here:
|
||||
@ -180,10 +185,10 @@ class Http1Server(Http1Connection):
|
||||
else:
|
||||
raise ValueError(f"Unexpected event: {event}")
|
||||
|
||||
def read_request_body(self, event: events.Event) -> typing.Iterator[HttpEvent]:
|
||||
def read_request_body(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
for e in self.read_body(event, True):
|
||||
yield e
|
||||
if isinstance(e, RequestEndOfMessage):
|
||||
if isinstance(e, ReceiveHttp) and isinstance(e.event, RequestEndOfMessage):
|
||||
yield from self.mark_done(request=True)
|
||||
|
||||
|
||||
@ -192,8 +197,8 @@ class Http1Client(Http1Connection):
|
||||
send_queue: typing.List[HttpEvent]
|
||||
"""A queue of send events for flows other than the one that is currently being transmitted."""
|
||||
|
||||
def __init__(self, conn: Server):
|
||||
super().__init__(conn)
|
||||
def __init__(self, context: Context):
|
||||
super().__init__(context, context.server)
|
||||
self.state = self.read_response_headers
|
||||
self.send_queue = []
|
||||
|
||||
@ -231,7 +236,7 @@ class Http1Client(Http1Connection):
|
||||
else:
|
||||
raise NotImplementedError(f"{event}")
|
||||
|
||||
def mark_done(self, *, request: bool = False, response: bool = False):
|
||||
def mark_done(self, *, request: bool = False, response: bool = False) -> layer.CommandGenerator[None]:
|
||||
if request:
|
||||
self.request_done = True
|
||||
if response:
|
||||
@ -246,15 +251,15 @@ class Http1Client(Http1Connection):
|
||||
for ev in send_queue:
|
||||
yield from self.send(ev)
|
||||
|
||||
def read_response_headers(self, event: events.Event) -> typing.Iterator[HttpEvent]:
|
||||
assert isinstance(event, events.ConnectionEvent)
|
||||
@expect(events.ConnectionEvent)
|
||||
def read_response_headers(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, events.DataReceived):
|
||||
response_head = self.buf.maybe_extract_lines()
|
||||
|
||||
if response_head:
|
||||
response_head = [bytes(x) for x in response_head]
|
||||
self.response = http.HTTPResponse.wrap(http1_sansio.read_response_head(response_head))
|
||||
yield ResponseHeaders(self.stream_id, self.response)
|
||||
yield ReceiveHttp(ResponseHeaders(self.stream_id, self.response))
|
||||
|
||||
expected_size = http1.expected_http_body_size(self.request, self.response)
|
||||
self.body_reader = self.make_body_reader(expected_size)
|
||||
@ -266,23 +271,24 @@ class Http1Client(Http1Connection):
|
||||
elif isinstance(event, events.ConnectionClosed):
|
||||
if self.stream_id:
|
||||
if self.buf:
|
||||
yield ResponseProtocolError(self.stream_id, f"unexpected server response: {bytes(self.buf)}")
|
||||
yield ReceiveHttp(
|
||||
ResponseProtocolError(self.stream_id, f"unexpected server response: {bytes(self.buf)}"))
|
||||
else:
|
||||
# The server has closed the connection to prevent us from continuing.
|
||||
# We need to signal that to the stream.
|
||||
# https://tools.ietf.org/html/rfc7231#section-6.5.11
|
||||
yield ResponseProtocolError(self.stream_id, "server closed connection")
|
||||
yield ReceiveHttp(ResponseProtocolError(self.stream_id, "server closed connection"))
|
||||
else:
|
||||
return
|
||||
yield commands.CloseConnection(self.conn)
|
||||
else:
|
||||
raise ValueError(f"Unexpected event: {event}")
|
||||
|
||||
def read_response_body(self, event: events.Event) -> typing.Iterator[HttpEvent]:
|
||||
assert isinstance(event, events.ConnectionEvent)
|
||||
@expect(events.ConnectionEvent)
|
||||
def read_response_body(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
|
||||
for e in self.read_body(event, False):
|
||||
yield e
|
||||
if isinstance(e, ResponseEndOfMessage):
|
||||
if isinstance(e, ReceiveHttp) and isinstance(e.event, ResponseEndOfMessage):
|
||||
self.state = self.read_response_headers
|
||||
yield from self.mark_done(response=True)
|
||||
|
||||
|
@ -67,7 +67,7 @@ class TCPLayer(layer.Layer):
|
||||
if self.flow:
|
||||
tcp_message = tcp.TCPMessage(from_client, event.data)
|
||||
self.flow.messages.append(tcp_message)
|
||||
yield TcpMessageHook(self.flow)t
|
||||
yield TcpMessageHook(self.flow)
|
||||
yield commands.SendData(send_to, tcp_message.content)
|
||||
else:
|
||||
yield commands.SendData(send_to, event.data)
|
||||
|
@ -1,7 +1,7 @@
|
||||
import struct
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generator, Iterator, Optional, Tuple
|
||||
from typing import Iterator, Optional, Tuple
|
||||
|
||||
from OpenSSL import SSL
|
||||
|
||||
@ -93,20 +93,6 @@ def parse_client_hello(data: bytes) -> Optional[net_tls.ClientHello]:
|
||||
HTTP_ALPNS = (b"h2", b"http/1.1", b"http/1.0", b"http/0.9")
|
||||
|
||||
|
||||
class EstablishServerTLS(commands.ConnectionCommand):
|
||||
"""Establish TLS on the given connection.
|
||||
|
||||
If TLS establishment fails, the connection will automatically be closed by the TLS layer."""
|
||||
connection: context.Server
|
||||
blocking = True
|
||||
|
||||
|
||||
class EstablishServerTLSReply(events.CommandReply):
|
||||
command: EstablishServerTLS
|
||||
reply: Optional[str]
|
||||
"""error message"""
|
||||
|
||||
|
||||
# We need these classes as hooks can only have one argument at the moment.
|
||||
|
||||
@dataclass
|
||||
@ -131,62 +117,61 @@ class TlsClienthelloHook(Hook):
|
||||
|
||||
|
||||
class _TLSLayer(layer.Layer):
|
||||
tls: Dict[context.Connection, SSL.Connection]
|
||||
conn: context.Connection
|
||||
"""The connection for which we do TLS"""
|
||||
tls: SSL.Connection = None
|
||||
"""The OpenSSL connection object"""
|
||||
child_layer: layer.Layer
|
||||
ssl_context: Optional[SSL.Context] = None
|
||||
|
||||
def __init__(self, context: context.Context):
|
||||
def __init__(self, context: context.Context, conn: context.Connection):
|
||||
super().__init__(context)
|
||||
self.tls = {}
|
||||
self.conn = conn
|
||||
|
||||
def __repr__(self):
|
||||
if not self.tls:
|
||||
state = "inactive"
|
||||
elif self.conn.tls_established:
|
||||
state = f"passthrough {self.conn.sni} {self.conn.alpn}"
|
||||
else:
|
||||
conn_states = []
|
||||
for conn in self.tls:
|
||||
if conn.tls_established:
|
||||
conn_states.append(f"passthrough {conn.sni} {conn.alpn}")
|
||||
else:
|
||||
conn_states.append(f"negotiating {conn.sni} {conn.alpn}")
|
||||
state = ", ".join(conn_states)
|
||||
state = f"negotiating {self.conn.sni} {self.conn.alpn}"
|
||||
return f"{type(self).__name__}({state})"
|
||||
|
||||
def start_tls(self, conn: context.Connection, initial_data: bytes = b""):
|
||||
assert conn not in self.tls
|
||||
assert conn.connected
|
||||
conn.tls = True
|
||||
def start_tls(self, initial_data: bytes = b""):
|
||||
assert not self.tls
|
||||
assert self.conn.connected
|
||||
self.conn.tls = True
|
||||
|
||||
tls_start = TlsStartData(conn, self.context)
|
||||
tls_start = TlsStartData(self.conn, self.context)
|
||||
yield TlsStartHook(tls_start)
|
||||
self.tls[conn] = tls_start.ssl_conn
|
||||
assert tls_start.ssl_conn
|
||||
self.tls = tls_start.ssl_conn
|
||||
|
||||
yield from self.negotiate(conn, initial_data)
|
||||
yield from self.negotiate(initial_data)
|
||||
|
||||
def tls_interact(self, conn: context.Connection) -> layer.CommandGenerator[None]:
|
||||
def tls_interact(self) -> layer.CommandGenerator[None]:
|
||||
while True:
|
||||
try:
|
||||
data = self.tls[conn].bio_read(65535)
|
||||
data = self.tls.bio_read(65535)
|
||||
except SSL.WantReadError:
|
||||
# Okay, nothing more waiting to be sent.
|
||||
return
|
||||
else:
|
||||
yield commands.SendData(conn, data)
|
||||
yield commands.SendData(self.conn, data)
|
||||
|
||||
def negotiate(self, conn: context.Connection, data: bytes) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]:
|
||||
def negotiate(self, data: bytes) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]:
|
||||
# bio_write errors for b"", so we need to check first if we actually received something.
|
||||
if data:
|
||||
self.tls[conn].bio_write(data)
|
||||
self.tls.bio_write(data)
|
||||
try:
|
||||
self.tls[conn].do_handshake()
|
||||
self.tls.do_handshake()
|
||||
except SSL.WantReadError:
|
||||
yield from self.tls_interact(conn)
|
||||
yield from self.tls_interact()
|
||||
return False, None
|
||||
except SSL.Error as e:
|
||||
# provide more detailed information for some errors.
|
||||
last_err = e.args and isinstance(e.args[0], list) and e.args[0] and e.args[0][-1]
|
||||
if last_err == ('SSL routines', 'tls_process_server_certificate', 'certificate verify failed'):
|
||||
verify_result = SSL._lib.SSL_get_verify_result(self.tls[conn]._ssl)
|
||||
verify_result = SSL._lib.SSL_get_verify_result(self.tls._ssl)
|
||||
error = SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(verify_result)).decode()
|
||||
err = f"Certificate verify failed: {error}"
|
||||
elif last_err in [
|
||||
@ -196,40 +181,40 @@ class _TLSLayer(layer.Layer):
|
||||
err = last_err[2]
|
||||
else:
|
||||
err = repr(e)
|
||||
yield from self.on_handshake_error(conn, err)
|
||||
yield from self.on_handshake_error(err)
|
||||
return False, err
|
||||
else:
|
||||
# Get all peer certificates.
|
||||
# https://www.openssl.org/docs/man1.1.1/man3/SSL_get_peer_cert_chain.html
|
||||
# If called on the client side, the stack also contains the peer's certificate; if called on the server
|
||||
# side, the peer's certificate must be obtained separately using SSL_get_peer_certificate(3).
|
||||
all_certs = self.tls[conn].get_peer_cert_chain() or []
|
||||
if conn == self.context.client:
|
||||
cert = self.tls[conn].get_peer_certificate()
|
||||
all_certs = self.tls.get_peer_cert_chain() or []
|
||||
if self.conn == self.context.client:
|
||||
cert = self.tls.get_peer_certificate()
|
||||
if cert:
|
||||
all_certs.insert(0, cert)
|
||||
|
||||
conn.tls_established = True
|
||||
conn.sni = self.tls[conn].get_servername()
|
||||
conn.alpn = self.tls[conn].get_alpn_proto_negotiated()
|
||||
conn.certificate_chain = [certs.Cert(x) for x in all_certs]
|
||||
conn.cipher_list = self.tls[conn].get_cipher_list()
|
||||
conn.tls_version = self.tls[conn].get_protocol_version_name()
|
||||
conn.timestamp_tls_setup = time.time()
|
||||
yield commands.Log(f"TLS established: {conn}")
|
||||
yield from self.receive(conn, b"")
|
||||
self.conn.tls_established = True
|
||||
self.conn.sni = self.tls.get_servername()
|
||||
self.conn.alpn = self.tls.get_alpn_proto_negotiated()
|
||||
self.conn.certificate_chain = [certs.Cert(x) for x in all_certs]
|
||||
self.conn.cipher_list = self.tls.get_cipher_list()
|
||||
self.conn.tls_version = self.tls.get_protocol_version_name()
|
||||
self.conn.timestamp_tls_setup = time.time()
|
||||
yield commands.Log(f"TLS established: {self.conn}")
|
||||
yield from self.receive(b"")
|
||||
return True, None
|
||||
|
||||
def receive(self, conn: context.Connection, data: bytes):
|
||||
def receive(self, data: bytes):
|
||||
if data:
|
||||
self.tls[conn].bio_write(data)
|
||||
yield from self.tls_interact(conn)
|
||||
self.tls.bio_write(data)
|
||||
yield from self.tls_interact()
|
||||
|
||||
plaintext = bytearray()
|
||||
close = False
|
||||
while True:
|
||||
try:
|
||||
plaintext.extend(self.tls[conn].recv(65535))
|
||||
plaintext.extend(self.tls.recv(65535))
|
||||
except SSL.WantReadError:
|
||||
break
|
||||
except SSL.ZeroReturnError:
|
||||
@ -238,76 +223,90 @@ class _TLSLayer(layer.Layer):
|
||||
|
||||
if plaintext:
|
||||
yield from self.event_to_child(
|
||||
events.DataReceived(conn, bytes(plaintext))
|
||||
events.DataReceived(self.conn, bytes(plaintext))
|
||||
)
|
||||
if close:
|
||||
conn.state &= ~context.ConnectionState.CAN_READ
|
||||
yield commands.Log(f"TLS close_notify {conn=}")
|
||||
self.conn.state &= ~context.ConnectionState.CAN_READ
|
||||
yield commands.Log(f"TLS close_notify {self.conn}")
|
||||
yield from self.event_to_child(
|
||||
events.ConnectionClosed(conn)
|
||||
events.ConnectionClosed(self.conn)
|
||||
)
|
||||
|
||||
def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
for command in self.child_layer.handle_event(event):
|
||||
if isinstance(command, commands.SendData) and command.connection in self.tls:
|
||||
self.tls[command.connection].sendall(command.data)
|
||||
yield from self.tls_interact(command.connection)
|
||||
if isinstance(command, commands.SendData) and command.connection == self.conn:
|
||||
self.tls.sendall(command.data)
|
||||
yield from self.tls_interact()
|
||||
else:
|
||||
yield command
|
||||
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, events.DataReceived) and event.connection in self.tls:
|
||||
if isinstance(event, events.DataReceived) and event.connection == self.conn:
|
||||
if not event.connection.tls_established:
|
||||
yield from self.negotiate(event.connection, event.data)
|
||||
yield from self.negotiate(event.data)
|
||||
else:
|
||||
yield from self.receive(event.connection, event.data)
|
||||
elif isinstance(event, events.ConnectionClosed) and event.connection in self.tls:
|
||||
yield from self.receive(event.data)
|
||||
elif isinstance(event, events.ConnectionClosed) and event.connection == self.conn:
|
||||
if event.connection.tls_established:
|
||||
if self.tls[event.connection].get_shutdown() & SSL.RECEIVED_SHUTDOWN:
|
||||
if self.tls.get_shutdown() & SSL.RECEIVED_SHUTDOWN:
|
||||
pass # We have already dispatched a ConnectionClosed to the child layer.
|
||||
else:
|
||||
yield from self.event_to_child(event)
|
||||
else:
|
||||
yield from self.on_handshake_error(event.connection, "connection closed without notice")
|
||||
yield from self.on_handshake_error("connection closed without notice")
|
||||
else:
|
||||
yield from self.event_to_child(event)
|
||||
|
||||
def on_handshake_error(self, conn: context.Connection, err: str) -> layer.CommandGenerator[None]:
|
||||
yield commands.CloseConnection(conn)
|
||||
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
|
||||
yield commands.CloseConnection(self.conn)
|
||||
|
||||
|
||||
class ServerTLSLayer(_TLSLayer):
|
||||
"""
|
||||
This layer manages TLS for potentially multiple server connections.
|
||||
This layer establishes TLS for a single server connection.
|
||||
"""
|
||||
command_to_reply_to: Dict[context.Connection, commands.OpenConnection]
|
||||
command_to_reply_to: Optional[commands.OpenConnection] = None
|
||||
|
||||
def __init__(self, context: context.Context):
|
||||
super().__init__(context)
|
||||
self.command_to_reply_to = {}
|
||||
super().__init__(context, context.server)
|
||||
self.child_layer = layer.NextLayer(self.context)
|
||||
|
||||
def negotiate(self, conn: context.Connection, data: bytes) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]:
|
||||
done, err = yield from super().negotiate(conn, data)
|
||||
@expect(events.Start)
|
||||
def state_start(self, _) -> layer.CommandGenerator[None]:
|
||||
self.context.server.tls = True
|
||||
if self.context.server.connected:
|
||||
yield from self.start_tls()
|
||||
self._handle_event = super()._handle_event
|
||||
|
||||
_handle_event = state_start
|
||||
|
||||
def negotiate(self, data: bytes) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]:
|
||||
done, err = yield from super().negotiate(data)
|
||||
if done or err:
|
||||
cmd = self.command_to_reply_to.pop(conn)
|
||||
cmd = self.command_to_reply_to
|
||||
yield from self.event_to_child(events.OpenConnectionReply(cmd, err))
|
||||
self.command_to_reply_to = None
|
||||
return done, err
|
||||
|
||||
def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
for command in super().event_to_child(event):
|
||||
if isinstance(command, EstablishServerTLS):
|
||||
self.command_to_reply_to[command.connection] = command
|
||||
yield from self.start_tls(command.connection)
|
||||
if isinstance(command, commands.OpenConnection) and command.connection == self.context.server:
|
||||
# create our own OpenConnection command object that blocks here.
|
||||
err = yield commands.OpenConnection(command.connection)
|
||||
if err:
|
||||
yield from self.event_to_child(events.OpenConnectionReply(command, err))
|
||||
else:
|
||||
self.command_to_reply_to = command
|
||||
yield from self.start_tls()
|
||||
else:
|
||||
yield command
|
||||
|
||||
def on_handshake_error(self, conn: context.Connection, err: str) -> layer.CommandGenerator[None]:
|
||||
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
|
||||
yield commands.Log(
|
||||
f"Server TLS handshake failed. {err}",
|
||||
level="warn"
|
||||
)
|
||||
yield from super().on_handshake_error(conn, err)
|
||||
yield from super().on_handshake_error(err)
|
||||
|
||||
|
||||
class ClientTLSLayer(_TLSLayer):
|
||||
@ -331,10 +330,9 @@ class ClientTLSLayer(_TLSLayer):
|
||||
|
||||
def __init__(self, context: context.Context):
|
||||
assert isinstance(context.layers[-1], ServerTLSLayer)
|
||||
super().__init__(context)
|
||||
super().__init__(context, self.context.client)
|
||||
self.recv_buffer = bytearray()
|
||||
self.child_layer = layer.NextLayer(self.context)
|
||||
self._handle_event = self.state_start
|
||||
|
||||
@expect(events.Start)
|
||||
def state_start(self, _) -> layer.CommandGenerator[None]:
|
||||
@ -342,20 +340,21 @@ class ClientTLSLayer(_TLSLayer):
|
||||
self._handle_event = self.state_wait_for_clienthello
|
||||
yield from ()
|
||||
|
||||
_handle_event = state_start
|
||||
|
||||
def state_wait_for_clienthello(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
client = self.context.client
|
||||
if isinstance(event, events.DataReceived) and event.connection == client:
|
||||
if isinstance(event, events.DataReceived) and event.connection == self.conn:
|
||||
self.recv_buffer.extend(event.data)
|
||||
try:
|
||||
client_hello = parse_client_hello(self.recv_buffer)
|
||||
except ValueError:
|
||||
yield commands.Log(f"Cannot parse ClientHello: {self.recv_buffer.hex()}")
|
||||
yield commands.CloseConnection(client)
|
||||
yield commands.CloseConnection(self.conn)
|
||||
return
|
||||
|
||||
if client_hello:
|
||||
client.sni = client_hello.sni
|
||||
client.alpn_offers = client_hello.alpn_protocols
|
||||
self.conn.sni = client_hello.sni
|
||||
self.conn.alpn_offers = client_hello.alpn_protocols
|
||||
tls_clienthello = ClientHelloData(self.context)
|
||||
yield TlsClienthelloHook(tls_clienthello)
|
||||
|
||||
@ -365,7 +364,7 @@ class ClientTLSLayer(_TLSLayer):
|
||||
yield commands.Log("Unable to establish TLS connection with server. "
|
||||
"Trying to establish TLS with client anyway.")
|
||||
|
||||
yield from self.start_tls(client, bytes(self.recv_buffer))
|
||||
yield from self.start_tls(bytes(self.recv_buffer))
|
||||
self.recv_buffer.clear()
|
||||
self._handle_event = super()._handle_event
|
||||
|
||||
@ -379,25 +378,16 @@ class ClientTLSLayer(_TLSLayer):
|
||||
We often need information from the upstream connection to establish TLS with the client.
|
||||
For example, we need to check if the client does ALPN or not.
|
||||
"""
|
||||
server = self.context.server
|
||||
if not server.connected:
|
||||
err = yield commands.OpenConnection(server)
|
||||
err = yield commands.OpenConnection(self.context.server)
|
||||
if err:
|
||||
yield commands.Log(
|
||||
f"Cannot establish server connection: {err}"
|
||||
)
|
||||
yield commands.Log(f"Cannot establish server connection: {err}")
|
||||
return err
|
||||
else:
|
||||
return None
|
||||
|
||||
err = yield EstablishServerTLS(server)
|
||||
if err:
|
||||
yield commands.Log(
|
||||
f"Cannot establish TLS with server: {err}"
|
||||
)
|
||||
return err
|
||||
|
||||
def on_handshake_error(self, conn: context.Connection, err: str) -> layer.CommandGenerator[None]:
|
||||
if conn.sni:
|
||||
dest = conn.sni.decode("idna")
|
||||
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
|
||||
if self.conn.sni:
|
||||
dest = self.conn.sni.decode("idna")
|
||||
else:
|
||||
dest = human.format_address(self.context.server.address)
|
||||
if "unknown ca" in err or "bad certificate" in err:
|
||||
@ -409,4 +399,4 @@ class ClientTLSLayer(_TLSLayer):
|
||||
f"The client {keyword} trust the proxy's certificate for {dest} ({err})",
|
||||
level="warn"
|
||||
)
|
||||
yield from super().on_handshake_error(conn, err)
|
||||
yield from super().on_handshake_error(err)
|
||||
|
@ -18,7 +18,7 @@ def expect(*event_types):
|
||||
if isinstance(event, event_types):
|
||||
yield from f(self, event)
|
||||
else:
|
||||
event_types_str = '|'.join(e.__name__ for e in event_types)
|
||||
event_types_str = '|'.join(e.__name__ for e in event_types) or "no events"
|
||||
raise AssertionError(f"Unexpected event type at {f.__qualname__}: Expected {event_types_str}, got {event}.")
|
||||
|
||||
return wrapper
|
||||
|
@ -106,8 +106,8 @@ def test_redirect(strategy, https_server, https_client, tctx):
|
||||
p << OpenConnection(server)
|
||||
p >> reply(None)
|
||||
if https_server:
|
||||
p << tls.EstablishServerTLS(server)
|
||||
p >> reply_establish_server_tls()
|
||||
pass # p << tls.EstablishServerTLS(server)
|
||||
# p >> reply_establish_server_tls()
|
||||
p << SendData(server, b"GET / HTTP/1.1\r\nHost: redirected.site\r\n\r\n")
|
||||
p >> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
|
||||
p << SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
|
||||
@ -123,6 +123,7 @@ def test_multiple_server_connections(tctx):
|
||||
"""Test multiple requests being rewritten to different targets."""
|
||||
server1 = Placeholder()
|
||||
server2 = Placeholder()
|
||||
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
|
||||
def redirect(to: str):
|
||||
def side_effect(flow: HTTPFlow):
|
||||
@ -131,7 +132,7 @@ def test_multiple_server_connections(tctx):
|
||||
return side_effect
|
||||
|
||||
assert (
|
||||
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
playbook
|
||||
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
<< http.HttpRequestHook(Placeholder())
|
||||
>> reply(side_effect=redirect("http://one.redirect/"))
|
||||
@ -140,6 +141,9 @@ def test_multiple_server_connections(tctx):
|
||||
<< SendData(server1, b"GET / HTTP/1.1\r\nHost: one.redirect\r\n\r\n")
|
||||
>> DataReceived(server1, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
|
||||
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
|
||||
)
|
||||
assert (
|
||||
playbook
|
||||
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
<< http.HttpRequestHook(Placeholder())
|
||||
>> reply(side_effect=redirect("http://two.redirect/"))
|
||||
|
@ -4,8 +4,6 @@ import typing
|
||||
import pytest
|
||||
from OpenSSL import SSL
|
||||
|
||||
import mitmproxy.proxy2.layer
|
||||
import mitmproxy.proxy2.layers.tls
|
||||
from mitmproxy.proxy2 import commands, context, events, layer
|
||||
from mitmproxy.proxy2.layers import tls
|
||||
from mitmproxy.utils import data
|
||||
@ -110,11 +108,11 @@ class TlsEchoLayer(tutils.EchoLayer):
|
||||
err: typing.Optional[str] = None
|
||||
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, events.DataReceived) and event.data == b"establish-server-tls":
|
||||
if isinstance(event, events.DataReceived) and event.data == b"open-connection":
|
||||
# noinspection PyTypeChecker
|
||||
err = yield tls.EstablishServerTLS(self.context.server)
|
||||
err = yield commands.OpenConnection(self.context.server)
|
||||
if err:
|
||||
yield commands.SendData(event.connection, f"server-tls-failed: {err}".encode())
|
||||
yield commands.SendData(event.connection, f"open-connection failed: {err}".encode())
|
||||
else:
|
||||
yield from super()._handle_event(event)
|
||||
|
||||
@ -208,9 +206,6 @@ class TestServerTLS:
|
||||
data = tutils.Placeholder()
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, b"establish-server-tls")
|
||||
<< layer.NextLayerHook(tutils.Placeholder())
|
||||
>> tutils.reply_next_layer(TlsEchoLayer)
|
||||
<< tls.TlsStartHook(tutils.Placeholder())
|
||||
>> reply_tls_start()
|
||||
<< commands.SendData(tctx.server, data)
|
||||
@ -233,6 +228,13 @@ class TestServerTLS:
|
||||
assert tctx.server.tls_established
|
||||
|
||||
# Echo
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.client, b"foo")
|
||||
<< layer.NextLayerHook(tutils.Placeholder())
|
||||
>> tutils.reply_next_layer(TlsEchoLayer)
|
||||
<< commands.SendData(tctx.client, b"foo")
|
||||
)
|
||||
_test_echo(playbook, tssl, tctx.server)
|
||||
|
||||
with pytest.raises(ssl.SSLWantReadError):
|
||||
@ -274,7 +276,7 @@ class TestServerTLS:
|
||||
assert (
|
||||
playbook
|
||||
>> events.DataReceived(tctx.server, tssl.out.read())
|
||||
<< commands.Log("Server TLS handshake failed. Certificate verify failed: Hostname mismatch","warn")
|
||||
<< commands.Log("Server TLS handshake failed. Certificate verify failed: Hostname mismatch", "warn")
|
||||
<< commands.CloseConnection(tctx.server)
|
||||
<< commands.SendData(tctx.client, b"server-tls-failed: Certificate verify failed: Hostname mismatch")
|
||||
)
|
||||
|
@ -170,7 +170,11 @@ class Playbook:
|
||||
pass
|
||||
else:
|
||||
if hasattr(x, "playbook_eval"):
|
||||
try:
|
||||
x = self.expected[i] = x.playbook_eval(self)
|
||||
except Exception:
|
||||
self.actual.append(_TracebackInPlaybook(traceback.format_exc()))
|
||||
break
|
||||
for name, value in vars(x).items():
|
||||
if isinstance(value, _Placeholder):
|
||||
setattr(x, name, value())
|
||||
@ -273,8 +277,7 @@ class reply(events.Event):
|
||||
self.to = cmd
|
||||
break
|
||||
else:
|
||||
actual_str = "\n".join(_fmt_entry(x) for x in playbook.actual)
|
||||
raise AssertionError(f"Expected command {self.to} did not occur:\n{actual_str}")
|
||||
raise AssertionError(f"Expected command {self.to} did not occur.")
|
||||
|
||||
assert isinstance(self.to, commands.Command)
|
||||
if isinstance(self.to, commands.Hook):
|
||||
|
Loading…
Reference in New Issue
Block a user