mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-22 15:37:45 +00:00
[sans-io] refactor
This commit is contained in:
parent
0740c673bd
commit
7efe27be74
@ -40,7 +40,7 @@ class ConnectionCommand(Command):
|
||||
"""
|
||||
connection: Connection
|
||||
|
||||
def __init__(self, connection: Connection) -> None:
|
||||
def __init__(self, connection: Connection):
|
||||
self.connection = connection
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ class SendData(ConnectionCommand):
|
||||
"""
|
||||
data: bytes
|
||||
|
||||
def __init__(self, connection: Connection, data: bytes) -> None:
|
||||
def __init__(self, connection: Connection, data: bytes):
|
||||
super().__init__(connection)
|
||||
self.data = data
|
||||
|
||||
@ -74,6 +74,16 @@ class CloseConnection(ConnectionCommand):
|
||||
"""
|
||||
|
||||
|
||||
class ProtocolError(Command):
|
||||
"""
|
||||
Indicate that an unrecoverable protocol error has occured.
|
||||
"""
|
||||
message: str
|
||||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
|
||||
|
||||
class Hook(Command):
|
||||
"""
|
||||
Callback to the master (like ".ask()")
|
||||
@ -135,4 +145,3 @@ class Log(Command):
|
||||
return f"Log({self.message}, {self.level})"
|
||||
|
||||
|
||||
TCommandGenerator = typing.Generator[Command, typing.Any, None]
|
||||
|
@ -8,16 +8,22 @@ from abc import abstractmethod
|
||||
|
||||
from mitmproxy import log
|
||||
from mitmproxy.proxy2 import commands, events
|
||||
from mitmproxy.proxy2.commands import Hook
|
||||
from mitmproxy.proxy2.commands import Command, Hook
|
||||
from mitmproxy.proxy2.context import Connection, Context
|
||||
|
||||
T = typing.TypeVar('T')
|
||||
CommandGenerator = typing.Generator[Command, typing.Optional[events.CommandReply], T]
|
||||
"""
|
||||
A function annotated with CommandGenerator[bool] may yield commands and ultimately return a boolean value.
|
||||
"""
|
||||
|
||||
|
||||
class Paused(typing.NamedTuple):
|
||||
"""
|
||||
State of a layer that's paused because it is waiting for a command reply.
|
||||
"""
|
||||
command: commands.Command
|
||||
generator: commands.TCommandGenerator
|
||||
generator: CommandGenerator
|
||||
|
||||
|
||||
class Layer:
|
||||
@ -65,11 +71,11 @@ class Layer:
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
def _handle_event(self, event: events.Event) -> CommandGenerator[None]:
|
||||
"""Handle a proxy server event"""
|
||||
yield from () # pragma: no cover
|
||||
|
||||
def handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
def handle_event(self, event: events.Event) -> CommandGenerator[None]:
|
||||
if self._paused:
|
||||
# did we just receive the reply we were waiting for?
|
||||
pause_finished = (
|
||||
@ -88,7 +94,7 @@ class Layer:
|
||||
command_generator = self._handle_event(event)
|
||||
yield from self.__process(command_generator)
|
||||
|
||||
def __process(self, command_generator: commands.TCommandGenerator, send=None):
|
||||
def __process(self, command_generator: CommandGenerator, send=None):
|
||||
"""
|
||||
yield all commands from a generator.
|
||||
if a command is blocking, the layer is paused and this function returns before
|
||||
|
@ -3,9 +3,8 @@ import typing
|
||||
|
||||
from mitmproxy import flow, http
|
||||
from mitmproxy.proxy.protocol.http import HTTPMode
|
||||
from mitmproxy.proxy2 import commands, events
|
||||
from mitmproxy.proxy2 import commands, events, layer
|
||||
from mitmproxy.proxy2.context import Connection, Context, Server
|
||||
from mitmproxy.proxy2.layer import Layer, NextLayer
|
||||
from mitmproxy.proxy2.layers.tls import EstablishServerTLS, EstablishServerTLSReply, HTTP_ALPNS
|
||||
from mitmproxy.proxy2.utils import expect
|
||||
from mitmproxy.utils import human
|
||||
@ -60,12 +59,12 @@ class SendHttp(HttpCommand):
|
||||
return f"Send({self.event})"
|
||||
|
||||
|
||||
class HttpStream(Layer):
|
||||
class HttpStream(layer.Layer):
|
||||
request_body_buf: bytes
|
||||
response_body_buf: bytes
|
||||
flow: http.HTTPFlow
|
||||
stream_id: StreamId
|
||||
child_layer: typing.Optional[Layer] = None
|
||||
child_layer: typing.Optional[layer.Layer] = None
|
||||
|
||||
@property
|
||||
def mode(self):
|
||||
@ -80,7 +79,7 @@ class HttpStream(Layer):
|
||||
self.server_state = self.state_uninitialized
|
||||
|
||||
@expect(events.Start, HttpEvent)
|
||||
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, events.Start):
|
||||
self.client_state = self.state_wait_for_request_headers
|
||||
elif isinstance(event, (RequestProtocolError, ResponseProtocolError)):
|
||||
@ -91,7 +90,7 @@ class HttpStream(Layer):
|
||||
yield from self.server_state(event)
|
||||
|
||||
@expect(RequestHeaders)
|
||||
def state_wait_for_request_headers(self, event: RequestHeaders) -> commands.TCommandGenerator:
|
||||
def state_wait_for_request_headers(self, event: RequestHeaders) -> layer.CommandGenerator[None]:
|
||||
self.stream_id = event.stream_id
|
||||
self.flow = http.HTTPFlow(
|
||||
self.context.client,
|
||||
@ -146,7 +145,7 @@ class HttpStream(Layer):
|
||||
self.server_state = self.state_wait_for_response_headers
|
||||
|
||||
@expect(RequestData, RequestEndOfMessage)
|
||||
def state_stream_request_body(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
def state_stream_request_body(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, RequestData):
|
||||
if callable(self.flow.request.stream):
|
||||
data = self.flow.request.stream(event.data)
|
||||
@ -158,7 +157,7 @@ class HttpStream(Layer):
|
||||
self.client_state = self.state_done
|
||||
|
||||
@expect(RequestData, RequestEndOfMessage)
|
||||
def state_consume_request_body(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
def state_consume_request_body(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, RequestData):
|
||||
self.request_body_buf += event.data
|
||||
elif isinstance(event, RequestEndOfMessage):
|
||||
@ -182,7 +181,7 @@ class HttpStream(Layer):
|
||||
self.client_state = self.state_done
|
||||
|
||||
@expect(ResponseHeaders)
|
||||
def state_wait_for_response_headers(self, event: ResponseHeaders) -> commands.TCommandGenerator:
|
||||
def state_wait_for_response_headers(self, event: ResponseHeaders) -> layer.CommandGenerator[None]:
|
||||
self.flow.response = event.response
|
||||
yield HttpResponseHeadersHook(self.flow)
|
||||
if self.flow.response.stream:
|
||||
@ -192,7 +191,7 @@ class HttpStream(Layer):
|
||||
self.server_state = self.state_consume_response_body
|
||||
|
||||
@expect(ResponseData, ResponseEndOfMessage)
|
||||
def state_stream_response_body(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
def state_stream_response_body(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, ResponseData):
|
||||
if callable(self.flow.response.stream):
|
||||
data = self.flow.response.stream(event.data)
|
||||
@ -204,7 +203,7 @@ class HttpStream(Layer):
|
||||
self.server_state = self.state_done
|
||||
|
||||
@expect(ResponseData, ResponseEndOfMessage)
|
||||
def state_consume_response_body(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
def state_consume_response_body(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, ResponseData):
|
||||
self.response_body_buf += event.data
|
||||
elif isinstance(event, ResponseEndOfMessage):
|
||||
@ -222,7 +221,7 @@ class HttpStream(Layer):
|
||||
def handle_protocol_error(
|
||||
self,
|
||||
event: typing.Union[RequestProtocolError, ResponseProtocolError]
|
||||
) -> commands.TCommandGenerator:
|
||||
) -> layer.CommandGenerator[None]:
|
||||
self.flow.error = flow.Error(event.message)
|
||||
yield HttpErrorHook(self.flow)
|
||||
|
||||
@ -232,7 +231,7 @@ class HttpStream(Layer):
|
||||
yield SendHttp(event, self.context.client)
|
||||
return
|
||||
|
||||
def make_server_connection(self) -> typing.Generator[commands.Command, typing.Any, bool]:
|
||||
def make_server_connection(self) -> layer.CommandGenerator[bool]:
|
||||
connection, err = yield GetHttpConnection(
|
||||
(self.flow.request.host, self.flow.request.port),
|
||||
self.flow.request.scheme == "https"
|
||||
@ -244,7 +243,7 @@ class HttpStream(Layer):
|
||||
self.context.server = self.flow.server_conn = connection
|
||||
return True
|
||||
|
||||
def handle_connect(self) -> commands.TCommandGenerator:
|
||||
def handle_connect(self) -> layer.CommandGenerator[None]:
|
||||
yield HttpConnectHook(self.flow)
|
||||
|
||||
self.context.server = Server((self.flow.request.host, self.flow.request.port))
|
||||
@ -261,7 +260,7 @@ class HttpStream(Layer):
|
||||
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response), self.context.client)
|
||||
|
||||
if 200 <= self.flow.response.status_code < 300:
|
||||
self.child_layer = NextLayer(self.context)
|
||||
self.child_layer = layer.NextLayer(self.context)
|
||||
yield from self.child_layer.handle_event(events.Start())
|
||||
self._handle_event = self.passthrough
|
||||
else:
|
||||
@ -269,7 +268,7 @@ class HttpStream(Layer):
|
||||
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
||||
|
||||
@expect(RequestData, RequestEndOfMessage, events.Event)
|
||||
def passthrough(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
def passthrough(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
# HTTP events -> normal connection events
|
||||
if isinstance(event, RequestData):
|
||||
event = events.DataReceived(self.context.client, event.data)
|
||||
@ -288,14 +287,14 @@ class HttpStream(Layer):
|
||||
yield command
|
||||
|
||||
@expect()
|
||||
def state_uninitialized(self, _) -> commands.TCommandGenerator:
|
||||
def state_uninitialized(self, _) -> layer.CommandGenerator[None]:
|
||||
yield from ()
|
||||
|
||||
@expect()
|
||||
def state_done(self, _) -> commands.TCommandGenerator:
|
||||
def state_done(self, _) -> layer.CommandGenerator[None]:
|
||||
yield from ()
|
||||
|
||||
def state_errored(self, _) -> commands.TCommandGenerator:
|
||||
def state_errored(self, _) -> layer.CommandGenerator[None]:
|
||||
# silently consume every event.
|
||||
yield from ()
|
||||
|
||||
|
@ -2,7 +2,7 @@ import abc
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
|
||||
from mitmproxy.proxy2 import commands, events
|
||||
from mitmproxy.proxy2 import events, layer
|
||||
|
||||
StreamId = int
|
||||
|
||||
@ -24,7 +24,7 @@ class HttpConnection(abc.ABC):
|
||||
yield from ()
|
||||
|
||||
@abc.abstractmethod
|
||||
def send(self, event: HttpEvent) -> commands.TCommandGenerator:
|
||||
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
|
||||
yield from ()
|
||||
|
||||
|
||||
|
@ -8,7 +8,7 @@ from h11._receivebuffer import ReceiveBuffer
|
||||
from mitmproxy import http
|
||||
from mitmproxy.net.http import http1
|
||||
from mitmproxy.net.http.http1 import read_sansio as http1_sansio
|
||||
from mitmproxy.proxy2 import commands, events
|
||||
from mitmproxy.proxy2 import commands, events, layer
|
||||
from mitmproxy.proxy2.context import Client, Connection, Server
|
||||
from mitmproxy.proxy2.layers.http._base import StreamId
|
||||
from ._base import HttpConnection
|
||||
@ -40,7 +40,7 @@ class Http1Connection(HttpConnection):
|
||||
yield from self.state(event)
|
||||
|
||||
@abstractmethod
|
||||
def send(self, event: HttpEvent) -> commands.TCommandGenerator:
|
||||
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
|
||||
yield from ()
|
||||
|
||||
def make_body_reader(self, expected_size: typing.Optional[int]) -> TBodyReader:
|
||||
@ -106,7 +106,7 @@ class Http1Server(Http1Connection):
|
||||
self.stream_id = 1
|
||||
self.state = self.read_request_headers
|
||||
|
||||
def send(self, event: HttpEvent) -> commands.TCommandGenerator:
|
||||
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
|
||||
assert event.stream_id == self.stream_id
|
||||
if isinstance(event, ResponseHeaders):
|
||||
self.response = event.response
|
||||
@ -197,7 +197,7 @@ class Http1Client(Http1Connection):
|
||||
self.state = self.read_response_headers
|
||||
self.send_queue = []
|
||||
|
||||
def send(self, event: HttpEvent) -> commands.TCommandGenerator:
|
||||
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
|
||||
if not self.stream_id:
|
||||
assert isinstance(event, RequestHeaders)
|
||||
self.stream_id = event.stream_id
|
||||
|
@ -7,7 +7,7 @@ from mitmproxy.proxy2.utils import expect
|
||||
|
||||
class ReverseProxy(layer.Layer):
|
||||
@expect(events.Start)
|
||||
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
spec = server_spec.parse_with_mode(self.context.options.mode)[1]
|
||||
self.context.server = Server(spec.address)
|
||||
if spec.scheme not in ("http", "tcp"):
|
||||
@ -21,7 +21,7 @@ class ReverseProxy(layer.Layer):
|
||||
|
||||
class HttpProxy(layer.Layer):
|
||||
@expect(events.Start)
|
||||
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
child_layer = layer.NextLayer(self.context)
|
||||
self._handle_event = child_layer.handle_event
|
||||
yield from child_layer.handle_event(event)
|
||||
@ -29,7 +29,7 @@ class HttpProxy(layer.Layer):
|
||||
|
||||
class TransparentProxy(layer.Layer):
|
||||
@expect(events.Start)
|
||||
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
socket = yield commands.GetSocket(self.context.client)
|
||||
try:
|
||||
self.context.server.address = platform.original_addr(socket)
|
||||
|
@ -4,13 +4,12 @@ from wsproto import ConnectionType, WSConnection
|
||||
from wsproto.extensions import PerMessageDeflate
|
||||
|
||||
from mitmproxy import websocket, http, flow
|
||||
from mitmproxy.proxy2 import events, commands
|
||||
from mitmproxy.proxy2 import events, commands, layer
|
||||
from mitmproxy.proxy2.context import Context
|
||||
from mitmproxy.proxy2.layer import Layer
|
||||
from mitmproxy.proxy2.utils import expect
|
||||
|
||||
|
||||
class WebsocketLayer(Layer):
|
||||
class WebsocketLayer(layer.Layer):
|
||||
"""
|
||||
WebSocket layer that intercepts and relays messages.
|
||||
"""
|
||||
@ -29,7 +28,7 @@ class WebsocketLayer(Layer):
|
||||
assert context.server.connected
|
||||
|
||||
@expect(events.Start)
|
||||
def start(self, _) -> commands.TCommandGenerator:
|
||||
def start(self, _) -> layer.CommandGenerator[None]:
|
||||
extensions = []
|
||||
if 'Sec-WebSocket-Extensions' in self.handshake_flow.response.headers:
|
||||
if PerMessageDeflate.name in self.handshake_flow.response.headers['Sec-WebSocket-Extensions']:
|
||||
@ -60,7 +59,7 @@ class WebsocketLayer(Layer):
|
||||
_handle_event = start
|
||||
|
||||
@expect(events.DataReceived, events.ConnectionClosed)
|
||||
def process_data(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
def process_data(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, events.DataReceived):
|
||||
from_client = event.connection == self.context.client
|
||||
if from_client:
|
||||
|
@ -1,10 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from mitmproxy import flow, tcp
|
||||
from mitmproxy.proxy2 import commands, events
|
||||
from mitmproxy.proxy2 import commands, events, layer
|
||||
from mitmproxy.proxy2.commands import Hook
|
||||
from mitmproxy.proxy2.context import Context
|
||||
from mitmproxy.proxy2.layer import Layer
|
||||
from mitmproxy.proxy2.utils import expect
|
||||
|
||||
|
||||
@ -24,7 +23,7 @@ class TcpErrorHook(Hook):
|
||||
flow: tcp.TCPFlow
|
||||
|
||||
|
||||
class TCPLayer(Layer):
|
||||
class TCPLayer(layer.Layer):
|
||||
"""
|
||||
Simple TCP layer that just relays messages right now.
|
||||
"""
|
||||
@ -39,7 +38,7 @@ class TCPLayer(Layer):
|
||||
self.flow = tcp.TCPFlow(self.context.client, self.context.server, True)
|
||||
|
||||
@expect(events.Start)
|
||||
def start(self, _) -> commands.TCommandGenerator:
|
||||
def start(self, _) -> layer.CommandGenerator[None]:
|
||||
if self.flow:
|
||||
yield TcpStartHook(self.flow)
|
||||
|
||||
@ -57,7 +56,7 @@ class TCPLayer(Layer):
|
||||
_handle_event = start
|
||||
|
||||
@expect(events.DataReceived, events.ConnectionClosed)
|
||||
def relay_messages(self, event: events.ConnectionEvent) -> commands.TCommandGenerator:
|
||||
def relay_messages(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
|
||||
from_client = event.connection == self.context.client
|
||||
if from_client:
|
||||
send_to = self.context.server
|
||||
@ -68,7 +67,7 @@ class TCPLayer(Layer):
|
||||
if self.flow:
|
||||
tcp_message = tcp.TCPMessage(from_client, event.data)
|
||||
self.flow.messages.append(tcp_message)
|
||||
yield TcpMessageHook(self.flow)
|
||||
yield TcpMessageHook(self.flow)t
|
||||
yield commands.SendData(send_to, tcp_message.content)
|
||||
else:
|
||||
yield commands.SendData(send_to, event.data)
|
||||
@ -82,5 +81,5 @@ class TCPLayer(Layer):
|
||||
yield TcpEndHook(self.flow)
|
||||
|
||||
@expect(events.DataReceived, events.ConnectionClosed)
|
||||
def done(self, _):
|
||||
def done(self, _) -> layer.CommandGenerator[None]:
|
||||
yield from ()
|
||||
|
@ -163,7 +163,7 @@ class _TLSLayer(layer.Layer):
|
||||
|
||||
yield from self.negotiate(conn, initial_data)
|
||||
|
||||
def tls_interact(self, conn: context.Connection) -> commands.TCommandGenerator:
|
||||
def tls_interact(self, conn: context.Connection) -> layer.CommandGenerator[None]:
|
||||
while True:
|
||||
try:
|
||||
data = self.tls[conn].bio_read(65535)
|
||||
@ -173,8 +173,7 @@ class _TLSLayer(layer.Layer):
|
||||
else:
|
||||
yield commands.SendData(conn, data)
|
||||
|
||||
def negotiate(self, conn: context.Connection, data: bytes) -> Generator[
|
||||
commands.Command, Any, Tuple[bool, Optional[str]]]:
|
||||
def negotiate(self, conn: context.Connection, data: bytes) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]:
|
||||
# bio_write errors for b"", so we need to check first if we actually received something.
|
||||
if data:
|
||||
self.tls[conn].bio_write(data)
|
||||
@ -248,7 +247,7 @@ class _TLSLayer(layer.Layer):
|
||||
events.ConnectionClosed(conn)
|
||||
)
|
||||
|
||||
def event_to_child(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
for command in self.child_layer.handle_event(event):
|
||||
if isinstance(command, commands.SendData) and command.connection in self.tls:
|
||||
self.tls[command.connection].sendall(command.data)
|
||||
@ -256,7 +255,7 @@ class _TLSLayer(layer.Layer):
|
||||
else:
|
||||
yield command
|
||||
|
||||
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, events.DataReceived) and event.connection in self.tls:
|
||||
if not event.connection.tls_established:
|
||||
yield from self.negotiate(event.connection, event.data)
|
||||
@ -273,7 +272,7 @@ class _TLSLayer(layer.Layer):
|
||||
else:
|
||||
yield from self.event_to_child(event)
|
||||
|
||||
def on_handshake_error(self, conn: context.Connection, err: str) -> commands.TCommandGenerator:
|
||||
def on_handshake_error(self, conn: context.Connection, err: str) -> layer.CommandGenerator[None]:
|
||||
yield commands.CloseConnection(conn)
|
||||
|
||||
|
||||
@ -281,22 +280,21 @@ class ServerTLSLayer(_TLSLayer):
|
||||
"""
|
||||
This layer manages TLS for potentially multiple server connections.
|
||||
"""
|
||||
command_to_reply_to: Dict[context.Connection, EstablishServerTLS]
|
||||
command_to_reply_to: Dict[context.Connection, commands.OpenConnection]
|
||||
|
||||
def __init__(self, context: context.Context):
|
||||
super().__init__(context)
|
||||
self.command_to_reply_to = {}
|
||||
self.child_layer = layer.NextLayer(self.context)
|
||||
|
||||
def negotiate(self, conn: context.Connection, data: bytes) \
|
||||
-> Generator[commands.Command, Any, Tuple[bool, Optional[str]]]:
|
||||
def negotiate(self, conn: context.Connection, data: bytes) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]:
|
||||
done, err = yield from super().negotiate(conn, data)
|
||||
if done or err:
|
||||
cmd = self.command_to_reply_to.pop(conn)
|
||||
yield from self.event_to_child(EstablishServerTLSReply(cmd, err))
|
||||
yield from self.event_to_child(events.OpenConnectionReply(cmd, err))
|
||||
return done, err
|
||||
|
||||
def event_to_child(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
for command in super().event_to_child(event):
|
||||
if isinstance(command, EstablishServerTLS):
|
||||
self.command_to_reply_to[command.connection] = command
|
||||
@ -304,7 +302,7 @@ class ServerTLSLayer(_TLSLayer):
|
||||
else:
|
||||
yield command
|
||||
|
||||
def on_handshake_error(self, conn: context.Connection, err: str) -> commands.TCommandGenerator:
|
||||
def on_handshake_error(self, conn: context.Connection, err: str) -> layer.CommandGenerator[None]:
|
||||
yield commands.Log(
|
||||
f"Server TLS handshake failed. {err}",
|
||||
level="warn"
|
||||
@ -339,12 +337,12 @@ class ClientTLSLayer(_TLSLayer):
|
||||
self._handle_event = self.state_start
|
||||
|
||||
@expect(events.Start)
|
||||
def state_start(self, _) -> commands.TCommandGenerator:
|
||||
def state_start(self, _) -> layer.CommandGenerator[None]:
|
||||
self.context.client.tls = True
|
||||
self._handle_event = self.state_wait_for_clienthello
|
||||
yield from ()
|
||||
|
||||
def state_wait_for_clienthello(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
def state_wait_for_clienthello(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
client = self.context.client
|
||||
if isinstance(event, events.DataReceived) and event.connection == client:
|
||||
self.recv_buffer.extend(event.data)
|
||||
@ -376,7 +374,7 @@ class ClientTLSLayer(_TLSLayer):
|
||||
else:
|
||||
yield from self.event_to_child(event)
|
||||
|
||||
def start_server_tls(self):
|
||||
def start_server_tls(self) -> layer.CommandGenerator[Optional[str]]:
|
||||
"""
|
||||
We often need information from the upstream connection to establish TLS with the client.
|
||||
For example, we need to check if the client does ALPN or not.
|
||||
@ -397,7 +395,7 @@ class ClientTLSLayer(_TLSLayer):
|
||||
)
|
||||
return err
|
||||
|
||||
def on_handshake_error(self, conn: context.Connection, err: str) -> commands.TCommandGenerator:
|
||||
def on_handshake_error(self, conn: context.Connection, err: str) -> layer.CommandGenerator[None]:
|
||||
if conn.sni:
|
||||
dest = conn.sni.decode("idna")
|
||||
else:
|
||||
|
@ -109,7 +109,7 @@ def _test_echo(playbook: tutils.Playbook, tssl: SSLTest, conn: context.Connectio
|
||||
class TlsEchoLayer(tutils.EchoLayer):
|
||||
err: typing.Optional[str] = None
|
||||
|
||||
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, events.DataReceived) and event.data == b"establish-server-tls":
|
||||
# noinspection PyTypeChecker
|
||||
err = yield tls.EstablishServerTLS(self.context.server)
|
||||
|
@ -2,8 +2,7 @@ import typing
|
||||
|
||||
import pytest
|
||||
|
||||
from mitmproxy.proxy2 import events, commands
|
||||
from mitmproxy.proxy2.layer import Layer
|
||||
from mitmproxy.proxy2 import events, commands, layer
|
||||
from . import tutils
|
||||
|
||||
|
||||
@ -25,12 +24,12 @@ class TCommandReply(events.CommandReply):
|
||||
command: TCommand
|
||||
|
||||
|
||||
class TLayer(Layer):
|
||||
class TLayer(layer.Layer):
|
||||
"""
|
||||
Simple echo layer
|
||||
"""
|
||||
|
||||
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, TEvent):
|
||||
for x in event.commands:
|
||||
yield TCommand(x)
|
||||
|
@ -327,7 +327,7 @@ def Placeholder() -> typing.Any:
|
||||
class EchoLayer(Layer):
|
||||
"""Echo layer that sends all data back to the client in lowercase."""
|
||||
|
||||
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, events.DataReceived):
|
||||
yield commands.SendData(event.connection, event.data.lower())
|
||||
if isinstance(event, events.ConnectionClosed):
|
||||
|
Loading…
Reference in New Issue
Block a user