mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 00:01:36 +00:00
[sans-io] wip: tls establishment semantics
This commit is contained in:
parent
7efe27be74
commit
b2060356b6
@ -74,16 +74,6 @@ class CloseConnection(ConnectionCommand):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class ProtocolError(Command):
|
|
||||||
"""
|
|
||||||
Indicate that an unrecoverable protocol error has occured.
|
|
||||||
"""
|
|
||||||
message: str
|
|
||||||
|
|
||||||
def __init__(self, message: str):
|
|
||||||
self.message = message
|
|
||||||
|
|
||||||
|
|
||||||
class Hook(Command):
|
class Hook(Command):
|
||||||
"""
|
"""
|
||||||
Callback to the master (like ".ask()")
|
Callback to the master (like ".ask()")
|
||||||
@ -143,5 +133,3 @@ class Log(Command):
|
|||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"Log({self.message}, {self.level})"
|
return f"Log({self.message}, {self.level})"
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,10 +5,10 @@ from mitmproxy import flow, http
|
|||||||
from mitmproxy.proxy.protocol.http import HTTPMode
|
from mitmproxy.proxy.protocol.http import HTTPMode
|
||||||
from mitmproxy.proxy2 import commands, events, layer
|
from mitmproxy.proxy2 import commands, events, layer
|
||||||
from mitmproxy.proxy2.context import Connection, Context, Server
|
from mitmproxy.proxy2.context import Connection, Context, Server
|
||||||
from mitmproxy.proxy2.layers.tls import EstablishServerTLS, EstablishServerTLSReply, HTTP_ALPNS
|
from mitmproxy.proxy2.layers import tls
|
||||||
from mitmproxy.proxy2.utils import expect
|
from mitmproxy.proxy2.utils import expect
|
||||||
from mitmproxy.utils import human
|
from mitmproxy.utils import human
|
||||||
from ._base import HttpConnection, StreamId
|
from ._base import HttpConnection, StreamId, HttpCommand, ReceiveHttp
|
||||||
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
|
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
|
||||||
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
|
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
|
||||||
from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \
|
from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \
|
||||||
@ -17,10 +17,6 @@ from ._http1 import Http1Client, Http1Server
|
|||||||
from ._http2 import Http2Client
|
from ._http2 import Http2Client
|
||||||
|
|
||||||
|
|
||||||
class HttpCommand(commands.Command):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class GetHttpConnection(HttpCommand):
|
class GetHttpConnection(HttpCommand):
|
||||||
"""
|
"""
|
||||||
Open a HTTP Connection. This may not actually open a connection, but return an existing HTTP connection instead.
|
Open a HTTP Connection. This may not actually open a connection, but return an existing HTTP connection instead.
|
||||||
@ -47,6 +43,18 @@ class GetHttpConnectionReply(events.CommandReply):
|
|||||||
"""connection object, error message"""
|
"""connection object, error message"""
|
||||||
|
|
||||||
|
|
||||||
|
class RegisterHttpConnection(HttpCommand):
|
||||||
|
"""
|
||||||
|
Register that a HTTP connection has been successfully established.
|
||||||
|
"""
|
||||||
|
connection: Connection
|
||||||
|
err: str
|
||||||
|
|
||||||
|
def __init__(self, connection: Connection, err: str):
|
||||||
|
self.connection = connection
|
||||||
|
self.err = err
|
||||||
|
|
||||||
|
|
||||||
class SendHttp(HttpCommand):
|
class SendHttp(HttpCommand):
|
||||||
connection: Connection
|
connection: Connection
|
||||||
event: HttpEvent
|
event: HttpEvent
|
||||||
@ -59,6 +67,8 @@ class SendHttp(HttpCommand):
|
|||||||
return f"Send({self.event})"
|
return f"Send({self.event})"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HttpStream(layer.Layer):
|
class HttpStream(layer.Layer):
|
||||||
request_body_buf: bytes
|
request_body_buf: bytes
|
||||||
response_body_buf: bytes
|
response_body_buf: bytes
|
||||||
@ -299,7 +309,7 @@ class HttpStream(layer.Layer):
|
|||||||
yield from ()
|
yield from ()
|
||||||
|
|
||||||
|
|
||||||
class HTTPLayer(Layer):
|
class HTTPLayer(layer.Layer):
|
||||||
"""
|
"""
|
||||||
ConnectionEvent: We have received b"GET /\r\n\r\n" from the client.
|
ConnectionEvent: We have received b"GET /\r\n\r\n" from the client.
|
||||||
HttpEvent: We have received request headers
|
HttpEvent: We have received request headers
|
||||||
@ -311,47 +321,34 @@ class HTTPLayer(Layer):
|
|||||||
mode: HTTPMode
|
mode: HTTPMode
|
||||||
stream_by_command: typing.Dict[commands.Command, HttpStream]
|
stream_by_command: typing.Dict[commands.Command, HttpStream]
|
||||||
streams: typing.Dict[int, HttpStream]
|
streams: typing.Dict[int, HttpStream]
|
||||||
connections: typing.Dict[Connection, typing.Union[HttpConnection, HttpStream]]
|
connections: typing.Dict[Connection, typing.Union[layer.Layer, HttpStream]]
|
||||||
waiting_for_connection: typing.DefaultDict[Connection, typing.List[GetHttpConnection]]
|
waiting_for_establishment: typing.DefaultDict[Connection, typing.List[GetHttpConnection]]
|
||||||
event_queue: typing.Deque[
|
command_queue: typing.Deque[commands.Command]
|
||||||
typing.Union[HttpEvent, HttpCommand, commands.Command]
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(self, context: Context, mode: HTTPMode):
|
def __init__(self, context: Context, mode: HTTPMode):
|
||||||
super().__init__(context)
|
super().__init__(context)
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
|
|
||||||
self.waiting_for_connection = collections.defaultdict(list)
|
self.waiting_for_establishment = collections.defaultdict(list)
|
||||||
self.streams = {}
|
self.streams = {}
|
||||||
self.stream_by_command = {}
|
self.stream_by_command = {}
|
||||||
self.event_queue = collections.deque()
|
self.command_queue = collections.deque()
|
||||||
|
|
||||||
self.connections = {
|
self.connections = {
|
||||||
context.client: Http1Server(context.client)
|
context.client: Http1Server(context.fork())
|
||||||
}
|
}
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"HTTPLayer(conns: {len(self.connections)}, events: {[type(e).__name__ for e in self.event_queue]})"
|
return f"HTTPLayer(conns: {len(self.connections)}, queue: {[type(e).__name__ for e in self.command_queue]})"
|
||||||
|
|
||||||
def _handle_event(self, event: events.Event):
|
def _handle_event(self, event: events.Event):
|
||||||
if isinstance(event, events.Start):
|
if isinstance(event, events.Start):
|
||||||
return
|
return
|
||||||
elif isinstance(event, (EstablishServerTLSReply, events.OpenConnectionReply)) and \
|
|
||||||
event.command.connection in self.waiting_for_connection:
|
|
||||||
if event.reply:
|
|
||||||
waiting = self.waiting_for_connection.pop(event.command.connection)
|
|
||||||
for cmd in waiting:
|
|
||||||
stream = self.stream_by_command.pop(cmd)
|
|
||||||
self.event_to_child(stream, GetHttpConnectionReply(cmd, (None, event.reply)))
|
|
||||||
else:
|
|
||||||
yield from self.make_http_connection(event.command.connection)
|
|
||||||
elif isinstance(event, events.CommandReply):
|
elif isinstance(event, events.CommandReply):
|
||||||
try:
|
try:
|
||||||
stream = self.stream_by_command.pop(event.command)
|
stream = self.stream_by_command.pop(event.command)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise
|
raise
|
||||||
if isinstance(event, events.OpenConnectionReply):
|
|
||||||
self.connections[event.command.connection] = stream
|
|
||||||
self.event_to_child(stream, event)
|
self.event_to_child(stream, event)
|
||||||
elif isinstance(event, events.ConnectionEvent):
|
elif isinstance(event, events.ConnectionEvent):
|
||||||
if event.connection == self.context.server and self.context.server not in self.connections:
|
if event.connection == self.context.server and self.context.server not in self.connections:
|
||||||
@ -362,117 +359,129 @@ class HTTPLayer(Layer):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected event: {event}")
|
raise ValueError(f"Unexpected event: {event}")
|
||||||
|
|
||||||
while self.event_queue:
|
while self.command_queue:
|
||||||
event = self.event_queue.popleft()
|
command = self.command_queue.popleft()
|
||||||
if isinstance(event, RequestHeaders):
|
if isinstance(command, ReceiveHttp):
|
||||||
self.streams[event.stream_id] = self.make_stream()
|
if isinstance(command.event, RequestHeaders):
|
||||||
if isinstance(event, HttpEvent):
|
self.streams[command.event.stream_id] = self.make_stream()
|
||||||
stream = self.streams[event.stream_id]
|
stream = self.streams[command.event.stream_id]
|
||||||
self.event_to_child(stream, event)
|
self.event_to_child(stream, command.event)
|
||||||
elif isinstance(event, SendHttp):
|
elif isinstance(command, SendHttp):
|
||||||
conn = self.connections[event.connection]
|
conn = self.connections[command.connection]
|
||||||
evts = conn.send(event.event)
|
self.event_to_child(conn, command.event)
|
||||||
self.event_queue.extend(evts)
|
elif isinstance(command, GetHttpConnection):
|
||||||
elif isinstance(event, GetHttpConnection):
|
self.get_connection(command)
|
||||||
yield from self.get_connection(event)
|
elif isinstance(command, RegisterHttpConnection):
|
||||||
elif isinstance(event, commands.Command):
|
yield from self.register_connection(command)
|
||||||
yield event
|
elif isinstance(command, commands.Command):
|
||||||
else:
|
yield command
|
||||||
raise ValueError(f"Unexpected event: {event}")
|
else: # pragma: no cover
|
||||||
|
raise ValueError(f"Not a command command: {command}")
|
||||||
|
|
||||||
def make_stream(self) -> HttpStream:
|
def make_stream(self) -> HttpStream:
|
||||||
ctx = self.context.fork()
|
ctx = self.context.fork()
|
||||||
|
|
||||||
stream = HttpStream(ctx)
|
stream = HttpStream(ctx)
|
||||||
if self.debug:
|
if self.debug:
|
||||||
stream.debug = self.debug + " "
|
stream.debug = self.debug + " "
|
||||||
self.event_to_child(stream, events.Start())
|
self.event_to_child(stream, events.Start())
|
||||||
return stream
|
return stream
|
||||||
|
|
||||||
def get_connection(self, event: GetHttpConnection):
|
def get_connection(self, event: GetHttpConnection, *, reuse: bool = True):
|
||||||
# Do we already have a connection we can re-use?
|
# Do we already have a connection we can re-use?
|
||||||
for connection, handler in self.connections.items():
|
for connection, layer in self.connections.items():
|
||||||
connection_suitable = (
|
connection_suitable = (
|
||||||
|
reuse and
|
||||||
event.connection_spec_matches(connection) and
|
event.connection_spec_matches(connection) and
|
||||||
(
|
(
|
||||||
isinstance(handler, Http2Client) or
|
# see "tricky multiplexing edge case" in make_http_connection for an explanation
|
||||||
# see "tricky multiplexing edge case" in make_http_connection for an explanation
|
connection.alpn == b"h2" or self.context.client.alpn != b"h2"
|
||||||
isinstance(handler, Http1Client) and self.context.client.alpn != b"h2"
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if connection_suitable:
|
if connection_suitable:
|
||||||
stream = self.stream_by_command.pop(event)
|
if connection in self.waiting_for_establishment:
|
||||||
self.event_to_child(stream, GetHttpConnectionReply(event, (connection, None)))
|
self.waiting_for_establishment[connection].append(event)
|
||||||
|
else:
|
||||||
|
stream = self.stream_by_command.pop(event)
|
||||||
|
self.event_to_child(stream, GetHttpConnectionReply(event, (layer, None)))
|
||||||
return
|
return
|
||||||
# Are we waiting for one?
|
|
||||||
for connection in self.waiting_for_connection:
|
|
||||||
if event.connection_spec_matches(connection):
|
|
||||||
self.waiting_for_connection[connection].append(event)
|
|
||||||
return
|
|
||||||
# Can we reuse context.server?
|
|
||||||
can_reuse_context_connection = (
|
can_reuse_context_connection = (
|
||||||
self.context.server not in self.connections and
|
self.context.server not in self.connections and
|
||||||
self.context.server.connected and
|
self.context.server.connected and
|
||||||
self.context.server.address == event.address and
|
event.connection_spec_matches(self.context.server)
|
||||||
self.context.server.tls == event.tls
|
|
||||||
)
|
)
|
||||||
if can_reuse_context_connection:
|
context = self.context.fork()
|
||||||
self.waiting_for_connection[self.context.server].append(event)
|
layer = HttpClient(context)
|
||||||
yield from self.make_http_connection(self.context.server)
|
if not can_reuse_context_connection:
|
||||||
# We need a new one.
|
context.server = Server(event.address)
|
||||||
|
if event.tls:
|
||||||
|
context.server.tls = True
|
||||||
|
# TODO: This is a bit ugly, let's make up nicer syntax, e.g. using __truediv__
|
||||||
|
orig = layer
|
||||||
|
layer = tls.ServerTLSLayer(context)
|
||||||
|
layer.child_layer = orig
|
||||||
|
# TODO: Here we should create other sublayer(s) for upstream proxy etc.
|
||||||
|
|
||||||
|
self.connections[context.server] = layer
|
||||||
|
self.waiting_for_establishment[context.server].append(event)
|
||||||
|
|
||||||
|
self.event_to_child(layer, events.Start())
|
||||||
|
|
||||||
|
def register_connection(self, command: RegisterHttpConnection):
|
||||||
|
waiting = self.waiting_for_establishment.pop(command.connection)
|
||||||
|
|
||||||
|
if command.err:
|
||||||
|
reply = (None, command.err)
|
||||||
|
self.connections.pop(command.connection)
|
||||||
else:
|
else:
|
||||||
connection = Server(event.address)
|
reply = (command.connection, None)
|
||||||
connection.tls = event.tls
|
|
||||||
self.waiting_for_connection[connection].append(event)
|
|
||||||
open_command = commands.OpenConnection(connection)
|
|
||||||
open_command.blocking = object()
|
|
||||||
yield open_command
|
|
||||||
|
|
||||||
def make_http_connection(self, connection: Server) -> None:
|
|
||||||
if connection.tls and not connection.tls_established:
|
|
||||||
connection.alpn_offers = list(HTTP_ALPNS)
|
|
||||||
if not self.context.options.http2:
|
|
||||||
connection.alpn_offers.remove(b"h2")
|
|
||||||
new_command = EstablishServerTLS(connection)
|
|
||||||
new_command.blocking = object()
|
|
||||||
yield new_command
|
|
||||||
return
|
|
||||||
|
|
||||||
if connection.alpn == b"h2":
|
|
||||||
raise NotImplementedError
|
|
||||||
else:
|
|
||||||
self.connections[connection] = Http1Client(connection)
|
|
||||||
|
|
||||||
waiting = self.waiting_for_connection.pop(connection)
|
|
||||||
for cmd in waiting:
|
for cmd in waiting:
|
||||||
stream = self.stream_by_command.pop(cmd)
|
stream = self.stream_by_command.pop(cmd)
|
||||||
self.event_to_child(stream, GetHttpConnectionReply(cmd, (connection, None)))
|
self.event_to_child(stream, GetHttpConnectionReply(cmd, reply))
|
||||||
|
|
||||||
# Tricky multiplexing edge case: Assume a h2 client that sends two requests (or receives two responses)
|
# Tricky multiplexing edge case: Assume a h2 client that sends two requests (or receives two responses)
|
||||||
# that neither have a content-length specified nor a chunked transfer encoding.
|
# that neither have a content-length specified nor a chunked transfer encoding.
|
||||||
# We can't process these two flows to the same h1 connection as they would both have
|
# We can't process these two flows to the same h1 connection as they would both have
|
||||||
# "read until eof" semantics. We could force chunked transfer encoding for requests, but can't enforce that
|
# "read until eof" semantics. We could force chunked transfer encoding for requests, but can't enforce that
|
||||||
# for responses. The only workaround left is to open a separate connection for each flow.
|
# for responses. The only workaround left is to open a separate connection for each flow.
|
||||||
if self.context.client.alpn == b"h2" and connection.alpn != b"h2":
|
if not command.err and self.context.client.alpn == b"h2" and command.connection.alpn != b"h2":
|
||||||
for cmd in waiting[1:]:
|
for cmd in waiting[1:]:
|
||||||
new_connection = Server(connection.address)
|
yield from self.get_connection(cmd, reuse=False)
|
||||||
new_connection.tls = connection.tls
|
|
||||||
self.waiting_for_connection[new_connection].append(cmd)
|
|
||||||
open_command = commands.OpenConnection(new_connection)
|
|
||||||
open_command.blocking = object()
|
|
||||||
yield open_command
|
|
||||||
break
|
break
|
||||||
|
|
||||||
def event_to_child(
|
def event_to_child(
|
||||||
self,
|
self,
|
||||||
stream: typing.Union[HttpConnection, HttpStream],
|
child: typing.Union[layer.Layer, HttpStream],
|
||||||
event: events.Event,
|
event: events.Event,
|
||||||
) -> None:
|
) -> None:
|
||||||
stream_events = list(stream.handle_event(event))
|
child_commands = list(child.handle_event(event))
|
||||||
for se in stream_events:
|
for cmd in child_commands:
|
||||||
|
assert isinstance(cmd, commands.Command)
|
||||||
# Streams may yield blocking commands, which ultimately generate CommandReply events.
|
# Streams may yield blocking commands, which ultimately generate CommandReply events.
|
||||||
# Those need to be routed back to the correct stream, so we need to keep track of that.
|
# Those need to be routed back to the correct stream, so we need to keep track of that.
|
||||||
if isinstance(se, commands.Command) and se.blocking:
|
if isinstance(cmd, commands.OpenConnection):
|
||||||
self.stream_by_command[se] = stream
|
self.connections[cmd.connection] = child
|
||||||
|
|
||||||
self.event_queue.extend(stream_events)
|
if cmd.blocking:
|
||||||
|
self.stream_by_command[cmd] = child
|
||||||
|
|
||||||
|
self.command_queue.extend(child_commands)
|
||||||
|
|
||||||
|
|
||||||
|
class HttpClient(layer.Layer):
|
||||||
|
@expect(events.Start)
|
||||||
|
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||||
|
if self.context.server.connected:
|
||||||
|
err = None
|
||||||
|
else:
|
||||||
|
err = yield commands.OpenConnection(self.context.server)
|
||||||
|
yield RegisterHttpConnection(self.context.server, err)
|
||||||
|
if err:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.context.server.alpn == b"h2":
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
child_layer = Http1Client(self.context)
|
||||||
|
self._handle_event = child_layer.handle_event
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
import abc
|
|
||||||
import typing
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from mitmproxy.proxy2 import events, layer
|
from mitmproxy.proxy2 import events, layer, commands
|
||||||
|
|
||||||
StreamId = int
|
StreamId = int
|
||||||
|
|
||||||
@ -18,14 +16,22 @@ class HttpEvent(events.Event):
|
|||||||
return f"{type(self).__name__}({repr(x) if x else ''})"
|
return f"{type(self).__name__}({repr(x) if x else ''})"
|
||||||
|
|
||||||
|
|
||||||
class HttpConnection(abc.ABC):
|
class HttpConnection(layer.Layer):
|
||||||
@abc.abstractmethod
|
pass
|
||||||
def handle_event(self, event: events.Event) -> typing.Iterator[HttpEvent]:
|
|
||||||
yield from ()
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
|
class HttpCommand(commands.Command):
|
||||||
yield from ()
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ReceiveHttp(HttpCommand):
|
||||||
|
event: HttpEvent
|
||||||
|
|
||||||
|
def __init__(self, event: HttpEvent):
|
||||||
|
self.event = event
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"Receive({self.event})"
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
|
import abc
|
||||||
import typing
|
import typing
|
||||||
from abc import abstractmethod
|
|
||||||
|
|
||||||
import h11
|
import h11
|
||||||
from h11._readers import ChunkedReader, ContentLengthReader, Http10Reader
|
from h11._readers import ChunkedReader, ContentLengthReader, Http10Reader
|
||||||
@ -9,8 +9,9 @@ from mitmproxy import http
|
|||||||
from mitmproxy.net.http import http1
|
from mitmproxy.net.http import http1
|
||||||
from mitmproxy.net.http.http1 import read_sansio as http1_sansio
|
from mitmproxy.net.http.http1 import read_sansio as http1_sansio
|
||||||
from mitmproxy.proxy2 import commands, events, layer
|
from mitmproxy.proxy2 import commands, events, layer
|
||||||
from mitmproxy.proxy2.context import Client, Connection, Server
|
from mitmproxy.proxy2.context import Client, Connection, Context, Server
|
||||||
from mitmproxy.proxy2.layers.http._base import StreamId
|
from mitmproxy.proxy2.layers.http._base import ReceiveHttp, StreamId
|
||||||
|
from mitmproxy.proxy2.utils import expect
|
||||||
from ._base import HttpConnection
|
from ._base import HttpConnection
|
||||||
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
|
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
|
||||||
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
|
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
|
||||||
@ -18,28 +19,32 @@ from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders
|
|||||||
TBodyReader = typing.Union[ChunkedReader, Http10Reader, ContentLengthReader]
|
TBodyReader = typing.Union[ChunkedReader, Http10Reader, ContentLengthReader]
|
||||||
|
|
||||||
|
|
||||||
class Http1Connection(HttpConnection):
|
class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
|
||||||
conn: Connection
|
conn: Connection
|
||||||
stream_id: typing.Optional[StreamId] = None
|
stream_id: typing.Optional[StreamId] = None
|
||||||
request: typing.Optional[http.HTTPRequest] = None
|
request: typing.Optional[http.HTTPRequest] = None
|
||||||
response: typing.Optional[http.HTTPResponse] = None
|
response: typing.Optional[http.HTTPResponse] = None
|
||||||
request_done: bool = False
|
request_done: bool = False
|
||||||
response_done: bool = False
|
response_done: bool = False
|
||||||
state: typing.Callable[[events.Event], typing.Iterator[HttpEvent]]
|
state: typing.Callable[[events.Event], layer.CommandGenerator[None]]
|
||||||
body_reader: TBodyReader
|
body_reader: TBodyReader
|
||||||
buf: ReceiveBuffer
|
buf: ReceiveBuffer
|
||||||
|
|
||||||
def __init__(self, conn: Connection):
|
def __init__(self, context: Context, conn: Connection):
|
||||||
|
super().__init__(context)
|
||||||
assert isinstance(conn, Connection)
|
assert isinstance(conn, Connection)
|
||||||
self.conn = conn
|
self.conn = conn
|
||||||
self.buf = ReceiveBuffer()
|
self.buf = ReceiveBuffer()
|
||||||
|
|
||||||
def handle_event(self, event: events.Event) -> typing.Iterator[HttpEvent]:
|
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||||
if isinstance(event, events.DataReceived):
|
if isinstance(event, HttpEvent):
|
||||||
self.buf += event.data
|
yield from self.send(event)
|
||||||
yield from self.state(event)
|
else:
|
||||||
|
if isinstance(event, events.DataReceived):
|
||||||
|
self.buf += event.data
|
||||||
|
yield from self.state(event)
|
||||||
|
|
||||||
@abstractmethod
|
@abc.abstractmethod
|
||||||
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
|
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
|
||||||
yield from ()
|
yield from ()
|
||||||
|
|
||||||
@ -51,7 +56,7 @@ class Http1Connection(HttpConnection):
|
|||||||
else:
|
else:
|
||||||
return ContentLengthReader(expected_size)
|
return ContentLengthReader(expected_size)
|
||||||
|
|
||||||
def read_body(self, event: events.Event, is_request: bool) -> typing.Iterator[HttpEvent]:
|
def read_body(self, event: events.Event, is_request: bool) -> layer.CommandGenerator[None]:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
if isinstance(event, events.DataReceived):
|
if isinstance(event, events.DataReceived):
|
||||||
@ -63,9 +68,9 @@ class Http1Connection(HttpConnection):
|
|||||||
except h11.ProtocolError as e:
|
except h11.ProtocolError as e:
|
||||||
yield commands.CloseConnection(self.conn)
|
yield commands.CloseConnection(self.conn)
|
||||||
if is_request:
|
if is_request:
|
||||||
yield RequestProtocolError(self.stream_id, str(e))
|
yield ReceiveHttp(RequestProtocolError(self.stream_id, str(e)))
|
||||||
else:
|
else:
|
||||||
yield ResponseProtocolError(self.stream_id, str(e))
|
yield ReceiveHttp(ResponseProtocolError(self.stream_id, str(e)))
|
||||||
return
|
return
|
||||||
|
|
||||||
if h11_event is None:
|
if h11_event is None:
|
||||||
@ -73,17 +78,17 @@ class Http1Connection(HttpConnection):
|
|||||||
elif isinstance(h11_event, h11.Data):
|
elif isinstance(h11_event, h11.Data):
|
||||||
h11_event.data: bytearray # type checking
|
h11_event.data: bytearray # type checking
|
||||||
if is_request:
|
if is_request:
|
||||||
yield RequestData(self.stream_id, bytes(h11_event.data))
|
yield ReceiveHttp(RequestData(self.stream_id, bytes(h11_event.data)))
|
||||||
else:
|
else:
|
||||||
yield ResponseData(self.stream_id, bytes(h11_event.data))
|
yield ReceiveHttp(ResponseData(self.stream_id, bytes(h11_event.data)))
|
||||||
elif isinstance(h11_event, h11.EndOfMessage):
|
elif isinstance(h11_event, h11.EndOfMessage):
|
||||||
if is_request:
|
if is_request:
|
||||||
yield RequestEndOfMessage(self.stream_id)
|
yield ReceiveHttp(RequestEndOfMessage(self.stream_id))
|
||||||
else:
|
else:
|
||||||
yield ResponseEndOfMessage(self.stream_id)
|
yield ReceiveHttp(ResponseEndOfMessage(self.stream_id))
|
||||||
return
|
return
|
||||||
|
|
||||||
def wait(self, event: events.Event) -> typing.Iterator[HttpEvent]:
|
def wait(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||||
"""
|
"""
|
||||||
We wait for the current flow to be finished before parsing the next message,
|
We wait for the current flow to be finished before parsing the next message,
|
||||||
as we may want to upgrade to WebSocket or plain TCP before that.
|
as we may want to upgrade to WebSocket or plain TCP before that.
|
||||||
@ -101,8 +106,8 @@ class Http1Server(Http1Connection):
|
|||||||
"""A simple HTTP/1 server with no pipelining support."""
|
"""A simple HTTP/1 server with no pipelining support."""
|
||||||
conn: Client
|
conn: Client
|
||||||
|
|
||||||
def __init__(self, conn: Client):
|
def __init__(self, context: Context):
|
||||||
super().__init__(conn)
|
super().__init__(context, context.client)
|
||||||
self.stream_id = 1
|
self.stream_id = 1
|
||||||
self.state = self.read_request_headers
|
self.state = self.read_request_headers
|
||||||
|
|
||||||
@ -138,7 +143,7 @@ class Http1Server(Http1Connection):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"{event}")
|
raise NotImplementedError(f"{event}")
|
||||||
|
|
||||||
def mark_done(self, *, request: bool = False, response: bool = False):
|
def mark_done(self, *, request: bool = False, response: bool = False) -> layer.CommandGenerator[None]:
|
||||||
if request:
|
if request:
|
||||||
self.request_done = True
|
self.request_done = True
|
||||||
if response:
|
if response:
|
||||||
@ -152,13 +157,13 @@ class Http1Server(Http1Connection):
|
|||||||
elif self.request_done:
|
elif self.request_done:
|
||||||
self.state = self.wait
|
self.state = self.wait
|
||||||
|
|
||||||
def read_request_headers(self, event: events.Event) -> typing.Iterator[HttpEvent]:
|
def read_request_headers(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||||
if isinstance(event, events.DataReceived):
|
if isinstance(event, events.DataReceived):
|
||||||
request_head = self.buf.maybe_extract_lines()
|
request_head = self.buf.maybe_extract_lines()
|
||||||
if request_head:
|
if request_head:
|
||||||
request_head = [bytes(x) for x in request_head] # TODO: Make url.parse compatible with bytearrays
|
request_head = [bytes(x) for x in request_head] # TODO: Make url.parse compatible with bytearrays
|
||||||
self.request = http.HTTPRequest.wrap(http1_sansio.read_request_head(request_head))
|
self.request = http.HTTPRequest.wrap(http1_sansio.read_request_head(request_head))
|
||||||
yield RequestHeaders(self.stream_id, self.request)
|
yield ReceiveHttp(RequestHeaders(self.stream_id, self.request))
|
||||||
|
|
||||||
if self.request.first_line_format == "authority":
|
if self.request.first_line_format == "authority":
|
||||||
# The previous proxy server implementation tried to read the request body here:
|
# The previous proxy server implementation tried to read the request body here:
|
||||||
@ -180,10 +185,10 @@ class Http1Server(Http1Connection):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected event: {event}")
|
raise ValueError(f"Unexpected event: {event}")
|
||||||
|
|
||||||
def read_request_body(self, event: events.Event) -> typing.Iterator[HttpEvent]:
|
def read_request_body(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||||
for e in self.read_body(event, True):
|
for e in self.read_body(event, True):
|
||||||
yield e
|
yield e
|
||||||
if isinstance(e, RequestEndOfMessage):
|
if isinstance(e, ReceiveHttp) and isinstance(e.event, RequestEndOfMessage):
|
||||||
yield from self.mark_done(request=True)
|
yield from self.mark_done(request=True)
|
||||||
|
|
||||||
|
|
||||||
@ -192,8 +197,8 @@ class Http1Client(Http1Connection):
|
|||||||
send_queue: typing.List[HttpEvent]
|
send_queue: typing.List[HttpEvent]
|
||||||
"""A queue of send events for flows other than the one that is currently being transmitted."""
|
"""A queue of send events for flows other than the one that is currently being transmitted."""
|
||||||
|
|
||||||
def __init__(self, conn: Server):
|
def __init__(self, context: Context):
|
||||||
super().__init__(conn)
|
super().__init__(context, context.server)
|
||||||
self.state = self.read_response_headers
|
self.state = self.read_response_headers
|
||||||
self.send_queue = []
|
self.send_queue = []
|
||||||
|
|
||||||
@ -231,7 +236,7 @@ class Http1Client(Http1Connection):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"{event}")
|
raise NotImplementedError(f"{event}")
|
||||||
|
|
||||||
def mark_done(self, *, request: bool = False, response: bool = False):
|
def mark_done(self, *, request: bool = False, response: bool = False) -> layer.CommandGenerator[None]:
|
||||||
if request:
|
if request:
|
||||||
self.request_done = True
|
self.request_done = True
|
||||||
if response:
|
if response:
|
||||||
@ -246,15 +251,15 @@ class Http1Client(Http1Connection):
|
|||||||
for ev in send_queue:
|
for ev in send_queue:
|
||||||
yield from self.send(ev)
|
yield from self.send(ev)
|
||||||
|
|
||||||
def read_response_headers(self, event: events.Event) -> typing.Iterator[HttpEvent]:
|
@expect(events.ConnectionEvent)
|
||||||
assert isinstance(event, events.ConnectionEvent)
|
def read_response_headers(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
|
||||||
if isinstance(event, events.DataReceived):
|
if isinstance(event, events.DataReceived):
|
||||||
response_head = self.buf.maybe_extract_lines()
|
response_head = self.buf.maybe_extract_lines()
|
||||||
|
|
||||||
if response_head:
|
if response_head:
|
||||||
response_head = [bytes(x) for x in response_head]
|
response_head = [bytes(x) for x in response_head]
|
||||||
self.response = http.HTTPResponse.wrap(http1_sansio.read_response_head(response_head))
|
self.response = http.HTTPResponse.wrap(http1_sansio.read_response_head(response_head))
|
||||||
yield ResponseHeaders(self.stream_id, self.response)
|
yield ReceiveHttp(ResponseHeaders(self.stream_id, self.response))
|
||||||
|
|
||||||
expected_size = http1.expected_http_body_size(self.request, self.response)
|
expected_size = http1.expected_http_body_size(self.request, self.response)
|
||||||
self.body_reader = self.make_body_reader(expected_size)
|
self.body_reader = self.make_body_reader(expected_size)
|
||||||
@ -266,23 +271,24 @@ class Http1Client(Http1Connection):
|
|||||||
elif isinstance(event, events.ConnectionClosed):
|
elif isinstance(event, events.ConnectionClosed):
|
||||||
if self.stream_id:
|
if self.stream_id:
|
||||||
if self.buf:
|
if self.buf:
|
||||||
yield ResponseProtocolError(self.stream_id, f"unexpected server response: {bytes(self.buf)}")
|
yield ReceiveHttp(
|
||||||
|
ResponseProtocolError(self.stream_id, f"unexpected server response: {bytes(self.buf)}"))
|
||||||
else:
|
else:
|
||||||
# The server has closed the connection to prevent us from continuing.
|
# The server has closed the connection to prevent us from continuing.
|
||||||
# We need to signal that to the stream.
|
# We need to signal that to the stream.
|
||||||
# https://tools.ietf.org/html/rfc7231#section-6.5.11
|
# https://tools.ietf.org/html/rfc7231#section-6.5.11
|
||||||
yield ResponseProtocolError(self.stream_id, "server closed connection")
|
yield ReceiveHttp(ResponseProtocolError(self.stream_id, "server closed connection"))
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
yield commands.CloseConnection(self.conn)
|
yield commands.CloseConnection(self.conn)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected event: {event}")
|
raise ValueError(f"Unexpected event: {event}")
|
||||||
|
|
||||||
def read_response_body(self, event: events.Event) -> typing.Iterator[HttpEvent]:
|
@expect(events.ConnectionEvent)
|
||||||
assert isinstance(event, events.ConnectionEvent)
|
def read_response_body(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
|
||||||
for e in self.read_body(event, False):
|
for e in self.read_body(event, False):
|
||||||
yield e
|
yield e
|
||||||
if isinstance(e, ResponseEndOfMessage):
|
if isinstance(e, ReceiveHttp) and isinstance(e.event, ResponseEndOfMessage):
|
||||||
self.state = self.read_response_headers
|
self.state = self.read_response_headers
|
||||||
yield from self.mark_done(response=True)
|
yield from self.mark_done(response=True)
|
||||||
|
|
||||||
|
@ -67,7 +67,7 @@ class TCPLayer(layer.Layer):
|
|||||||
if self.flow:
|
if self.flow:
|
||||||
tcp_message = tcp.TCPMessage(from_client, event.data)
|
tcp_message = tcp.TCPMessage(from_client, event.data)
|
||||||
self.flow.messages.append(tcp_message)
|
self.flow.messages.append(tcp_message)
|
||||||
yield TcpMessageHook(self.flow)t
|
yield TcpMessageHook(self.flow)
|
||||||
yield commands.SendData(send_to, tcp_message.content)
|
yield commands.SendData(send_to, tcp_message.content)
|
||||||
else:
|
else:
|
||||||
yield commands.SendData(send_to, event.data)
|
yield commands.SendData(send_to, event.data)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import struct
|
import struct
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Generator, Iterator, Optional, Tuple
|
from typing import Iterator, Optional, Tuple
|
||||||
|
|
||||||
from OpenSSL import SSL
|
from OpenSSL import SSL
|
||||||
|
|
||||||
@ -93,20 +93,6 @@ def parse_client_hello(data: bytes) -> Optional[net_tls.ClientHello]:
|
|||||||
HTTP_ALPNS = (b"h2", b"http/1.1", b"http/1.0", b"http/0.9")
|
HTTP_ALPNS = (b"h2", b"http/1.1", b"http/1.0", b"http/0.9")
|
||||||
|
|
||||||
|
|
||||||
class EstablishServerTLS(commands.ConnectionCommand):
|
|
||||||
"""Establish TLS on the given connection.
|
|
||||||
|
|
||||||
If TLS establishment fails, the connection will automatically be closed by the TLS layer."""
|
|
||||||
connection: context.Server
|
|
||||||
blocking = True
|
|
||||||
|
|
||||||
|
|
||||||
class EstablishServerTLSReply(events.CommandReply):
|
|
||||||
command: EstablishServerTLS
|
|
||||||
reply: Optional[str]
|
|
||||||
"""error message"""
|
|
||||||
|
|
||||||
|
|
||||||
# We need these classes as hooks can only have one argument at the moment.
|
# We need these classes as hooks can only have one argument at the moment.
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -131,62 +117,61 @@ class TlsClienthelloHook(Hook):
|
|||||||
|
|
||||||
|
|
||||||
class _TLSLayer(layer.Layer):
|
class _TLSLayer(layer.Layer):
|
||||||
tls: Dict[context.Connection, SSL.Connection]
|
conn: context.Connection
|
||||||
|
"""The connection for which we do TLS"""
|
||||||
|
tls: SSL.Connection = None
|
||||||
|
"""The OpenSSL connection object"""
|
||||||
child_layer: layer.Layer
|
child_layer: layer.Layer
|
||||||
ssl_context: Optional[SSL.Context] = None
|
|
||||||
|
|
||||||
def __init__(self, context: context.Context):
|
def __init__(self, context: context.Context, conn: context.Connection):
|
||||||
super().__init__(context)
|
super().__init__(context)
|
||||||
self.tls = {}
|
self.conn = conn
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
if not self.tls:
|
if not self.tls:
|
||||||
state = "inactive"
|
state = "inactive"
|
||||||
|
elif self.conn.tls_established:
|
||||||
|
state = f"passthrough {self.conn.sni} {self.conn.alpn}"
|
||||||
else:
|
else:
|
||||||
conn_states = []
|
state = f"negotiating {self.conn.sni} {self.conn.alpn}"
|
||||||
for conn in self.tls:
|
|
||||||
if conn.tls_established:
|
|
||||||
conn_states.append(f"passthrough {conn.sni} {conn.alpn}")
|
|
||||||
else:
|
|
||||||
conn_states.append(f"negotiating {conn.sni} {conn.alpn}")
|
|
||||||
state = ", ".join(conn_states)
|
|
||||||
return f"{type(self).__name__}({state})"
|
return f"{type(self).__name__}({state})"
|
||||||
|
|
||||||
def start_tls(self, conn: context.Connection, initial_data: bytes = b""):
|
def start_tls(self, initial_data: bytes = b""):
|
||||||
assert conn not in self.tls
|
assert not self.tls
|
||||||
assert conn.connected
|
assert self.conn.connected
|
||||||
conn.tls = True
|
self.conn.tls = True
|
||||||
|
|
||||||
tls_start = TlsStartData(conn, self.context)
|
tls_start = TlsStartData(self.conn, self.context)
|
||||||
yield TlsStartHook(tls_start)
|
yield TlsStartHook(tls_start)
|
||||||
self.tls[conn] = tls_start.ssl_conn
|
assert tls_start.ssl_conn
|
||||||
|
self.tls = tls_start.ssl_conn
|
||||||
|
|
||||||
yield from self.negotiate(conn, initial_data)
|
yield from self.negotiate(initial_data)
|
||||||
|
|
||||||
def tls_interact(self, conn: context.Connection) -> layer.CommandGenerator[None]:
|
def tls_interact(self) -> layer.CommandGenerator[None]:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
data = self.tls[conn].bio_read(65535)
|
data = self.tls.bio_read(65535)
|
||||||
except SSL.WantReadError:
|
except SSL.WantReadError:
|
||||||
# Okay, nothing more waiting to be sent.
|
# Okay, nothing more waiting to be sent.
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
yield commands.SendData(conn, data)
|
yield commands.SendData(self.conn, data)
|
||||||
|
|
||||||
def negotiate(self, conn: context.Connection, data: bytes) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]:
|
def negotiate(self, data: bytes) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]:
|
||||||
# bio_write errors for b"", so we need to check first if we actually received something.
|
# bio_write errors for b"", so we need to check first if we actually received something.
|
||||||
if data:
|
if data:
|
||||||
self.tls[conn].bio_write(data)
|
self.tls.bio_write(data)
|
||||||
try:
|
try:
|
||||||
self.tls[conn].do_handshake()
|
self.tls.do_handshake()
|
||||||
except SSL.WantReadError:
|
except SSL.WantReadError:
|
||||||
yield from self.tls_interact(conn)
|
yield from self.tls_interact()
|
||||||
return False, None
|
return False, None
|
||||||
except SSL.Error as e:
|
except SSL.Error as e:
|
||||||
# provide more detailed information for some errors.
|
# provide more detailed information for some errors.
|
||||||
last_err = e.args and isinstance(e.args[0], list) and e.args[0] and e.args[0][-1]
|
last_err = e.args and isinstance(e.args[0], list) and e.args[0] and e.args[0][-1]
|
||||||
if last_err == ('SSL routines', 'tls_process_server_certificate', 'certificate verify failed'):
|
if last_err == ('SSL routines', 'tls_process_server_certificate', 'certificate verify failed'):
|
||||||
verify_result = SSL._lib.SSL_get_verify_result(self.tls[conn]._ssl)
|
verify_result = SSL._lib.SSL_get_verify_result(self.tls._ssl)
|
||||||
error = SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(verify_result)).decode()
|
error = SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(verify_result)).decode()
|
||||||
err = f"Certificate verify failed: {error}"
|
err = f"Certificate verify failed: {error}"
|
||||||
elif last_err in [
|
elif last_err in [
|
||||||
@ -196,40 +181,40 @@ class _TLSLayer(layer.Layer):
|
|||||||
err = last_err[2]
|
err = last_err[2]
|
||||||
else:
|
else:
|
||||||
err = repr(e)
|
err = repr(e)
|
||||||
yield from self.on_handshake_error(conn, err)
|
yield from self.on_handshake_error(err)
|
||||||
return False, err
|
return False, err
|
||||||
else:
|
else:
|
||||||
# Get all peer certificates.
|
# Get all peer certificates.
|
||||||
# https://www.openssl.org/docs/man1.1.1/man3/SSL_get_peer_cert_chain.html
|
# https://www.openssl.org/docs/man1.1.1/man3/SSL_get_peer_cert_chain.html
|
||||||
# If called on the client side, the stack also contains the peer's certificate; if called on the server
|
# If called on the client side, the stack also contains the peer's certificate; if called on the server
|
||||||
# side, the peer's certificate must be obtained separately using SSL_get_peer_certificate(3).
|
# side, the peer's certificate must be obtained separately using SSL_get_peer_certificate(3).
|
||||||
all_certs = self.tls[conn].get_peer_cert_chain() or []
|
all_certs = self.tls.get_peer_cert_chain() or []
|
||||||
if conn == self.context.client:
|
if self.conn == self.context.client:
|
||||||
cert = self.tls[conn].get_peer_certificate()
|
cert = self.tls.get_peer_certificate()
|
||||||
if cert:
|
if cert:
|
||||||
all_certs.insert(0, cert)
|
all_certs.insert(0, cert)
|
||||||
|
|
||||||
conn.tls_established = True
|
self.conn.tls_established = True
|
||||||
conn.sni = self.tls[conn].get_servername()
|
self.conn.sni = self.tls.get_servername()
|
||||||
conn.alpn = self.tls[conn].get_alpn_proto_negotiated()
|
self.conn.alpn = self.tls.get_alpn_proto_negotiated()
|
||||||
conn.certificate_chain = [certs.Cert(x) for x in all_certs]
|
self.conn.certificate_chain = [certs.Cert(x) for x in all_certs]
|
||||||
conn.cipher_list = self.tls[conn].get_cipher_list()
|
self.conn.cipher_list = self.tls.get_cipher_list()
|
||||||
conn.tls_version = self.tls[conn].get_protocol_version_name()
|
self.conn.tls_version = self.tls.get_protocol_version_name()
|
||||||
conn.timestamp_tls_setup = time.time()
|
self.conn.timestamp_tls_setup = time.time()
|
||||||
yield commands.Log(f"TLS established: {conn}")
|
yield commands.Log(f"TLS established: {self.conn}")
|
||||||
yield from self.receive(conn, b"")
|
yield from self.receive(b"")
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
def receive(self, conn: context.Connection, data: bytes):
|
def receive(self, data: bytes):
|
||||||
if data:
|
if data:
|
||||||
self.tls[conn].bio_write(data)
|
self.tls.bio_write(data)
|
||||||
yield from self.tls_interact(conn)
|
yield from self.tls_interact()
|
||||||
|
|
||||||
plaintext = bytearray()
|
plaintext = bytearray()
|
||||||
close = False
|
close = False
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
plaintext.extend(self.tls[conn].recv(65535))
|
plaintext.extend(self.tls.recv(65535))
|
||||||
except SSL.WantReadError:
|
except SSL.WantReadError:
|
||||||
break
|
break
|
||||||
except SSL.ZeroReturnError:
|
except SSL.ZeroReturnError:
|
||||||
@ -238,76 +223,90 @@ class _TLSLayer(layer.Layer):
|
|||||||
|
|
||||||
if plaintext:
|
if plaintext:
|
||||||
yield from self.event_to_child(
|
yield from self.event_to_child(
|
||||||
events.DataReceived(conn, bytes(plaintext))
|
events.DataReceived(self.conn, bytes(plaintext))
|
||||||
)
|
)
|
||||||
if close:
|
if close:
|
||||||
conn.state &= ~context.ConnectionState.CAN_READ
|
self.conn.state &= ~context.ConnectionState.CAN_READ
|
||||||
yield commands.Log(f"TLS close_notify {conn=}")
|
yield commands.Log(f"TLS close_notify {self.conn}")
|
||||||
yield from self.event_to_child(
|
yield from self.event_to_child(
|
||||||
events.ConnectionClosed(conn)
|
events.ConnectionClosed(self.conn)
|
||||||
)
|
)
|
||||||
|
|
||||||
def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]:
|
def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||||
for command in self.child_layer.handle_event(event):
|
for command in self.child_layer.handle_event(event):
|
||||||
if isinstance(command, commands.SendData) and command.connection in self.tls:
|
if isinstance(command, commands.SendData) and command.connection == self.conn:
|
||||||
self.tls[command.connection].sendall(command.data)
|
self.tls.sendall(command.data)
|
||||||
yield from self.tls_interact(command.connection)
|
yield from self.tls_interact()
|
||||||
else:
|
else:
|
||||||
yield command
|
yield command
|
||||||
|
|
||||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||||
if isinstance(event, events.DataReceived) and event.connection in self.tls:
|
if isinstance(event, events.DataReceived) and event.connection == self.conn:
|
||||||
if not event.connection.tls_established:
|
if not event.connection.tls_established:
|
||||||
yield from self.negotiate(event.connection, event.data)
|
yield from self.negotiate(event.data)
|
||||||
else:
|
else:
|
||||||
yield from self.receive(event.connection, event.data)
|
yield from self.receive(event.data)
|
||||||
elif isinstance(event, events.ConnectionClosed) and event.connection in self.tls:
|
elif isinstance(event, events.ConnectionClosed) and event.connection == self.conn:
|
||||||
if event.connection.tls_established:
|
if event.connection.tls_established:
|
||||||
if self.tls[event.connection].get_shutdown() & SSL.RECEIVED_SHUTDOWN:
|
if self.tls.get_shutdown() & SSL.RECEIVED_SHUTDOWN:
|
||||||
pass # We have already dispatched a ConnectionClosed to the child layer.
|
pass # We have already dispatched a ConnectionClosed to the child layer.
|
||||||
else:
|
else:
|
||||||
yield from self.event_to_child(event)
|
yield from self.event_to_child(event)
|
||||||
else:
|
else:
|
||||||
yield from self.on_handshake_error(event.connection, "connection closed without notice")
|
yield from self.on_handshake_error("connection closed without notice")
|
||||||
else:
|
else:
|
||||||
yield from self.event_to_child(event)
|
yield from self.event_to_child(event)
|
||||||
|
|
||||||
def on_handshake_error(self, conn: context.Connection, err: str) -> layer.CommandGenerator[None]:
|
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
|
||||||
yield commands.CloseConnection(conn)
|
yield commands.CloseConnection(self.conn)
|
||||||
|
|
||||||
|
|
||||||
class ServerTLSLayer(_TLSLayer):
|
class ServerTLSLayer(_TLSLayer):
|
||||||
"""
|
"""
|
||||||
This layer manages TLS for potentially multiple server connections.
|
This layer establishes TLS for a single server connection.
|
||||||
"""
|
"""
|
||||||
command_to_reply_to: Dict[context.Connection, commands.OpenConnection]
|
command_to_reply_to: Optional[commands.OpenConnection] = None
|
||||||
|
|
||||||
def __init__(self, context: context.Context):
|
def __init__(self, context: context.Context):
|
||||||
super().__init__(context)
|
super().__init__(context, context.server)
|
||||||
self.command_to_reply_to = {}
|
|
||||||
self.child_layer = layer.NextLayer(self.context)
|
self.child_layer = layer.NextLayer(self.context)
|
||||||
|
|
||||||
def negotiate(self, conn: context.Connection, data: bytes) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]:
|
@expect(events.Start)
|
||||||
done, err = yield from super().negotiate(conn, data)
|
def state_start(self, _) -> layer.CommandGenerator[None]:
|
||||||
|
self.context.server.tls = True
|
||||||
|
if self.context.server.connected:
|
||||||
|
yield from self.start_tls()
|
||||||
|
self._handle_event = super()._handle_event
|
||||||
|
|
||||||
|
_handle_event = state_start
|
||||||
|
|
||||||
|
def negotiate(self, data: bytes) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]:
|
||||||
|
done, err = yield from super().negotiate(data)
|
||||||
if done or err:
|
if done or err:
|
||||||
cmd = self.command_to_reply_to.pop(conn)
|
cmd = self.command_to_reply_to
|
||||||
yield from self.event_to_child(events.OpenConnectionReply(cmd, err))
|
yield from self.event_to_child(events.OpenConnectionReply(cmd, err))
|
||||||
|
self.command_to_reply_to = None
|
||||||
return done, err
|
return done, err
|
||||||
|
|
||||||
def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]:
|
def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||||
for command in super().event_to_child(event):
|
for command in super().event_to_child(event):
|
||||||
if isinstance(command, EstablishServerTLS):
|
if isinstance(command, commands.OpenConnection) and command.connection == self.context.server:
|
||||||
self.command_to_reply_to[command.connection] = command
|
# create our own OpenConnection command object that blocks here.
|
||||||
yield from self.start_tls(command.connection)
|
err = yield commands.OpenConnection(command.connection)
|
||||||
|
if err:
|
||||||
|
yield from self.event_to_child(events.OpenConnectionReply(command, err))
|
||||||
|
else:
|
||||||
|
self.command_to_reply_to = command
|
||||||
|
yield from self.start_tls()
|
||||||
else:
|
else:
|
||||||
yield command
|
yield command
|
||||||
|
|
||||||
def on_handshake_error(self, conn: context.Connection, err: str) -> layer.CommandGenerator[None]:
|
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
|
||||||
yield commands.Log(
|
yield commands.Log(
|
||||||
f"Server TLS handshake failed. {err}",
|
f"Server TLS handshake failed. {err}",
|
||||||
level="warn"
|
level="warn"
|
||||||
)
|
)
|
||||||
yield from super().on_handshake_error(conn, err)
|
yield from super().on_handshake_error(err)
|
||||||
|
|
||||||
|
|
||||||
class ClientTLSLayer(_TLSLayer):
|
class ClientTLSLayer(_TLSLayer):
|
||||||
@ -331,10 +330,9 @@ class ClientTLSLayer(_TLSLayer):
|
|||||||
|
|
||||||
def __init__(self, context: context.Context):
|
def __init__(self, context: context.Context):
|
||||||
assert isinstance(context.layers[-1], ServerTLSLayer)
|
assert isinstance(context.layers[-1], ServerTLSLayer)
|
||||||
super().__init__(context)
|
super().__init__(context, self.context.client)
|
||||||
self.recv_buffer = bytearray()
|
self.recv_buffer = bytearray()
|
||||||
self.child_layer = layer.NextLayer(self.context)
|
self.child_layer = layer.NextLayer(self.context)
|
||||||
self._handle_event = self.state_start
|
|
||||||
|
|
||||||
@expect(events.Start)
|
@expect(events.Start)
|
||||||
def state_start(self, _) -> layer.CommandGenerator[None]:
|
def state_start(self, _) -> layer.CommandGenerator[None]:
|
||||||
@ -342,20 +340,21 @@ class ClientTLSLayer(_TLSLayer):
|
|||||||
self._handle_event = self.state_wait_for_clienthello
|
self._handle_event = self.state_wait_for_clienthello
|
||||||
yield from ()
|
yield from ()
|
||||||
|
|
||||||
|
_handle_event = state_start
|
||||||
|
|
||||||
def state_wait_for_clienthello(self, event: events.Event) -> layer.CommandGenerator[None]:
|
def state_wait_for_clienthello(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||||
client = self.context.client
|
if isinstance(event, events.DataReceived) and event.connection == self.conn:
|
||||||
if isinstance(event, events.DataReceived) and event.connection == client:
|
|
||||||
self.recv_buffer.extend(event.data)
|
self.recv_buffer.extend(event.data)
|
||||||
try:
|
try:
|
||||||
client_hello = parse_client_hello(self.recv_buffer)
|
client_hello = parse_client_hello(self.recv_buffer)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
yield commands.Log(f"Cannot parse ClientHello: {self.recv_buffer.hex()}")
|
yield commands.Log(f"Cannot parse ClientHello: {self.recv_buffer.hex()}")
|
||||||
yield commands.CloseConnection(client)
|
yield commands.CloseConnection(self.conn)
|
||||||
return
|
return
|
||||||
|
|
||||||
if client_hello:
|
if client_hello:
|
||||||
client.sni = client_hello.sni
|
self.conn.sni = client_hello.sni
|
||||||
client.alpn_offers = client_hello.alpn_protocols
|
self.conn.alpn_offers = client_hello.alpn_protocols
|
||||||
tls_clienthello = ClientHelloData(self.context)
|
tls_clienthello = ClientHelloData(self.context)
|
||||||
yield TlsClienthelloHook(tls_clienthello)
|
yield TlsClienthelloHook(tls_clienthello)
|
||||||
|
|
||||||
@ -365,7 +364,7 @@ class ClientTLSLayer(_TLSLayer):
|
|||||||
yield commands.Log("Unable to establish TLS connection with server. "
|
yield commands.Log("Unable to establish TLS connection with server. "
|
||||||
"Trying to establish TLS with client anyway.")
|
"Trying to establish TLS with client anyway.")
|
||||||
|
|
||||||
yield from self.start_tls(client, bytes(self.recv_buffer))
|
yield from self.start_tls(bytes(self.recv_buffer))
|
||||||
self.recv_buffer.clear()
|
self.recv_buffer.clear()
|
||||||
self._handle_event = super()._handle_event
|
self._handle_event = super()._handle_event
|
||||||
|
|
||||||
@ -379,25 +378,16 @@ class ClientTLSLayer(_TLSLayer):
|
|||||||
We often need information from the upstream connection to establish TLS with the client.
|
We often need information from the upstream connection to establish TLS with the client.
|
||||||
For example, we need to check if the client does ALPN or not.
|
For example, we need to check if the client does ALPN or not.
|
||||||
"""
|
"""
|
||||||
server = self.context.server
|
err = yield commands.OpenConnection(self.context.server)
|
||||||
if not server.connected:
|
|
||||||
err = yield commands.OpenConnection(server)
|
|
||||||
if err:
|
|
||||||
yield commands.Log(
|
|
||||||
f"Cannot establish server connection: {err}"
|
|
||||||
)
|
|
||||||
return err
|
|
||||||
|
|
||||||
err = yield EstablishServerTLS(server)
|
|
||||||
if err:
|
if err:
|
||||||
yield commands.Log(
|
yield commands.Log(f"Cannot establish server connection: {err}")
|
||||||
f"Cannot establish TLS with server: {err}"
|
|
||||||
)
|
|
||||||
return err
|
return err
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
def on_handshake_error(self, conn: context.Connection, err: str) -> layer.CommandGenerator[None]:
|
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
|
||||||
if conn.sni:
|
if self.conn.sni:
|
||||||
dest = conn.sni.decode("idna")
|
dest = self.conn.sni.decode("idna")
|
||||||
else:
|
else:
|
||||||
dest = human.format_address(self.context.server.address)
|
dest = human.format_address(self.context.server.address)
|
||||||
if "unknown ca" in err or "bad certificate" in err:
|
if "unknown ca" in err or "bad certificate" in err:
|
||||||
@ -409,4 +399,4 @@ class ClientTLSLayer(_TLSLayer):
|
|||||||
f"The client {keyword} trust the proxy's certificate for {dest} ({err})",
|
f"The client {keyword} trust the proxy's certificate for {dest} ({err})",
|
||||||
level="warn"
|
level="warn"
|
||||||
)
|
)
|
||||||
yield from super().on_handshake_error(conn, err)
|
yield from super().on_handshake_error(err)
|
||||||
|
@ -18,7 +18,7 @@ def expect(*event_types):
|
|||||||
if isinstance(event, event_types):
|
if isinstance(event, event_types):
|
||||||
yield from f(self, event)
|
yield from f(self, event)
|
||||||
else:
|
else:
|
||||||
event_types_str = '|'.join(e.__name__ for e in event_types)
|
event_types_str = '|'.join(e.__name__ for e in event_types) or "no events"
|
||||||
raise AssertionError(f"Unexpected event type at {f.__qualname__}: Expected {event_types_str}, got {event}.")
|
raise AssertionError(f"Unexpected event type at {f.__qualname__}: Expected {event_types_str}, got {event}.")
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
@ -106,8 +106,8 @@ def test_redirect(strategy, https_server, https_client, tctx):
|
|||||||
p << OpenConnection(server)
|
p << OpenConnection(server)
|
||||||
p >> reply(None)
|
p >> reply(None)
|
||||||
if https_server:
|
if https_server:
|
||||||
p << tls.EstablishServerTLS(server)
|
pass # p << tls.EstablishServerTLS(server)
|
||||||
p >> reply_establish_server_tls()
|
# p >> reply_establish_server_tls()
|
||||||
p << SendData(server, b"GET / HTTP/1.1\r\nHost: redirected.site\r\n\r\n")
|
p << SendData(server, b"GET / HTTP/1.1\r\nHost: redirected.site\r\n\r\n")
|
||||||
p >> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
|
p >> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
|
||||||
p << SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
|
p << SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
|
||||||
@ -123,6 +123,7 @@ def test_multiple_server_connections(tctx):
|
|||||||
"""Test multiple requests being rewritten to different targets."""
|
"""Test multiple requests being rewritten to different targets."""
|
||||||
server1 = Placeholder()
|
server1 = Placeholder()
|
||||||
server2 = Placeholder()
|
server2 = Placeholder()
|
||||||
|
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
||||||
|
|
||||||
def redirect(to: str):
|
def redirect(to: str):
|
||||||
def side_effect(flow: HTTPFlow):
|
def side_effect(flow: HTTPFlow):
|
||||||
@ -131,7 +132,7 @@ def test_multiple_server_connections(tctx):
|
|||||||
return side_effect
|
return side_effect
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
playbook
|
||||||
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||||
<< http.HttpRequestHook(Placeholder())
|
<< http.HttpRequestHook(Placeholder())
|
||||||
>> reply(side_effect=redirect("http://one.redirect/"))
|
>> reply(side_effect=redirect("http://one.redirect/"))
|
||||||
@ -140,6 +141,9 @@ def test_multiple_server_connections(tctx):
|
|||||||
<< SendData(server1, b"GET / HTTP/1.1\r\nHost: one.redirect\r\n\r\n")
|
<< SendData(server1, b"GET / HTTP/1.1\r\nHost: one.redirect\r\n\r\n")
|
||||||
>> DataReceived(server1, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
|
>> DataReceived(server1, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
|
||||||
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
|
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||||
<< http.HttpRequestHook(Placeholder())
|
<< http.HttpRequestHook(Placeholder())
|
||||||
>> reply(side_effect=redirect("http://two.redirect/"))
|
>> reply(side_effect=redirect("http://two.redirect/"))
|
||||||
|
@ -4,8 +4,6 @@ import typing
|
|||||||
import pytest
|
import pytest
|
||||||
from OpenSSL import SSL
|
from OpenSSL import SSL
|
||||||
|
|
||||||
import mitmproxy.proxy2.layer
|
|
||||||
import mitmproxy.proxy2.layers.tls
|
|
||||||
from mitmproxy.proxy2 import commands, context, events, layer
|
from mitmproxy.proxy2 import commands, context, events, layer
|
||||||
from mitmproxy.proxy2.layers import tls
|
from mitmproxy.proxy2.layers import tls
|
||||||
from mitmproxy.utils import data
|
from mitmproxy.utils import data
|
||||||
@ -110,11 +108,11 @@ class TlsEchoLayer(tutils.EchoLayer):
|
|||||||
err: typing.Optional[str] = None
|
err: typing.Optional[str] = None
|
||||||
|
|
||||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||||
if isinstance(event, events.DataReceived) and event.data == b"establish-server-tls":
|
if isinstance(event, events.DataReceived) and event.data == b"open-connection":
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
err = yield tls.EstablishServerTLS(self.context.server)
|
err = yield commands.OpenConnection(self.context.server)
|
||||||
if err:
|
if err:
|
||||||
yield commands.SendData(event.connection, f"server-tls-failed: {err}".encode())
|
yield commands.SendData(event.connection, f"open-connection failed: {err}".encode())
|
||||||
else:
|
else:
|
||||||
yield from super()._handle_event(event)
|
yield from super()._handle_event(event)
|
||||||
|
|
||||||
@ -208,9 +206,6 @@ class TestServerTLS:
|
|||||||
data = tutils.Placeholder()
|
data = tutils.Placeholder()
|
||||||
assert (
|
assert (
|
||||||
playbook
|
playbook
|
||||||
>> events.DataReceived(tctx.client, b"establish-server-tls")
|
|
||||||
<< layer.NextLayerHook(tutils.Placeholder())
|
|
||||||
>> tutils.reply_next_layer(TlsEchoLayer)
|
|
||||||
<< tls.TlsStartHook(tutils.Placeholder())
|
<< tls.TlsStartHook(tutils.Placeholder())
|
||||||
>> reply_tls_start()
|
>> reply_tls_start()
|
||||||
<< commands.SendData(tctx.server, data)
|
<< commands.SendData(tctx.server, data)
|
||||||
@ -233,6 +228,13 @@ class TestServerTLS:
|
|||||||
assert tctx.server.tls_established
|
assert tctx.server.tls_established
|
||||||
|
|
||||||
# Echo
|
# Echo
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
>> events.DataReceived(tctx.client, b"foo")
|
||||||
|
<< layer.NextLayerHook(tutils.Placeholder())
|
||||||
|
>> tutils.reply_next_layer(TlsEchoLayer)
|
||||||
|
<< commands.SendData(tctx.client, b"foo")
|
||||||
|
)
|
||||||
_test_echo(playbook, tssl, tctx.server)
|
_test_echo(playbook, tssl, tctx.server)
|
||||||
|
|
||||||
with pytest.raises(ssl.SSLWantReadError):
|
with pytest.raises(ssl.SSLWantReadError):
|
||||||
@ -274,7 +276,7 @@ class TestServerTLS:
|
|||||||
assert (
|
assert (
|
||||||
playbook
|
playbook
|
||||||
>> events.DataReceived(tctx.server, tssl.out.read())
|
>> events.DataReceived(tctx.server, tssl.out.read())
|
||||||
<< commands.Log("Server TLS handshake failed. Certificate verify failed: Hostname mismatch","warn")
|
<< commands.Log("Server TLS handshake failed. Certificate verify failed: Hostname mismatch", "warn")
|
||||||
<< commands.CloseConnection(tctx.server)
|
<< commands.CloseConnection(tctx.server)
|
||||||
<< commands.SendData(tctx.client, b"server-tls-failed: Certificate verify failed: Hostname mismatch")
|
<< commands.SendData(tctx.client, b"server-tls-failed: Certificate verify failed: Hostname mismatch")
|
||||||
)
|
)
|
||||||
|
@ -170,7 +170,11 @@ class Playbook:
|
|||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if hasattr(x, "playbook_eval"):
|
if hasattr(x, "playbook_eval"):
|
||||||
x = self.expected[i] = x.playbook_eval(self)
|
try:
|
||||||
|
x = self.expected[i] = x.playbook_eval(self)
|
||||||
|
except Exception:
|
||||||
|
self.actual.append(_TracebackInPlaybook(traceback.format_exc()))
|
||||||
|
break
|
||||||
for name, value in vars(x).items():
|
for name, value in vars(x).items():
|
||||||
if isinstance(value, _Placeholder):
|
if isinstance(value, _Placeholder):
|
||||||
setattr(x, name, value())
|
setattr(x, name, value())
|
||||||
@ -273,8 +277,7 @@ class reply(events.Event):
|
|||||||
self.to = cmd
|
self.to = cmd
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
actual_str = "\n".join(_fmt_entry(x) for x in playbook.actual)
|
raise AssertionError(f"Expected command {self.to} did not occur.")
|
||||||
raise AssertionError(f"Expected command {self.to} did not occur:\n{actual_str}")
|
|
||||||
|
|
||||||
assert isinstance(self.to, commands.Command)
|
assert isinstance(self.to, commands.Command)
|
||||||
if isinstance(self.to, commands.Hook):
|
if isinstance(self.to, commands.Hook):
|
||||||
|
Loading…
Reference in New Issue
Block a user