[sans-io] refactor

This commit is contained in:
Maximilian Hils 2019-12-21 20:00:01 +01:00
parent 0740c673bd
commit 7efe27be74
12 changed files with 79 additions and 70 deletions

View File

@ -40,7 +40,7 @@ class ConnectionCommand(Command):
"""
connection: Connection
def __init__(self, connection: Connection) -> None:
def __init__(self, connection: Connection):
self.connection = connection
@ -50,7 +50,7 @@ class SendData(ConnectionCommand):
"""
data: bytes
def __init__(self, connection: Connection, data: bytes) -> None:
def __init__(self, connection: Connection, data: bytes):
super().__init__(connection)
self.data = data
@ -74,6 +74,16 @@ class CloseConnection(ConnectionCommand):
"""
class ProtocolError(Command):
"""
Indicate that an unrecoverable protocol error has occured.
"""
message: str
def __init__(self, message: str):
self.message = message
class Hook(Command):
"""
Callback to the master (like ".ask()")
@ -135,4 +145,3 @@ class Log(Command):
return f"Log({self.message}, {self.level})"
TCommandGenerator = typing.Generator[Command, typing.Any, None]

View File

@ -8,16 +8,22 @@ from abc import abstractmethod
from mitmproxy import log
from mitmproxy.proxy2 import commands, events
from mitmproxy.proxy2.commands import Hook
from mitmproxy.proxy2.commands import Command, Hook
from mitmproxy.proxy2.context import Connection, Context
T = typing.TypeVar('T')
CommandGenerator = typing.Generator[Command, typing.Optional[events.CommandReply], T]
"""
A function annotated with CommandGenerator[bool] may yield commands and ultimately return a boolean value.
"""
class Paused(typing.NamedTuple):
"""
State of a layer that's paused because it is waiting for a command reply.
"""
command: commands.Command
generator: commands.TCommandGenerator
generator: CommandGenerator
class Layer:
@ -65,11 +71,11 @@ class Layer:
)
@abstractmethod
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
def _handle_event(self, event: events.Event) -> CommandGenerator[None]:
"""Handle a proxy server event"""
yield from () # pragma: no cover
def handle_event(self, event: events.Event) -> commands.TCommandGenerator:
def handle_event(self, event: events.Event) -> CommandGenerator[None]:
if self._paused:
# did we just receive the reply we were waiting for?
pause_finished = (
@ -88,7 +94,7 @@ class Layer:
command_generator = self._handle_event(event)
yield from self.__process(command_generator)
def __process(self, command_generator: commands.TCommandGenerator, send=None):
def __process(self, command_generator: CommandGenerator, send=None):
"""
yield all commands from a generator.
if a command is blocking, the layer is paused and this function returns before

View File

@ -3,9 +3,8 @@ import typing
from mitmproxy import flow, http
from mitmproxy.proxy.protocol.http import HTTPMode
from mitmproxy.proxy2 import commands, events
from mitmproxy.proxy2 import commands, events, layer
from mitmproxy.proxy2.context import Connection, Context, Server
from mitmproxy.proxy2.layer import Layer, NextLayer
from mitmproxy.proxy2.layers.tls import EstablishServerTLS, EstablishServerTLSReply, HTTP_ALPNS
from mitmproxy.proxy2.utils import expect
from mitmproxy.utils import human
@ -60,12 +59,12 @@ class SendHttp(HttpCommand):
return f"Send({self.event})"
class HttpStream(Layer):
class HttpStream(layer.Layer):
request_body_buf: bytes
response_body_buf: bytes
flow: http.HTTPFlow
stream_id: StreamId
child_layer: typing.Optional[Layer] = None
child_layer: typing.Optional[layer.Layer] = None
@property
def mode(self):
@ -80,7 +79,7 @@ class HttpStream(Layer):
self.server_state = self.state_uninitialized
@expect(events.Start, HttpEvent)
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, events.Start):
self.client_state = self.state_wait_for_request_headers
elif isinstance(event, (RequestProtocolError, ResponseProtocolError)):
@ -91,7 +90,7 @@ class HttpStream(Layer):
yield from self.server_state(event)
@expect(RequestHeaders)
def state_wait_for_request_headers(self, event: RequestHeaders) -> commands.TCommandGenerator:
def state_wait_for_request_headers(self, event: RequestHeaders) -> layer.CommandGenerator[None]:
self.stream_id = event.stream_id
self.flow = http.HTTPFlow(
self.context.client,
@ -146,7 +145,7 @@ class HttpStream(Layer):
self.server_state = self.state_wait_for_response_headers
@expect(RequestData, RequestEndOfMessage)
def state_stream_request_body(self, event: events.Event) -> commands.TCommandGenerator:
def state_stream_request_body(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, RequestData):
if callable(self.flow.request.stream):
data = self.flow.request.stream(event.data)
@ -158,7 +157,7 @@ class HttpStream(Layer):
self.client_state = self.state_done
@expect(RequestData, RequestEndOfMessage)
def state_consume_request_body(self, event: events.Event) -> commands.TCommandGenerator:
def state_consume_request_body(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, RequestData):
self.request_body_buf += event.data
elif isinstance(event, RequestEndOfMessage):
@ -182,7 +181,7 @@ class HttpStream(Layer):
self.client_state = self.state_done
@expect(ResponseHeaders)
def state_wait_for_response_headers(self, event: ResponseHeaders) -> commands.TCommandGenerator:
def state_wait_for_response_headers(self, event: ResponseHeaders) -> layer.CommandGenerator[None]:
self.flow.response = event.response
yield HttpResponseHeadersHook(self.flow)
if self.flow.response.stream:
@ -192,7 +191,7 @@ class HttpStream(Layer):
self.server_state = self.state_consume_response_body
@expect(ResponseData, ResponseEndOfMessage)
def state_stream_response_body(self, event: events.Event) -> commands.TCommandGenerator:
def state_stream_response_body(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, ResponseData):
if callable(self.flow.response.stream):
data = self.flow.response.stream(event.data)
@ -204,7 +203,7 @@ class HttpStream(Layer):
self.server_state = self.state_done
@expect(ResponseData, ResponseEndOfMessage)
def state_consume_response_body(self, event: events.Event) -> commands.TCommandGenerator:
def state_consume_response_body(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, ResponseData):
self.response_body_buf += event.data
elif isinstance(event, ResponseEndOfMessage):
@ -222,7 +221,7 @@ class HttpStream(Layer):
def handle_protocol_error(
self,
event: typing.Union[RequestProtocolError, ResponseProtocolError]
) -> commands.TCommandGenerator:
) -> layer.CommandGenerator[None]:
self.flow.error = flow.Error(event.message)
yield HttpErrorHook(self.flow)
@ -232,7 +231,7 @@ class HttpStream(Layer):
yield SendHttp(event, self.context.client)
return
def make_server_connection(self) -> typing.Generator[commands.Command, typing.Any, bool]:
def make_server_connection(self) -> layer.CommandGenerator[bool]:
connection, err = yield GetHttpConnection(
(self.flow.request.host, self.flow.request.port),
self.flow.request.scheme == "https"
@ -244,7 +243,7 @@ class HttpStream(Layer):
self.context.server = self.flow.server_conn = connection
return True
def handle_connect(self) -> commands.TCommandGenerator:
def handle_connect(self) -> layer.CommandGenerator[None]:
yield HttpConnectHook(self.flow)
self.context.server = Server((self.flow.request.host, self.flow.request.port))
@ -261,7 +260,7 @@ class HttpStream(Layer):
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response), self.context.client)
if 200 <= self.flow.response.status_code < 300:
self.child_layer = NextLayer(self.context)
self.child_layer = layer.NextLayer(self.context)
yield from self.child_layer.handle_event(events.Start())
self._handle_event = self.passthrough
else:
@ -269,7 +268,7 @@ class HttpStream(Layer):
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
@expect(RequestData, RequestEndOfMessage, events.Event)
def passthrough(self, event: events.Event) -> commands.TCommandGenerator:
def passthrough(self, event: events.Event) -> layer.CommandGenerator[None]:
# HTTP events -> normal connection events
if isinstance(event, RequestData):
event = events.DataReceived(self.context.client, event.data)
@ -288,14 +287,14 @@ class HttpStream(Layer):
yield command
@expect()
def state_uninitialized(self, _) -> commands.TCommandGenerator:
def state_uninitialized(self, _) -> layer.CommandGenerator[None]:
yield from ()
@expect()
def state_done(self, _) -> commands.TCommandGenerator:
def state_done(self, _) -> layer.CommandGenerator[None]:
yield from ()
def state_errored(self, _) -> commands.TCommandGenerator:
def state_errored(self, _) -> layer.CommandGenerator[None]:
# silently consume every event.
yield from ()

View File

@ -2,7 +2,7 @@ import abc
import typing
from dataclasses import dataclass
from mitmproxy.proxy2 import commands, events
from mitmproxy.proxy2 import events, layer
StreamId = int
@ -24,7 +24,7 @@ class HttpConnection(abc.ABC):
yield from ()
@abc.abstractmethod
def send(self, event: HttpEvent) -> commands.TCommandGenerator:
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
yield from ()

View File

@ -8,7 +8,7 @@ from h11._receivebuffer import ReceiveBuffer
from mitmproxy import http
from mitmproxy.net.http import http1
from mitmproxy.net.http.http1 import read_sansio as http1_sansio
from mitmproxy.proxy2 import commands, events
from mitmproxy.proxy2 import commands, events, layer
from mitmproxy.proxy2.context import Client, Connection, Server
from mitmproxy.proxy2.layers.http._base import StreamId
from ._base import HttpConnection
@ -40,7 +40,7 @@ class Http1Connection(HttpConnection):
yield from self.state(event)
@abstractmethod
def send(self, event: HttpEvent) -> commands.TCommandGenerator:
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
yield from ()
def make_body_reader(self, expected_size: typing.Optional[int]) -> TBodyReader:
@ -106,7 +106,7 @@ class Http1Server(Http1Connection):
self.stream_id = 1
self.state = self.read_request_headers
def send(self, event: HttpEvent) -> commands.TCommandGenerator:
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
assert event.stream_id == self.stream_id
if isinstance(event, ResponseHeaders):
self.response = event.response
@ -197,7 +197,7 @@ class Http1Client(Http1Connection):
self.state = self.read_response_headers
self.send_queue = []
def send(self, event: HttpEvent) -> commands.TCommandGenerator:
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
if not self.stream_id:
assert isinstance(event, RequestHeaders)
self.stream_id = event.stream_id

View File

@ -7,7 +7,7 @@ from mitmproxy.proxy2.utils import expect
class ReverseProxy(layer.Layer):
@expect(events.Start)
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
spec = server_spec.parse_with_mode(self.context.options.mode)[1]
self.context.server = Server(spec.address)
if spec.scheme not in ("http", "tcp"):
@ -21,7 +21,7 @@ class ReverseProxy(layer.Layer):
class HttpProxy(layer.Layer):
@expect(events.Start)
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
child_layer = layer.NextLayer(self.context)
self._handle_event = child_layer.handle_event
yield from child_layer.handle_event(event)
@ -29,7 +29,7 @@ class HttpProxy(layer.Layer):
class TransparentProxy(layer.Layer):
@expect(events.Start)
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
socket = yield commands.GetSocket(self.context.client)
try:
self.context.server.address = platform.original_addr(socket)

View File

@ -4,13 +4,12 @@ from wsproto import ConnectionType, WSConnection
from wsproto.extensions import PerMessageDeflate
from mitmproxy import websocket, http, flow
from mitmproxy.proxy2 import events, commands
from mitmproxy.proxy2 import events, commands, layer
from mitmproxy.proxy2.context import Context
from mitmproxy.proxy2.layer import Layer
from mitmproxy.proxy2.utils import expect
class WebsocketLayer(Layer):
class WebsocketLayer(layer.Layer):
"""
WebSocket layer that intercepts and relays messages.
"""
@ -29,7 +28,7 @@ class WebsocketLayer(Layer):
assert context.server.connected
@expect(events.Start)
def start(self, _) -> commands.TCommandGenerator:
def start(self, _) -> layer.CommandGenerator[None]:
extensions = []
if 'Sec-WebSocket-Extensions' in self.handshake_flow.response.headers:
if PerMessageDeflate.name in self.handshake_flow.response.headers['Sec-WebSocket-Extensions']:
@ -60,7 +59,7 @@ class WebsocketLayer(Layer):
_handle_event = start
@expect(events.DataReceived, events.ConnectionClosed)
def process_data(self, event: events.Event) -> commands.TCommandGenerator:
def process_data(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, events.DataReceived):
from_client = event.connection == self.context.client
if from_client:

View File

@ -1,10 +1,9 @@
from typing import Optional
from mitmproxy import flow, tcp
from mitmproxy.proxy2 import commands, events
from mitmproxy.proxy2 import commands, events, layer
from mitmproxy.proxy2.commands import Hook
from mitmproxy.proxy2.context import Context
from mitmproxy.proxy2.layer import Layer
from mitmproxy.proxy2.utils import expect
@ -24,7 +23,7 @@ class TcpErrorHook(Hook):
flow: tcp.TCPFlow
class TCPLayer(Layer):
class TCPLayer(layer.Layer):
"""
Simple TCP layer that just relays messages right now.
"""
@ -39,7 +38,7 @@ class TCPLayer(Layer):
self.flow = tcp.TCPFlow(self.context.client, self.context.server, True)
@expect(events.Start)
def start(self, _) -> commands.TCommandGenerator:
def start(self, _) -> layer.CommandGenerator[None]:
if self.flow:
yield TcpStartHook(self.flow)
@ -57,7 +56,7 @@ class TCPLayer(Layer):
_handle_event = start
@expect(events.DataReceived, events.ConnectionClosed)
def relay_messages(self, event: events.ConnectionEvent) -> commands.TCommandGenerator:
def relay_messages(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
from_client = event.connection == self.context.client
if from_client:
send_to = self.context.server
@ -68,7 +67,7 @@ class TCPLayer(Layer):
if self.flow:
tcp_message = tcp.TCPMessage(from_client, event.data)
self.flow.messages.append(tcp_message)
yield TcpMessageHook(self.flow)
yield TcpMessageHook(self.flow)t
yield commands.SendData(send_to, tcp_message.content)
else:
yield commands.SendData(send_to, event.data)
@ -82,5 +81,5 @@ class TCPLayer(Layer):
yield TcpEndHook(self.flow)
@expect(events.DataReceived, events.ConnectionClosed)
def done(self, _):
def done(self, _) -> layer.CommandGenerator[None]:
yield from ()

View File

@ -163,7 +163,7 @@ class _TLSLayer(layer.Layer):
yield from self.negotiate(conn, initial_data)
def tls_interact(self, conn: context.Connection) -> commands.TCommandGenerator:
def tls_interact(self, conn: context.Connection) -> layer.CommandGenerator[None]:
while True:
try:
data = self.tls[conn].bio_read(65535)
@ -173,8 +173,7 @@ class _TLSLayer(layer.Layer):
else:
yield commands.SendData(conn, data)
def negotiate(self, conn: context.Connection, data: bytes) -> Generator[
commands.Command, Any, Tuple[bool, Optional[str]]]:
def negotiate(self, conn: context.Connection, data: bytes) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]:
# bio_write errors for b"", so we need to check first if we actually received something.
if data:
self.tls[conn].bio_write(data)
@ -248,7 +247,7 @@ class _TLSLayer(layer.Layer):
events.ConnectionClosed(conn)
)
def event_to_child(self, event: events.Event) -> commands.TCommandGenerator:
def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]:
for command in self.child_layer.handle_event(event):
if isinstance(command, commands.SendData) and command.connection in self.tls:
self.tls[command.connection].sendall(command.data)
@ -256,7 +255,7 @@ class _TLSLayer(layer.Layer):
else:
yield command
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, events.DataReceived) and event.connection in self.tls:
if not event.connection.tls_established:
yield from self.negotiate(event.connection, event.data)
@ -273,7 +272,7 @@ class _TLSLayer(layer.Layer):
else:
yield from self.event_to_child(event)
def on_handshake_error(self, conn: context.Connection, err: str) -> commands.TCommandGenerator:
def on_handshake_error(self, conn: context.Connection, err: str) -> layer.CommandGenerator[None]:
yield commands.CloseConnection(conn)
@ -281,22 +280,21 @@ class ServerTLSLayer(_TLSLayer):
"""
This layer manages TLS for potentially multiple server connections.
"""
command_to_reply_to: Dict[context.Connection, EstablishServerTLS]
command_to_reply_to: Dict[context.Connection, commands.OpenConnection]
def __init__(self, context: context.Context):
super().__init__(context)
self.command_to_reply_to = {}
self.child_layer = layer.NextLayer(self.context)
def negotiate(self, conn: context.Connection, data: bytes) \
-> Generator[commands.Command, Any, Tuple[bool, Optional[str]]]:
def negotiate(self, conn: context.Connection, data: bytes) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]:
done, err = yield from super().negotiate(conn, data)
if done or err:
cmd = self.command_to_reply_to.pop(conn)
yield from self.event_to_child(EstablishServerTLSReply(cmd, err))
yield from self.event_to_child(events.OpenConnectionReply(cmd, err))
return done, err
def event_to_child(self, event: events.Event) -> commands.TCommandGenerator:
def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]:
for command in super().event_to_child(event):
if isinstance(command, EstablishServerTLS):
self.command_to_reply_to[command.connection] = command
@ -304,7 +302,7 @@ class ServerTLSLayer(_TLSLayer):
else:
yield command
def on_handshake_error(self, conn: context.Connection, err: str) -> commands.TCommandGenerator:
def on_handshake_error(self, conn: context.Connection, err: str) -> layer.CommandGenerator[None]:
yield commands.Log(
f"Server TLS handshake failed. {err}",
level="warn"
@ -339,12 +337,12 @@ class ClientTLSLayer(_TLSLayer):
self._handle_event = self.state_start
@expect(events.Start)
def state_start(self, _) -> commands.TCommandGenerator:
def state_start(self, _) -> layer.CommandGenerator[None]:
self.context.client.tls = True
self._handle_event = self.state_wait_for_clienthello
yield from ()
def state_wait_for_clienthello(self, event: events.Event) -> commands.TCommandGenerator:
def state_wait_for_clienthello(self, event: events.Event) -> layer.CommandGenerator[None]:
client = self.context.client
if isinstance(event, events.DataReceived) and event.connection == client:
self.recv_buffer.extend(event.data)
@ -376,7 +374,7 @@ class ClientTLSLayer(_TLSLayer):
else:
yield from self.event_to_child(event)
def start_server_tls(self):
def start_server_tls(self) -> layer.CommandGenerator[Optional[str]]:
"""
We often need information from the upstream connection to establish TLS with the client.
For example, we need to check if the client does ALPN or not.
@ -397,7 +395,7 @@ class ClientTLSLayer(_TLSLayer):
)
return err
def on_handshake_error(self, conn: context.Connection, err: str) -> commands.TCommandGenerator:
def on_handshake_error(self, conn: context.Connection, err: str) -> layer.CommandGenerator[None]:
if conn.sni:
dest = conn.sni.decode("idna")
else:

View File

@ -109,7 +109,7 @@ def _test_echo(playbook: tutils.Playbook, tssl: SSLTest, conn: context.Connectio
class TlsEchoLayer(tutils.EchoLayer):
err: typing.Optional[str] = None
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, events.DataReceived) and event.data == b"establish-server-tls":
# noinspection PyTypeChecker
err = yield tls.EstablishServerTLS(self.context.server)

View File

@ -2,8 +2,7 @@ import typing
import pytest
from mitmproxy.proxy2 import events, commands
from mitmproxy.proxy2.layer import Layer
from mitmproxy.proxy2 import events, commands, layer
from . import tutils
@ -25,12 +24,12 @@ class TCommandReply(events.CommandReply):
command: TCommand
class TLayer(Layer):
class TLayer(layer.Layer):
"""
Simple echo layer
"""
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, TEvent):
for x in event.commands:
yield TCommand(x)

View File

@ -327,7 +327,7 @@ def Placeholder() -> typing.Any:
class EchoLayer(Layer):
"""Echo layer that sends all data back to the client in lowercase."""
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, events.DataReceived):
yield commands.SendData(event.connection, event.data.lower())
if isinstance(event, events.ConnectionClosed):