[sans-io] implement http streaming, refine error handling

This commit is contained in:
Maximilian Hils 2019-11-25 04:13:21 +01:00
parent 5671012163
commit 0740c673bd
8 changed files with 460 additions and 237 deletions

View File

@ -112,6 +112,9 @@ class Hook(Command):
all_hooks: typing.Dict[str, typing.Type[Hook]] = {} all_hooks: typing.Dict[str, typing.Type[Hook]] = {}
# TODO: Move descriptions from addons/events.py into hooks and have hook documentation generated from all_hooks.
class GetSocket(ConnectionCommand): class GetSocket(ConnectionCommand):
""" """
Get the underlying socket. Get the underlying socket.

View File

@ -5,6 +5,7 @@ The counterpart to events are commands.
""" """
import socket import socket
import typing import typing
from dataclasses import dataclass
from mitmproxy.proxy2 import commands from mitmproxy.proxy2 import commands
from mitmproxy.proxy2.context import Connection from mitmproxy.proxy2.context import Connection
@ -27,15 +28,13 @@ class Start(Event):
pass pass
@dataclass
class ConnectionEvent(Event): class ConnectionEvent(Event):
""" """
All events involving connection IO. All events involving connection IO.
""" """
connection: Connection connection: Connection
def __init__(self, connection: Connection):
self.connection = connection
class ConnectionClosed(ConnectionEvent): class ConnectionClosed(ConnectionEvent):
""" """
@ -44,20 +43,19 @@ class ConnectionClosed(ConnectionEvent):
pass pass
@dataclass
class DataReceived(ConnectionEvent): class DataReceived(ConnectionEvent):
""" """
Remote has sent some data. Remote has sent some data.
""" """
data: bytes
def __init__(self, connection: Connection, data: bytes) -> None:
super().__init__(connection)
self.data = data
def __repr__(self): def __repr__(self):
target = type(self.connection).__name__.lower() target = type(self.connection).__name__.lower()
return f"DataReceived({target}, {self.data})" return f"DataReceived({target}, {self.data})"
@dataclass
class CommandReply(Event): class CommandReply(Event):
""" """
Emitted when a command has been finished, e.g. Emitted when a command has been finished, e.g.
@ -66,10 +64,6 @@ class CommandReply(Event):
command: commands.Command command: commands.Command
reply: typing.Any reply: typing.Any
def __init__(self, command: commands.Command, reply: typing.Any):
self.command = command
self.reply = reply
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
if cls is CommandReply: if cls is CommandReply:
raise TypeError("CommandReply may not be instantiated directly.") raise TypeError("CommandReply may not be instantiated directly.")
@ -88,35 +82,23 @@ class CommandReply(Event):
command_reply_subclasses: typing.Dict[commands.Command, typing.Type[CommandReply]] = {} command_reply_subclasses: typing.Dict[commands.Command, typing.Type[CommandReply]] = {}
@dataclass
class OpenConnectionReply(CommandReply): class OpenConnectionReply(CommandReply):
command: commands.OpenConnection command: commands.OpenConnection
reply: typing.Optional[str] reply: typing.Optional[str]
"""error message"""
def __init__(
self,
command: commands.OpenConnection,
err: typing.Optional[str]
):
super().__init__(command, err)
@dataclass
class HookReply(CommandReply): class HookReply(CommandReply):
command: commands.Hook command: commands.Hook
reply: None = None
def __init__(self, command: commands.Hook):
super().__init__(command, None)
def __repr__(self): def __repr__(self):
return f"HookReply({repr(self.command)[5:-1]})" return f"HookReply({repr(self.command)[5:-1]})"
@dataclass
class GetSocketReply(CommandReply): class GetSocketReply(CommandReply):
command: commands.GetSocket command: commands.GetSocket
reply: socket.socket reply: socket.socket
def __init__(
self,
command: commands.GetSocket,
socket: socket.socket
):
super().__init__(command, socket)

View File

@ -10,8 +10,10 @@ from mitmproxy.proxy2.layers.tls import EstablishServerTLS, EstablishServerTLSRe
from mitmproxy.proxy2.utils import expect from mitmproxy.proxy2.utils import expect
from mitmproxy.utils import human from mitmproxy.utils import human
from ._base import HttpConnection, StreamId from ._base import HttpConnection, StreamId
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, ResponseData, ResponseEndOfMessage, \ from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
ResponseHeaders ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \
HttpResponseHook
from ._http1 import Http1Client, Http1Server from ._http1 import Http1Client, Http1Server
from ._http2 import Http2Client from ._http2 import Http2Client
@ -42,8 +44,8 @@ class GetHttpConnection(HttpCommand):
class GetHttpConnectionReply(events.CommandReply): class GetHttpConnectionReply(events.CommandReply):
command: GetHttpConnection command: GetHttpConnection
reply: typing.Optional[str] reply: typing.Tuple[typing.Optional[Connection], typing.Optional[str]]
"""error message""" """connection object, error message"""
class SendHttp(HttpCommand): class SendHttp(HttpCommand):
@ -58,35 +60,6 @@ class SendHttp(HttpCommand):
return f"Send({self.event})" return f"Send({self.event})"
class HttpRequestHeadersHook(commands.Hook):
name = "requestheaders"
flow: http.HTTPFlow
class HttpRequestHook(commands.Hook):
name = "request"
flow: http.HTTPFlow
class HttpResponseHook(commands.Hook):
name = "response"
flow: http.HTTPFlow
class HttpResponseHeadersHook(commands.Hook):
name = "responseheaders"
flow: http.HTTPFlow
class HttpConnectHook(commands.Hook):
flow: http.HTTPFlow
class HttpErrorHook(commands.Hook):
name = "error"
flow: http.HTTPFlow
class HttpStream(Layer): class HttpStream(Layer):
request_body_buf: bytes request_body_buf: bytes
response_body_buf: bytes response_body_buf: bytes
@ -103,15 +76,22 @@ class HttpStream(Layer):
super().__init__(context) super().__init__(context)
self.request_body_buf = b"" self.request_body_buf = b""
self.response_body_buf = b"" self.response_body_buf = b""
self._handle_event = self.start self.client_state = self.state_uninitialized
self.server_state = self.state_uninitialized
@expect(events.Start) @expect(events.Start, HttpEvent)
def start(self, event: events.Event) -> commands.TCommandGenerator: def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
self._handle_event = self.read_request_headers if isinstance(event, events.Start):
yield from () self.client_state = self.state_wait_for_request_headers
elif isinstance(event, (RequestProtocolError, ResponseProtocolError)):
yield from self.handle_protocol_error(event)
elif isinstance(event, (RequestHeaders, RequestData, RequestEndOfMessage)):
yield from self.client_state(event)
else:
yield from self.server_state(event)
@expect(RequestHeaders) @expect(RequestHeaders)
def read_request_headers(self, event: RequestHeaders) -> commands.TCommandGenerator: def state_wait_for_request_headers(self, event: RequestHeaders) -> commands.TCommandGenerator:
self.stream_id = event.stream_id self.stream_id = event.stream_id
self.flow = http.HTTPFlow( self.flow = http.HTTPFlow(
self.context.client, self.context.client,
@ -122,72 +102,12 @@ class HttpStream(Layer):
if self.flow.request.first_line_format == "authority": if self.flow.request.first_line_format == "authority":
yield from self.handle_connect() yield from self.handle_connect()
return return
else:
yield HttpRequestHeadersHook(self.flow)
if self.flow.request.headers.get("expect", "").lower() == "100-continue": if self.flow.request.headers.get("expect", "").lower() == "100-continue":
raise NotImplementedError("expect nothing") raise NotImplementedError("expect nothing")
# self.send_response(http.expect_continue_response) # self.send_response(http.expect_continue_response)
# request.headers.pop("expect") # request.headers.pop("expect")
if self.flow.request.stream:
raise NotImplementedError # FIXME
else:
self._handle_event = self.read_request_body
@expect(RequestData, RequestEndOfMessage)
def read_request_body(self, event: events.Event) -> commands.TCommandGenerator:
if isinstance(event, RequestData):
self.request_body_buf += event.data
elif isinstance(event, RequestEndOfMessage):
self.flow.request.data.content = self.request_body_buf
self.request_body_buf = b""
yield from self.handle_request()
def handle_connect(self) -> commands.TCommandGenerator:
yield HttpConnectHook(self.flow)
self.context.server = Server((self.flow.request.host, self.flow.request.port))
if self.context.options.connection_strategy == "eager":
err = yield commands.OpenConnection(self.context.server)
if err:
self.flow.response = http.HTTPResponse.make(
502, f"Cannot connect to {human.format_address(self.context.server.address)}: {err}"
)
if not self.flow.response:
self.flow.response = http.make_connect_response(self.flow.request.data.http_version)
yield SendHttp(ResponseHeaders(self.flow.response, self.stream_id), self.context.client)
if 200 <= self.flow.response.status_code < 300:
self.child_layer = NextLayer(self.context)
yield from self.child_layer.handle_event(events.Start())
self._handle_event = self.passthrough
else:
yield SendHttp(ResponseData(self.flow.response.data.content, self.stream_id), self.context.client)
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
@expect(RequestData, RequestEndOfMessage, events.Event)
def passthrough(self, event: events.Event) -> commands.TCommandGenerator:
# HTTP events -> normal connection events
if isinstance(event, RequestData):
event = events.DataReceived(self.context.client, event.data)
elif isinstance(event, RequestEndOfMessage):
event = events.ConnectionClosed(self.context.client)
for command in self.child_layer.handle_event(event):
# normal connection events -> HTTP events
if isinstance(command, commands.SendData) and command.connection == self.context.client:
yield SendHttp(ResponseData(command.data, self.stream_id), self.context.client)
elif isinstance(command, commands.CloseConnection) and command.connection == self.context.client:
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
elif isinstance(command, commands.OpenConnection) and command.connection == self.context.server:
yield from self.passthrough(events.OpenConnectionReply(command, None))
else:
yield command
def handle_request(self) -> commands.TCommandGenerator:
# set first line format to relative in regular mode, # set first line format to relative in regular mode,
# see https://github.com/mitmproxy/mitmproxy/issues/1759 # see https://github.com/mitmproxy/mitmproxy/issues/1759
if self.mode is HTTPMode.regular and self.flow.request.first_line_format == "absolute": if self.mode is HTTPMode.regular and self.flow.request.first_line_format == "absolute":
@ -210,68 +130,174 @@ class HttpStream(Layer):
self.flow.request.port = self.context.server.address[1] self.flow.request.port = self.context.server.address[1]
self.flow.request.host_header = host_header # set again as .host overwrites this. self.flow.request.host_header = host_header # set again as .host overwrites this.
self.flow.request.scheme = "https" if self.context.server.tls else "http" self.flow.request.scheme = "https" if self.context.server.tls else "http"
yield HttpRequestHook(self.flow)
yield HttpRequestHeadersHook(self.flow)
if self.flow.request.stream:
if self.flow.response:
raise NotImplementedError("Can't set a response and enable streaming at the same time.")
ok = yield from self.make_server_connection()
if not ok:
return
yield SendHttp(RequestHeaders(self.stream_id, self.flow.request), self.context.server)
self.client_state = self.state_stream_request_body
else:
self.client_state = self.state_consume_request_body
self.server_state = self.state_wait_for_response_headers
@expect(RequestData, RequestEndOfMessage)
def state_stream_request_body(self, event: events.Event) -> commands.TCommandGenerator:
if isinstance(event, RequestData):
if callable(self.flow.request.stream):
data = self.flow.request.stream(event.data)
else:
data = event.data
yield SendHttp(RequestData(self.stream_id, data), self.context.server)
elif isinstance(event, RequestEndOfMessage):
yield SendHttp(RequestEndOfMessage(self.stream_id), self.context.server)
self.client_state = self.state_done
@expect(RequestData, RequestEndOfMessage)
def state_consume_request_body(self, event: events.Event) -> commands.TCommandGenerator:
if isinstance(event, RequestData):
self.request_body_buf += event.data
elif isinstance(event, RequestEndOfMessage):
self.flow.request.data.content = self.request_body_buf
self.request_body_buf = b""
yield HttpRequestHook(self.flow)
if self.flow.response: if self.flow.response:
# response was set by an inline script. # response was set by an inline script.
# we now need to emulate the responseheaders hook. # we now need to emulate the responseheaders hook.
yield HttpResponseHeadersHook(self.flow) yield HttpResponseHeadersHook(self.flow)
yield from self.handle_response() yield from self.send_response()
else: else:
connection, err = yield GetHttpConnection( ok = yield from self.make_server_connection()
(self.flow.request.host, self.flow.request.port), if not ok:
self.flow.request.scheme == "https"
)
if err:
yield from self.send_error_response(502, err)
self.flow.error = flow.Error(err)
yield HttpErrorHook(self.flow)
return return
else:
self.flow.server_conn = connection
yield SendHttp(RequestHeaders(self.flow.request, self.stream_id), connection) yield SendHttp(RequestHeaders(self.stream_id, self.flow.request), self.context.server)
yield SendHttp(RequestData(self.stream_id, self.flow.request.data.content), self.context.server)
yield SendHttp(RequestEndOfMessage(self.stream_id), self.context.server)
if self.flow.request.stream: self.client_state = self.state_done
raise NotImplementedError
else:
yield SendHttp(RequestData(self.flow.request.data.content, self.stream_id), connection)
yield SendHttp(RequestEndOfMessage(self.stream_id), connection)
self._handle_event = self.read_response_headers
@expect(ResponseHeaders) @expect(ResponseHeaders)
def read_response_headers(self, event: ResponseHeaders) -> commands.TCommandGenerator: def state_wait_for_response_headers(self, event: ResponseHeaders) -> commands.TCommandGenerator:
self.flow.response = event.response self.flow.response = event.response
yield HttpResponseHeadersHook(self.flow) yield HttpResponseHeadersHook(self.flow)
if not self.flow.response.stream: if self.flow.response.stream:
self._handle_event = self.read_response_body yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response), self.context.client)
self.server_state = self.state_stream_response_body
else: else:
raise NotImplementedError self.server_state = self.state_consume_response_body
@expect(ResponseData, ResponseEndOfMessage) @expect(ResponseData, ResponseEndOfMessage)
def read_response_body(self, event: events.Event) -> commands.TCommandGenerator: def state_stream_response_body(self, event: events.Event) -> commands.TCommandGenerator:
if isinstance(event, ResponseData):
if callable(self.flow.response.stream):
data = self.flow.response.stream(event.data)
else:
data = event.data
yield SendHttp(ResponseData(self.stream_id, data), self.context.client)
elif isinstance(event, ResponseEndOfMessage):
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
self.server_state = self.state_done
@expect(ResponseData, ResponseEndOfMessage)
def state_consume_response_body(self, event: events.Event) -> commands.TCommandGenerator:
if isinstance(event, ResponseData): if isinstance(event, ResponseData):
self.response_body_buf += event.data self.response_body_buf += event.data
elif isinstance(event, ResponseEndOfMessage): elif isinstance(event, ResponseEndOfMessage):
self.flow.response.data.content = self.response_body_buf self.flow.response.data.content = self.response_body_buf
self.response_body_buf = b"" self.response_body_buf = b""
yield from self.handle_response() yield from self.send_response()
def handle_response(self): def send_response(self):
yield HttpResponseHook(self.flow) yield HttpResponseHook(self.flow)
yield SendHttp(ResponseHeaders(self.flow.response, self.stream_id), self.context.client) yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response), self.context.client)
yield SendHttp(ResponseData(self.stream_id, self.flow.response.data.content), self.context.client)
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
self.server_state = self.state_done
if self.flow.response.stream: def handle_protocol_error(
raise NotImplementedError self,
event: typing.Union[RequestProtocolError, ResponseProtocolError]
) -> commands.TCommandGenerator:
self.flow.error = flow.Error(event.message)
yield HttpErrorHook(self.flow)
if isinstance(event, RequestProtocolError):
yield SendHttp(event, self.context.server)
else: else:
yield SendHttp(ResponseData(self.flow.response.data.content, self.stream_id), self.context.client) yield SendHttp(event, self.context.client)
return
def make_server_connection(self) -> typing.Generator[commands.Command, typing.Any, bool]:
connection, err = yield GetHttpConnection(
(self.flow.request.host, self.flow.request.port),
self.flow.request.scheme == "https"
)
if err:
yield from self.handle_protocol_error(ResponseProtocolError(self.stream_id, err))
return False
else:
self.context.server = self.flow.server_conn = connection
return True
def handle_connect(self) -> commands.TCommandGenerator:
yield HttpConnectHook(self.flow)
self.context.server = Server((self.flow.request.host, self.flow.request.port))
if self.context.options.connection_strategy == "eager":
err = yield commands.OpenConnection(self.context.server)
if err:
self.flow.response = http.HTTPResponse.make(
502, f"Cannot connect to {human.format_address(self.context.server.address)}: {err}"
)
if not self.flow.response:
self.flow.response = http.make_connect_response(self.flow.request.data.http_version)
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response), self.context.client)
if 200 <= self.flow.response.status_code < 300:
self.child_layer = NextLayer(self.context)
yield from self.child_layer.handle_event(events.Start())
self._handle_event = self.passthrough
else:
yield SendHttp(ResponseData(self.stream_id, self.flow.response.data.content), self.context.client)
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client) yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
def send_error_response(self, status_code: int, message: str, headers=None): @expect(RequestData, RequestEndOfMessage, events.Event)
response = http.make_error_response(status_code, message, headers) def passthrough(self, event: events.Event) -> commands.TCommandGenerator:
yield SendHttp(ResponseHeaders(response, self.stream_id), self.context.client) # HTTP events -> normal connection events
yield SendHttp(ResponseData(response.data.content, self.stream_id), self.context.client) if isinstance(event, RequestData):
event = events.DataReceived(self.context.client, event.data)
elif isinstance(event, RequestEndOfMessage):
event = events.ConnectionClosed(self.context.client)
for command in self.child_layer.handle_event(event):
# normal connection events -> HTTP events
if isinstance(command, commands.SendData) and command.connection == self.context.client:
yield SendHttp(ResponseData(self.stream_id, command.data), self.context.client)
elif isinstance(command, commands.CloseConnection) and command.connection == self.context.client:
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client) yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
elif isinstance(command, commands.OpenConnection) and command.connection == self.context.server:
yield from self.passthrough(events.OpenConnectionReply(command, None))
else:
yield command
@expect()
def state_uninitialized(self, _) -> commands.TCommandGenerator:
yield from ()
@expect()
def state_done(self, _) -> commands.TCommandGenerator:
yield from ()
def state_errored(self, _) -> commands.TCommandGenerator:
# silently consume every event.
yield from ()
class HTTPLayer(Layer): class HTTPLayer(Layer):
@ -304,8 +330,6 @@ class HTTPLayer(Layer):
self.connections = { self.connections = {
context.client: Http1Server(context.client) context.client: Http1Server(context.client)
} }
if self.context.server.connected:
self.make_http_connection(self.context.server)
def __repr__(self): def __repr__(self):
return f"HTTPLayer(conns: {len(self.connections)}, events: {[type(e).__name__ for e in self.event_queue]})" return f"HTTPLayer(conns: {len(self.connections)}, events: {[type(e).__name__ for e in self.event_queue]})"
@ -322,8 +346,6 @@ class HTTPLayer(Layer):
self.event_to_child(stream, GetHttpConnectionReply(cmd, (None, event.reply))) self.event_to_child(stream, GetHttpConnectionReply(cmd, (None, event.reply)))
else: else:
yield from self.make_http_connection(event.command.connection) yield from self.make_http_connection(event.command.connection)
elif isinstance(event, EstablishServerTLSReply) and event.command.connection in self.waiting_for_connection:
yield from self.make_http_connection(event.command.connection)
elif isinstance(event, events.CommandReply): elif isinstance(event, events.CommandReply):
try: try:
stream = self.stream_by_command.pop(event.command) stream = self.stream_by_command.pop(event.command)

View File

@ -1,18 +1,16 @@
import abc import abc
import typing import typing
from dataclasses import dataclass
from mitmproxy.proxy2 import commands, events from mitmproxy.proxy2 import commands, events
StreamId = int StreamId = int
@dataclass
class HttpEvent(events.Event): class HttpEvent(events.Event):
stream_id: StreamId
# we need stream ids on every event to avoid race conditions # we need stream ids on every event to avoid race conditions
stream_id: StreamId
def __init__(self, stream_id: StreamId):
self.stream_id = stream_id
def __repr__(self) -> str: def __repr__(self) -> str:
x = self.__dict__.copy() x = self.__dict__.copy()

View File

@ -1,38 +1,28 @@
from dataclasses import dataclass
from mitmproxy import http from mitmproxy import http
from ._base import HttpEvent, StreamId from ._base import HttpEvent
@dataclass
class RequestHeaders(HttpEvent): class RequestHeaders(HttpEvent):
request: http.HTTPRequest request: http.HTTPRequest
def __init__(self, request: http.HTTPRequest, stream_id: StreamId):
super().__init__(stream_id)
self.request = request
@dataclass
class ResponseHeaders(HttpEvent): class ResponseHeaders(HttpEvent):
response: http.HTTPResponse response: http.HTTPResponse
def __init__(self, response: http.HTTPResponse, stream_id: StreamId):
super().__init__(stream_id)
self.response = response
@dataclass
class RequestData(HttpEvent): class RequestData(HttpEvent):
data: bytes data: bytes
def __init__(self, data: bytes, stream_id: StreamId):
super().__init__(stream_id)
self.data = data
@dataclass
class ResponseData(HttpEvent): class ResponseData(HttpEvent):
data: bytes data: bytes
def __init__(self, data: bytes, stream_id: StreamId):
super().__init__(stream_id)
self.data = data
class RequestEndOfMessage(HttpEvent): class RequestEndOfMessage(HttpEvent):
pass pass
@ -42,6 +32,18 @@ class ResponseEndOfMessage(HttpEvent):
pass pass
@dataclass
class RequestProtocolError(HttpEvent):
message: str
code: int = 400
@dataclass
class ResponseProtocolError(HttpEvent):
message: str
code: int = 502
__all__ = [ __all__ = [
"HttpEvent", "HttpEvent",
"RequestHeaders", "RequestHeaders",
@ -50,4 +52,6 @@ __all__ = [
"ResponseHeaders", "ResponseHeaders",
"ResponseData", "ResponseData",
"ResponseEndOfMessage", "ResponseEndOfMessage",
"RequestProtocolError",
"ResponseProtocolError",
] ]

View File

@ -0,0 +1,31 @@
from mitmproxy import http
from mitmproxy.proxy2 import commands
class HttpRequestHeadersHook(commands.Hook):
name = "requestheaders"
flow: http.HTTPFlow
class HttpRequestHook(commands.Hook):
name = "request"
flow: http.HTTPFlow
class HttpResponseHook(commands.Hook):
name = "response"
flow: http.HTTPFlow
class HttpResponseHeadersHook(commands.Hook):
name = "responseheaders"
flow: http.HTTPFlow
class HttpConnectHook(commands.Hook):
flow: http.HTTPFlow
class HttpErrorHook(commands.Hook):
name = "error"
flow: http.HTTPFlow

View File

@ -12,17 +12,19 @@ from mitmproxy.proxy2 import commands, events
from mitmproxy.proxy2.context import Client, Connection, Server from mitmproxy.proxy2.context import Client, Connection, Server
from mitmproxy.proxy2.layers.http._base import StreamId from mitmproxy.proxy2.layers.http._base import StreamId
from ._base import HttpConnection from ._base import HttpConnection
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, ResponseData, ResponseEndOfMessage, \ from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
ResponseHeaders ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
TBodyReader = typing.Union[ChunkedReader, Http10Reader, ContentLengthReader] TBodyReader = typing.Union[ChunkedReader, Http10Reader, ContentLengthReader]
class Http1Connection(HttpConnection): class Http1Connection(HttpConnection):
conn: Connection conn: Connection
stream_id: StreamId = None stream_id: typing.Optional[StreamId] = None
request: http.HTTPRequest request: typing.Optional[http.HTTPRequest] = None
response: http.HTTPResponse response: typing.Optional[http.HTTPResponse] = None
request_done: bool = False
response_done: bool = False
state: typing.Callable[[events.Event], typing.Iterator[HttpEvent]] state: typing.Callable[[events.Event], typing.Iterator[HttpEvent]]
body_reader: TBodyReader body_reader: TBodyReader
buf: ReceiveBuffer buf: ReceiveBuffer
@ -58,17 +60,22 @@ class Http1Connection(HttpConnection):
h11_event = self.body_reader.read_eof() h11_event = self.body_reader.read_eof()
else: else:
raise ValueError(f"Unexpected event: {event}") raise ValueError(f"Unexpected event: {event}")
except h11.ProtocolError: except h11.ProtocolError as e:
raise # FIXME yield commands.CloseConnection(self.conn)
if is_request:
yield RequestProtocolError(self.stream_id, str(e))
else:
yield ResponseProtocolError(self.stream_id, str(e))
return
if h11_event is None: if h11_event is None:
return return
elif isinstance(h11_event, h11.Data): elif isinstance(h11_event, h11.Data):
h11_event.data: bytearray # type checking h11_event.data: bytearray # type checking
if is_request: if is_request:
yield RequestData(bytes(h11_event.data), self.stream_id) yield RequestData(self.stream_id, bytes(h11_event.data))
else: else:
yield ResponseData(bytes(h11_event.data), self.stream_id) yield ResponseData(self.stream_id, bytes(h11_event.data))
elif isinstance(h11_event, h11.EndOfMessage): elif isinstance(h11_event, h11.EndOfMessage):
if is_request: if is_request:
yield RequestEndOfMessage(self.stream_id) yield RequestEndOfMessage(self.stream_id)
@ -96,13 +103,15 @@ class Http1Server(Http1Connection):
def __init__(self, conn: Client): def __init__(self, conn: Client):
super().__init__(conn) super().__init__(conn)
self.stream_id = 1
self.state = self.read_request_headers self.state = self.read_request_headers
self.stream_id = -1
def send(self, event: HttpEvent) -> commands.TCommandGenerator: def send(self, event: HttpEvent) -> commands.TCommandGenerator:
assert event.stream_id == self.stream_id
if isinstance(event, ResponseHeaders): if isinstance(event, ResponseHeaders):
self.response = event.response self.response = event.response
raw = http1.assemble_response_head(event.response) raw = http1.assemble_response_head(event.response)
yield commands.SendData(self.conn, raw)
if self.request.first_line_format == "authority": if self.request.first_line_format == "authority":
assert self.state == self.wait assert self.state == self.wait
self.body_reader = self.make_body_reader(-1) self.body_reader = self.make_body_reader(-1)
@ -113,24 +122,35 @@ class Http1Server(Http1Connection):
raw = b"%x\r\n%s\r\n" % (len(event.data), event.data) raw = b"%x\r\n%s\r\n" % (len(event.data), event.data)
else: else:
raw = event.data raw = event.data
yield commands.SendData(self.conn, raw)
elif isinstance(event, ResponseEndOfMessage): elif isinstance(event, ResponseEndOfMessage):
if "chunked" in self.response.headers.get("transfer-encoding", "").lower(): if "chunked" in self.response.headers.get("transfer-encoding", "").lower():
raw = b"0\r\n\r\n" yield commands.SendData(self.conn, b"0\r\n\r\n")
elif http1.expected_http_body_size(self.request, self.response) == -1: elif http1.expected_http_body_size(self.request, self.response) == -1:
yield commands.CloseConnection(self.conn) yield commands.CloseConnection(self.conn)
return yield from self.mark_done(response=True)
else: elif isinstance(event, ResponseProtocolError):
raw = False if not self.response:
self.request = None resp = http.make_error_response(event.code, event.message)
self.response = None raw = http1.assemble_response(resp)
self.stream_id += 2 yield commands.SendData(self.conn, raw)
self.state = self.read_request_headers yield commands.CloseConnection(self.conn)
yield from self.state(events.DataReceived(self.conn, b""))
else: else:
raise NotImplementedError(f"{event}") raise NotImplementedError(f"{event}")
if raw: def mark_done(self, *, request: bool = False, response: bool = False):
yield commands.SendData(self.conn, raw) if request:
self.request_done = True
if response:
self.response_done = True
if self.request_done and self.response_done:
self.request_done = self.response_done = False
self.request = self.response = None
self.stream_id += 2
self.state = self.read_request_headers
yield from self.state(events.DataReceived(self.conn, b""))
elif self.request_done:
self.state = self.wait
def read_request_headers(self, event: events.Event) -> typing.Iterator[HttpEvent]: def read_request_headers(self, event: events.Event) -> typing.Iterator[HttpEvent]:
if isinstance(event, events.DataReceived): if isinstance(event, events.DataReceived):
@ -138,7 +158,7 @@ class Http1Server(Http1Connection):
if request_head: if request_head:
request_head = [bytes(x) for x in request_head] # TODO: Make url.parse compatible with bytearrays request_head = [bytes(x) for x in request_head] # TODO: Make url.parse compatible with bytearrays
self.request = http.HTTPRequest.wrap(http1_sansio.read_request_head(request_head)) self.request = http.HTTPRequest.wrap(http1_sansio.read_request_head(request_head))
yield RequestHeaders(self.request, self.stream_id) yield RequestHeaders(self.stream_id, self.request)
if self.request.first_line_format == "authority": if self.request.first_line_format == "authority":
# The previous proxy server implementation tried to read the request body here: # The previous proxy server implementation tried to read the request body here:
@ -151,17 +171,20 @@ class Http1Server(Http1Connection):
self.body_reader = self.make_body_reader(expected_size) self.body_reader = self.make_body_reader(expected_size)
self.state = self.read_request_body self.state = self.read_request_body
yield from self.state(event) yield from self.state(event)
else:
pass # FIXME: protect against header size DoS
elif isinstance(event, events.ConnectionClosed): elif isinstance(event, events.ConnectionClosed):
pass # TODO: Better handling, tear everything down. if bytes(self.buf).strip():
yield commands.Log(f"Client closed connection before sending request headers: {bytes(self.buf)}")
yield commands.CloseConnection(self.conn)
else: else:
raise ValueError(f"Unexpected event: {event}") raise ValueError(f"Unexpected event: {event}")
def read_request_body(self, event: events.Event) -> typing.Iterator[HttpEvent]: def read_request_body(self, event: events.Event) -> typing.Iterator[HttpEvent]:
for e in self.read_body(event, True): for e in self.read_body(event, True):
if isinstance(e, RequestEndOfMessage):
self.state = self.wait
yield from self.state(event)
yield e yield e
if isinstance(e, RequestEndOfMessage):
yield from self.mark_done(request=True)
class Http1Client(Http1Connection): class Http1Client(Http1Connection):
@ -188,25 +211,40 @@ class Http1Client(Http1Connection):
if isinstance(event, RequestHeaders): if isinstance(event, RequestHeaders):
raw = http1.assemble_request_head(event.request) raw = http1.assemble_request_head(event.request)
yield commands.SendData(self.conn, raw)
elif isinstance(event, RequestData): elif isinstance(event, RequestData):
if "chunked" in self.request.headers.get("transfer-encoding", "").lower(): if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
raw = b"%x\r\n%s\r\n" % (len(event.data), event.data) raw = b"%x\r\n%s\r\n" % (len(event.data), event.data)
else: else:
raw = event.data raw = event.data
yield commands.SendData(self.conn, raw)
elif isinstance(event, RequestEndOfMessage): elif isinstance(event, RequestEndOfMessage):
if "chunked" in self.request.headers.get("transfer-encoding", "").lower(): if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
raw = b"0\r\n\r\n" yield commands.SendData(self.conn, b"0\r\n\r\n")
elif http1.expected_http_body_size(self.request) == -1: elif http1.expected_http_body_size(self.request) == -1:
assert not self.send_queue assert not self.send_queue
yield commands.CloseConnection(self.conn) yield commands.CloseConnection(self.conn)
yield from self.mark_done(request=True)
elif isinstance(event, RequestProtocolError):
yield commands.CloseConnection(self.conn)
return return
else:
raw = False
else: else:
raise NotImplementedError(f"{event}") raise NotImplementedError(f"{event}")
if raw: def mark_done(self, *, request: bool = False, response: bool = False):
yield commands.SendData(self.conn, raw) if request:
self.request_done = True
if response:
self.response_done = True
if self.request_done and self.response_done:
self.request_done = self.response_done = False
self.request = self.response = None
self.stream_id = None
if self.send_queue:
send_queue = self.send_queue
self.send_queue = []
for ev in send_queue:
yield from self.send(ev)
def read_response_headers(self, event: events.Event) -> typing.Iterator[HttpEvent]: def read_response_headers(self, event: events.Event) -> typing.Iterator[HttpEvent]:
assert isinstance(event, events.ConnectionEvent) assert isinstance(event, events.ConnectionEvent)
@ -216,18 +254,27 @@ class Http1Client(Http1Connection):
if response_head: if response_head:
response_head = [bytes(x) for x in response_head] response_head = [bytes(x) for x in response_head]
self.response = http.HTTPResponse.wrap(http1_sansio.read_response_head(response_head)) self.response = http.HTTPResponse.wrap(http1_sansio.read_response_head(response_head))
yield ResponseHeaders(self.response, self.stream_id) yield ResponseHeaders(self.stream_id, self.response)
expected_size = http1.expected_http_body_size(self.request, self.response) expected_size = http1.expected_http_body_size(self.request, self.response)
self.body_reader = self.make_body_reader(expected_size) self.body_reader = self.make_body_reader(expected_size)
self.state = self.read_response_body self.state = self.read_response_body
yield from self.state(event) yield from self.state(event)
else:
pass # FIXME: protect against header size DoS
elif isinstance(event, events.ConnectionClosed): elif isinstance(event, events.ConnectionClosed):
if self.stream_id: if self.stream_id:
raise NotImplementedError(f"{event}") if self.buf:
yield ResponseProtocolError(self.stream_id, f"unexpected server response: {bytes(self.buf)}")
else:
# The server has closed the connection to prevent us from continuing.
# We need to signal that to the stream.
# https://tools.ietf.org/html/rfc7231#section-6.5.11
yield ResponseProtocolError(self.stream_id, "server closed connection")
else: else:
return return
yield commands.CloseConnection(self.conn)
else: else:
raise ValueError(f"Unexpected event: {event}") raise ValueError(f"Unexpected event: {event}")
@ -237,14 +284,7 @@ class Http1Client(Http1Connection):
yield e yield e
if isinstance(e, ResponseEndOfMessage): if isinstance(e, ResponseEndOfMessage):
self.state = self.read_response_headers self.state = self.read_response_headers
self.stream_id = None yield from self.mark_done(response=True)
self.request = None
self.response = None
if self.send_queue:
send_queue = self.send_queue
self.send_queue = []
for ev in send_queue:
yield from self.send(ev)
__all__ = [ __all__ = [

View File

@ -1,9 +1,9 @@
import pytest import pytest
from mitmproxy.http import HTTPResponse, HTTPFlow from mitmproxy.http import HTTPFlow, HTTPResponse
from mitmproxy.proxy.protocol.http import HTTPMode from mitmproxy.proxy.protocol.http import HTTPMode
from mitmproxy.proxy2 import layer from mitmproxy.proxy2 import layer
from mitmproxy.proxy2.commands import OpenConnection, SendData from mitmproxy.proxy2.commands import CloseConnection, OpenConnection, SendData
from mitmproxy.proxy2.events import ConnectionClosed, DataReceived from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
from mitmproxy.proxy2.layers import http, tls from mitmproxy.proxy2.layers import http, tls
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply, reply_establish_server_tls, reply_next_layer from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply, reply_establish_server_tls, reply_next_layer
@ -168,6 +168,22 @@ def test_http_reply_from_proxy(tctx):
) )
def test_response_until_eof(tctx):
"""Test scenario where the server response body is terminated by EOF."""
server = Placeholder()
assert (
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< OpenConnection(server)
>> reply(None)
<< SendData(server, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
>> DataReceived(server, b"HTTP/1.1 200 OK\r\n\r\nfoo")
>> ConnectionClosed(server)
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\n\r\nfoo")
<< CloseConnection(tctx.client)
)
def test_disconnect_while_intercept(tctx): def test_disconnect_while_intercept(tctx):
"""Test a server disconnect while a request is intercepted.""" """Test a server disconnect while a request is intercepted."""
tctx.options.connection_strategy = "eager" tctx.options.connection_strategy = "eager"
@ -198,3 +214,130 @@ def test_disconnect_while_intercept(tctx):
) )
assert server1() != server2() assert server1() != server2()
assert flow().server_conn == server2() assert flow().server_conn == server2()
def test_response_streaming(tctx):
"""Test HTTP response streaming"""
server = Placeholder()
flow = Placeholder()
def enable_streaming(flow: HTTPFlow):
flow.response.stream = lambda x: x.upper()
assert (
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
>> DataReceived(tctx.client, b"GET http://example.com/largefile HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< OpenConnection(server)
>> reply(None)
<< SendData(server, b"GET /largefile HTTP/1.1\r\nHost: example.com\r\n\r\n")
>> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nabc")
<< http.HttpResponseHeadersHook(flow)
>> reply(side_effect=enable_streaming)
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nABC")
>> DataReceived(server, b"def")
<< SendData(tctx.client, b"DEF")
)
@pytest.mark.parametrize("response", ["normal response", "early response", "early close", "early kill"])
def test_request_streaming(tctx, response):
"""
Test HTTP request streaming
This is a bit more contrived as we may receive server data while we are still sending the request.
"""
server = Placeholder()
flow = Placeholder()
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
def enable_streaming(flow: HTTPFlow):
flow.request.stream = lambda x: x.upper()
assert (
playbook
>> DataReceived(tctx.client, b"POST http://example.com/ HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Length: 6\r\n\r\n"
b"abc")
<< http.HttpRequestHeadersHook(flow)
>> reply(side_effect=enable_streaming)
<< OpenConnection(server)
>> reply(None)
<< SendData(server, b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Length: 6\r\n\r\n"
b"ABC")
)
if response == "normal response":
assert (
playbook
>> DataReceived(tctx.client, b"def")
<< SendData(server, b"DEF")
>> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
)
elif response == "early response":
# We may receive a response before we have finished sending our request.
# We continue sending unless the server closes the connection.
# https://tools.ietf.org/html/rfc7231#section-6.5.11
assert (
playbook
>> DataReceived(server, b"HTTP/1.1 413 Request Entity Too Large\r\nContent-Length: 0\r\n\r\n")
<< SendData(tctx.client, b"HTTP/1.1 413 Request Entity Too Large\r\nContent-Length: 0\r\n\r\n")
>> DataReceived(tctx.client, b"def")
<< SendData(server, b"DEF")
)
elif response == "early close":
assert (
playbook
>> DataReceived(server, b"HTTP/1.1 413 Request Entity Too Large\r\nContent-Length: 0\r\n\r\n")
<< SendData(tctx.client, b"HTTP/1.1 413 Request Entity Too Large\r\nContent-Length: 0\r\n\r\n")
>> ConnectionClosed(server)
<< CloseConnection(server)
<< CloseConnection(tctx.client)
)
elif response == "early kill":
err = Placeholder()
assert (
playbook
>> ConnectionClosed(server)
<< CloseConnection(server)
<< SendData(tctx.client, err)
<< CloseConnection(tctx.client)
)
assert b"502 Bad Gateway" in err()
else: # pragma: no cover
assert False
@pytest.mark.parametrize("data", [
None,
b"I don't speak HTTP.",
b"HTTP/1.1 200 OK\r\nContent-Length: 10\r\n\r\nweee"
])
def test_server_aborts(tctx, data):
"""Test the scenario where the server doesn't serve a response"""
server = Placeholder()
flow = Placeholder()
err = Placeholder()
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
assert (
playbook
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< OpenConnection(server)
>> reply(None)
<< SendData(server, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
)
if data:
playbook >> DataReceived(server, data)
assert (
playbook
>> ConnectionClosed(server)
<< CloseConnection(server)
<< http.HttpErrorHook(flow)
>> reply()
<< SendData(tctx.client, err)
<< CloseConnection(tctx.client)
)
assert flow().error
assert b"502 Bad Gateway" in err()