mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-22 15:37:45 +00:00
[sans-io] implement http streaming, refine error handling
This commit is contained in:
parent
5671012163
commit
0740c673bd
@ -112,6 +112,9 @@ class Hook(Command):
|
||||
all_hooks: typing.Dict[str, typing.Type[Hook]] = {}
|
||||
|
||||
|
||||
# TODO: Move descriptions from addons/events.py into hooks and have hook documentation generated from all_hooks.
|
||||
|
||||
|
||||
class GetSocket(ConnectionCommand):
|
||||
"""
|
||||
Get the underlying socket.
|
||||
|
@ -5,6 +5,7 @@ The counterpart to events are commands.
|
||||
"""
|
||||
import socket
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
|
||||
from mitmproxy.proxy2 import commands
|
||||
from mitmproxy.proxy2.context import Connection
|
||||
@ -27,15 +28,13 @@ class Start(Event):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionEvent(Event):
|
||||
"""
|
||||
All events involving connection IO.
|
||||
"""
|
||||
connection: Connection
|
||||
|
||||
def __init__(self, connection: Connection):
|
||||
self.connection = connection
|
||||
|
||||
|
||||
class ConnectionClosed(ConnectionEvent):
|
||||
"""
|
||||
@ -44,20 +43,19 @@ class ConnectionClosed(ConnectionEvent):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataReceived(ConnectionEvent):
|
||||
"""
|
||||
Remote has sent some data.
|
||||
"""
|
||||
|
||||
def __init__(self, connection: Connection, data: bytes) -> None:
|
||||
super().__init__(connection)
|
||||
self.data = data
|
||||
data: bytes
|
||||
|
||||
def __repr__(self):
|
||||
target = type(self.connection).__name__.lower()
|
||||
return f"DataReceived({target}, {self.data})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandReply(Event):
|
||||
"""
|
||||
Emitted when a command has been finished, e.g.
|
||||
@ -66,10 +64,6 @@ class CommandReply(Event):
|
||||
command: commands.Command
|
||||
reply: typing.Any
|
||||
|
||||
def __init__(self, command: commands.Command, reply: typing.Any):
|
||||
self.command = command
|
||||
self.reply = reply
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls is CommandReply:
|
||||
raise TypeError("CommandReply may not be instantiated directly.")
|
||||
@ -88,35 +82,23 @@ class CommandReply(Event):
|
||||
command_reply_subclasses: typing.Dict[commands.Command, typing.Type[CommandReply]] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenConnectionReply(CommandReply):
|
||||
command: commands.OpenConnection
|
||||
reply: typing.Optional[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
command: commands.OpenConnection,
|
||||
err: typing.Optional[str]
|
||||
):
|
||||
super().__init__(command, err)
|
||||
"""error message"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class HookReply(CommandReply):
|
||||
command: commands.Hook
|
||||
|
||||
def __init__(self, command: commands.Hook):
|
||||
super().__init__(command, None)
|
||||
reply: None = None
|
||||
|
||||
def __repr__(self):
|
||||
return f"HookReply({repr(self.command)[5:-1]})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetSocketReply(CommandReply):
|
||||
command: commands.GetSocket
|
||||
reply: socket.socket
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
command: commands.GetSocket,
|
||||
socket: socket.socket
|
||||
):
|
||||
super().__init__(command, socket)
|
||||
|
@ -10,8 +10,10 @@ from mitmproxy.proxy2.layers.tls import EstablishServerTLS, EstablishServerTLSRe
|
||||
from mitmproxy.proxy2.utils import expect
|
||||
from mitmproxy.utils import human
|
||||
from ._base import HttpConnection, StreamId
|
||||
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, ResponseData, ResponseEndOfMessage, \
|
||||
ResponseHeaders
|
||||
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
|
||||
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
|
||||
from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \
|
||||
HttpResponseHook
|
||||
from ._http1 import Http1Client, Http1Server
|
||||
from ._http2 import Http2Client
|
||||
|
||||
@ -42,8 +44,8 @@ class GetHttpConnection(HttpCommand):
|
||||
|
||||
class GetHttpConnectionReply(events.CommandReply):
|
||||
command: GetHttpConnection
|
||||
reply: typing.Optional[str]
|
||||
"""error message"""
|
||||
reply: typing.Tuple[typing.Optional[Connection], typing.Optional[str]]
|
||||
"""connection object, error message"""
|
||||
|
||||
|
||||
class SendHttp(HttpCommand):
|
||||
@ -58,35 +60,6 @@ class SendHttp(HttpCommand):
|
||||
return f"Send({self.event})"
|
||||
|
||||
|
||||
class HttpRequestHeadersHook(commands.Hook):
|
||||
name = "requestheaders"
|
||||
flow: http.HTTPFlow
|
||||
|
||||
|
||||
class HttpRequestHook(commands.Hook):
|
||||
name = "request"
|
||||
flow: http.HTTPFlow
|
||||
|
||||
|
||||
class HttpResponseHook(commands.Hook):
|
||||
name = "response"
|
||||
flow: http.HTTPFlow
|
||||
|
||||
|
||||
class HttpResponseHeadersHook(commands.Hook):
|
||||
name = "responseheaders"
|
||||
flow: http.HTTPFlow
|
||||
|
||||
|
||||
class HttpConnectHook(commands.Hook):
|
||||
flow: http.HTTPFlow
|
||||
|
||||
|
||||
class HttpErrorHook(commands.Hook):
|
||||
name = "error"
|
||||
flow: http.HTTPFlow
|
||||
|
||||
|
||||
class HttpStream(Layer):
|
||||
request_body_buf: bytes
|
||||
response_body_buf: bytes
|
||||
@ -103,15 +76,22 @@ class HttpStream(Layer):
|
||||
super().__init__(context)
|
||||
self.request_body_buf = b""
|
||||
self.response_body_buf = b""
|
||||
self._handle_event = self.start
|
||||
self.client_state = self.state_uninitialized
|
||||
self.server_state = self.state_uninitialized
|
||||
|
||||
@expect(events.Start)
|
||||
def start(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
self._handle_event = self.read_request_headers
|
||||
yield from ()
|
||||
@expect(events.Start, HttpEvent)
|
||||
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
if isinstance(event, events.Start):
|
||||
self.client_state = self.state_wait_for_request_headers
|
||||
elif isinstance(event, (RequestProtocolError, ResponseProtocolError)):
|
||||
yield from self.handle_protocol_error(event)
|
||||
elif isinstance(event, (RequestHeaders, RequestData, RequestEndOfMessage)):
|
||||
yield from self.client_state(event)
|
||||
else:
|
||||
yield from self.server_state(event)
|
||||
|
||||
@expect(RequestHeaders)
|
||||
def read_request_headers(self, event: RequestHeaders) -> commands.TCommandGenerator:
|
||||
def state_wait_for_request_headers(self, event: RequestHeaders) -> commands.TCommandGenerator:
|
||||
self.stream_id = event.stream_id
|
||||
self.flow = http.HTTPFlow(
|
||||
self.context.client,
|
||||
@ -122,72 +102,12 @@ class HttpStream(Layer):
|
||||
if self.flow.request.first_line_format == "authority":
|
||||
yield from self.handle_connect()
|
||||
return
|
||||
else:
|
||||
yield HttpRequestHeadersHook(self.flow)
|
||||
|
||||
if self.flow.request.headers.get("expect", "").lower() == "100-continue":
|
||||
raise NotImplementedError("expect nothing")
|
||||
# self.send_response(http.expect_continue_response)
|
||||
# request.headers.pop("expect")
|
||||
|
||||
if self.flow.request.stream:
|
||||
raise NotImplementedError # FIXME
|
||||
else:
|
||||
self._handle_event = self.read_request_body
|
||||
|
||||
@expect(RequestData, RequestEndOfMessage)
|
||||
def read_request_body(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
if isinstance(event, RequestData):
|
||||
self.request_body_buf += event.data
|
||||
elif isinstance(event, RequestEndOfMessage):
|
||||
self.flow.request.data.content = self.request_body_buf
|
||||
self.request_body_buf = b""
|
||||
yield from self.handle_request()
|
||||
|
||||
def handle_connect(self) -> commands.TCommandGenerator:
|
||||
yield HttpConnectHook(self.flow)
|
||||
|
||||
self.context.server = Server((self.flow.request.host, self.flow.request.port))
|
||||
if self.context.options.connection_strategy == "eager":
|
||||
err = yield commands.OpenConnection(self.context.server)
|
||||
if err:
|
||||
self.flow.response = http.HTTPResponse.make(
|
||||
502, f"Cannot connect to {human.format_address(self.context.server.address)}: {err}"
|
||||
)
|
||||
|
||||
if not self.flow.response:
|
||||
self.flow.response = http.make_connect_response(self.flow.request.data.http_version)
|
||||
|
||||
yield SendHttp(ResponseHeaders(self.flow.response, self.stream_id), self.context.client)
|
||||
|
||||
if 200 <= self.flow.response.status_code < 300:
|
||||
self.child_layer = NextLayer(self.context)
|
||||
yield from self.child_layer.handle_event(events.Start())
|
||||
self._handle_event = self.passthrough
|
||||
else:
|
||||
yield SendHttp(ResponseData(self.flow.response.data.content, self.stream_id), self.context.client)
|
||||
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
||||
|
||||
@expect(RequestData, RequestEndOfMessage, events.Event)
|
||||
def passthrough(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
# HTTP events -> normal connection events
|
||||
if isinstance(event, RequestData):
|
||||
event = events.DataReceived(self.context.client, event.data)
|
||||
elif isinstance(event, RequestEndOfMessage):
|
||||
event = events.ConnectionClosed(self.context.client)
|
||||
|
||||
for command in self.child_layer.handle_event(event):
|
||||
# normal connection events -> HTTP events
|
||||
if isinstance(command, commands.SendData) and command.connection == self.context.client:
|
||||
yield SendHttp(ResponseData(command.data, self.stream_id), self.context.client)
|
||||
elif isinstance(command, commands.CloseConnection) and command.connection == self.context.client:
|
||||
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
||||
elif isinstance(command, commands.OpenConnection) and command.connection == self.context.server:
|
||||
yield from self.passthrough(events.OpenConnectionReply(command, None))
|
||||
else:
|
||||
yield command
|
||||
|
||||
def handle_request(self) -> commands.TCommandGenerator:
|
||||
# set first line format to relative in regular mode,
|
||||
# see https://github.com/mitmproxy/mitmproxy/issues/1759
|
||||
if self.mode is HTTPMode.regular and self.flow.request.first_line_format == "absolute":
|
||||
@ -210,68 +130,174 @@ class HttpStream(Layer):
|
||||
self.flow.request.port = self.context.server.address[1]
|
||||
self.flow.request.host_header = host_header # set again as .host overwrites this.
|
||||
self.flow.request.scheme = "https" if self.context.server.tls else "http"
|
||||
yield HttpRequestHook(self.flow)
|
||||
|
||||
if self.flow.response:
|
||||
# response was set by an inline script.
|
||||
# we now need to emulate the responseheaders hook.
|
||||
yield HttpResponseHeadersHook(self.flow)
|
||||
yield from self.handle_response()
|
||||
else:
|
||||
connection, err = yield GetHttpConnection(
|
||||
(self.flow.request.host, self.flow.request.port),
|
||||
self.flow.request.scheme == "https"
|
||||
)
|
||||
if err:
|
||||
yield from self.send_error_response(502, err)
|
||||
self.flow.error = flow.Error(err)
|
||||
yield HttpErrorHook(self.flow)
|
||||
yield HttpRequestHeadersHook(self.flow)
|
||||
|
||||
if self.flow.request.stream:
|
||||
if self.flow.response:
|
||||
raise NotImplementedError("Can't set a response and enable streaming at the same time.")
|
||||
ok = yield from self.make_server_connection()
|
||||
if not ok:
|
||||
return
|
||||
else:
|
||||
self.flow.server_conn = connection
|
||||
yield SendHttp(RequestHeaders(self.stream_id, self.flow.request), self.context.server)
|
||||
self.client_state = self.state_stream_request_body
|
||||
else:
|
||||
self.client_state = self.state_consume_request_body
|
||||
self.server_state = self.state_wait_for_response_headers
|
||||
|
||||
yield SendHttp(RequestHeaders(self.flow.request, self.stream_id), connection)
|
||||
|
||||
if self.flow.request.stream:
|
||||
raise NotImplementedError
|
||||
@expect(RequestData, RequestEndOfMessage)
|
||||
def state_stream_request_body(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
if isinstance(event, RequestData):
|
||||
if callable(self.flow.request.stream):
|
||||
data = self.flow.request.stream(event.data)
|
||||
else:
|
||||
yield SendHttp(RequestData(self.flow.request.data.content, self.stream_id), connection)
|
||||
yield SendHttp(RequestEndOfMessage(self.stream_id), connection)
|
||||
self._handle_event = self.read_response_headers
|
||||
data = event.data
|
||||
yield SendHttp(RequestData(self.stream_id, data), self.context.server)
|
||||
elif isinstance(event, RequestEndOfMessage):
|
||||
yield SendHttp(RequestEndOfMessage(self.stream_id), self.context.server)
|
||||
self.client_state = self.state_done
|
||||
|
||||
@expect(RequestData, RequestEndOfMessage)
|
||||
def state_consume_request_body(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
if isinstance(event, RequestData):
|
||||
self.request_body_buf += event.data
|
||||
elif isinstance(event, RequestEndOfMessage):
|
||||
self.flow.request.data.content = self.request_body_buf
|
||||
self.request_body_buf = b""
|
||||
yield HttpRequestHook(self.flow)
|
||||
if self.flow.response:
|
||||
# response was set by an inline script.
|
||||
# we now need to emulate the responseheaders hook.
|
||||
yield HttpResponseHeadersHook(self.flow)
|
||||
yield from self.send_response()
|
||||
else:
|
||||
ok = yield from self.make_server_connection()
|
||||
if not ok:
|
||||
return
|
||||
|
||||
yield SendHttp(RequestHeaders(self.stream_id, self.flow.request), self.context.server)
|
||||
yield SendHttp(RequestData(self.stream_id, self.flow.request.data.content), self.context.server)
|
||||
yield SendHttp(RequestEndOfMessage(self.stream_id), self.context.server)
|
||||
|
||||
self.client_state = self.state_done
|
||||
|
||||
@expect(ResponseHeaders)
|
||||
def read_response_headers(self, event: ResponseHeaders) -> commands.TCommandGenerator:
|
||||
def state_wait_for_response_headers(self, event: ResponseHeaders) -> commands.TCommandGenerator:
|
||||
self.flow.response = event.response
|
||||
yield HttpResponseHeadersHook(self.flow)
|
||||
if not self.flow.response.stream:
|
||||
self._handle_event = self.read_response_body
|
||||
if self.flow.response.stream:
|
||||
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response), self.context.client)
|
||||
self.server_state = self.state_stream_response_body
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.server_state = self.state_consume_response_body
|
||||
|
||||
@expect(ResponseData, ResponseEndOfMessage)
|
||||
def read_response_body(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
def state_stream_response_body(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
if isinstance(event, ResponseData):
|
||||
if callable(self.flow.response.stream):
|
||||
data = self.flow.response.stream(event.data)
|
||||
else:
|
||||
data = event.data
|
||||
yield SendHttp(ResponseData(self.stream_id, data), self.context.client)
|
||||
elif isinstance(event, ResponseEndOfMessage):
|
||||
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
||||
self.server_state = self.state_done
|
||||
|
||||
@expect(ResponseData, ResponseEndOfMessage)
|
||||
def state_consume_response_body(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
if isinstance(event, ResponseData):
|
||||
self.response_body_buf += event.data
|
||||
elif isinstance(event, ResponseEndOfMessage):
|
||||
self.flow.response.data.content = self.response_body_buf
|
||||
self.response_body_buf = b""
|
||||
yield from self.handle_response()
|
||||
yield from self.send_response()
|
||||
|
||||
def handle_response(self):
|
||||
def send_response(self):
|
||||
yield HttpResponseHook(self.flow)
|
||||
yield SendHttp(ResponseHeaders(self.flow.response, self.stream_id), self.context.client)
|
||||
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response), self.context.client)
|
||||
yield SendHttp(ResponseData(self.stream_id, self.flow.response.data.content), self.context.client)
|
||||
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
||||
self.server_state = self.state_done
|
||||
|
||||
if self.flow.response.stream:
|
||||
raise NotImplementedError
|
||||
def handle_protocol_error(
|
||||
self,
|
||||
event: typing.Union[RequestProtocolError, ResponseProtocolError]
|
||||
) -> commands.TCommandGenerator:
|
||||
self.flow.error = flow.Error(event.message)
|
||||
yield HttpErrorHook(self.flow)
|
||||
|
||||
if isinstance(event, RequestProtocolError):
|
||||
yield SendHttp(event, self.context.server)
|
||||
else:
|
||||
yield SendHttp(ResponseData(self.flow.response.data.content, self.stream_id), self.context.client)
|
||||
yield SendHttp(event, self.context.client)
|
||||
return
|
||||
|
||||
def make_server_connection(self) -> typing.Generator[commands.Command, typing.Any, bool]:
|
||||
connection, err = yield GetHttpConnection(
|
||||
(self.flow.request.host, self.flow.request.port),
|
||||
self.flow.request.scheme == "https"
|
||||
)
|
||||
if err:
|
||||
yield from self.handle_protocol_error(ResponseProtocolError(self.stream_id, err))
|
||||
return False
|
||||
else:
|
||||
self.context.server = self.flow.server_conn = connection
|
||||
return True
|
||||
|
||||
def handle_connect(self) -> commands.TCommandGenerator:
|
||||
yield HttpConnectHook(self.flow)
|
||||
|
||||
self.context.server = Server((self.flow.request.host, self.flow.request.port))
|
||||
if self.context.options.connection_strategy == "eager":
|
||||
err = yield commands.OpenConnection(self.context.server)
|
||||
if err:
|
||||
self.flow.response = http.HTTPResponse.make(
|
||||
502, f"Cannot connect to {human.format_address(self.context.server.address)}: {err}"
|
||||
)
|
||||
|
||||
if not self.flow.response:
|
||||
self.flow.response = http.make_connect_response(self.flow.request.data.http_version)
|
||||
|
||||
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response), self.context.client)
|
||||
|
||||
if 200 <= self.flow.response.status_code < 300:
|
||||
self.child_layer = NextLayer(self.context)
|
||||
yield from self.child_layer.handle_event(events.Start())
|
||||
self._handle_event = self.passthrough
|
||||
else:
|
||||
yield SendHttp(ResponseData(self.stream_id, self.flow.response.data.content), self.context.client)
|
||||
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
||||
|
||||
def send_error_response(self, status_code: int, message: str, headers=None):
|
||||
response = http.make_error_response(status_code, message, headers)
|
||||
yield SendHttp(ResponseHeaders(response, self.stream_id), self.context.client)
|
||||
yield SendHttp(ResponseData(response.data.content, self.stream_id), self.context.client)
|
||||
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
||||
@expect(RequestData, RequestEndOfMessage, events.Event)
|
||||
def passthrough(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
# HTTP events -> normal connection events
|
||||
if isinstance(event, RequestData):
|
||||
event = events.DataReceived(self.context.client, event.data)
|
||||
elif isinstance(event, RequestEndOfMessage):
|
||||
event = events.ConnectionClosed(self.context.client)
|
||||
|
||||
for command in self.child_layer.handle_event(event):
|
||||
# normal connection events -> HTTP events
|
||||
if isinstance(command, commands.SendData) and command.connection == self.context.client:
|
||||
yield SendHttp(ResponseData(self.stream_id, command.data), self.context.client)
|
||||
elif isinstance(command, commands.CloseConnection) and command.connection == self.context.client:
|
||||
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
||||
elif isinstance(command, commands.OpenConnection) and command.connection == self.context.server:
|
||||
yield from self.passthrough(events.OpenConnectionReply(command, None))
|
||||
else:
|
||||
yield command
|
||||
|
||||
@expect()
|
||||
def state_uninitialized(self, _) -> commands.TCommandGenerator:
|
||||
yield from ()
|
||||
|
||||
@expect()
|
||||
def state_done(self, _) -> commands.TCommandGenerator:
|
||||
yield from ()
|
||||
|
||||
def state_errored(self, _) -> commands.TCommandGenerator:
|
||||
# silently consume every event.
|
||||
yield from ()
|
||||
|
||||
|
||||
class HTTPLayer(Layer):
|
||||
@ -304,8 +330,6 @@ class HTTPLayer(Layer):
|
||||
self.connections = {
|
||||
context.client: Http1Server(context.client)
|
||||
}
|
||||
if self.context.server.connected:
|
||||
self.make_http_connection(self.context.server)
|
||||
|
||||
def __repr__(self):
|
||||
return f"HTTPLayer(conns: {len(self.connections)}, events: {[type(e).__name__ for e in self.event_queue]})"
|
||||
@ -322,8 +346,6 @@ class HTTPLayer(Layer):
|
||||
self.event_to_child(stream, GetHttpConnectionReply(cmd, (None, event.reply)))
|
||||
else:
|
||||
yield from self.make_http_connection(event.command.connection)
|
||||
elif isinstance(event, EstablishServerTLSReply) and event.command.connection in self.waiting_for_connection:
|
||||
yield from self.make_http_connection(event.command.connection)
|
||||
elif isinstance(event, events.CommandReply):
|
||||
try:
|
||||
stream = self.stream_by_command.pop(event.command)
|
||||
|
@ -1,18 +1,16 @@
|
||||
import abc
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
|
||||
from mitmproxy.proxy2 import commands, events
|
||||
|
||||
StreamId = int
|
||||
|
||||
|
||||
@dataclass
|
||||
class HttpEvent(events.Event):
|
||||
stream_id: StreamId
|
||||
|
||||
# we need stream ids on every event to avoid race conditions
|
||||
|
||||
def __init__(self, stream_id: StreamId):
|
||||
self.stream_id = stream_id
|
||||
stream_id: StreamId
|
||||
|
||||
def __repr__(self) -> str:
|
||||
x = self.__dict__.copy()
|
||||
|
@ -1,38 +1,28 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from mitmproxy import http
|
||||
from ._base import HttpEvent, StreamId
|
||||
from ._base import HttpEvent
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestHeaders(HttpEvent):
|
||||
request: http.HTTPRequest
|
||||
|
||||
def __init__(self, request: http.HTTPRequest, stream_id: StreamId):
|
||||
super().__init__(stream_id)
|
||||
self.request = request
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResponseHeaders(HttpEvent):
|
||||
response: http.HTTPResponse
|
||||
|
||||
def __init__(self, response: http.HTTPResponse, stream_id: StreamId):
|
||||
super().__init__(stream_id)
|
||||
self.response = response
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestData(HttpEvent):
|
||||
data: bytes
|
||||
|
||||
def __init__(self, data: bytes, stream_id: StreamId):
|
||||
super().__init__(stream_id)
|
||||
self.data = data
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResponseData(HttpEvent):
|
||||
data: bytes
|
||||
|
||||
def __init__(self, data: bytes, stream_id: StreamId):
|
||||
super().__init__(stream_id)
|
||||
self.data = data
|
||||
|
||||
|
||||
class RequestEndOfMessage(HttpEvent):
|
||||
pass
|
||||
@ -42,6 +32,18 @@ class ResponseEndOfMessage(HttpEvent):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestProtocolError(HttpEvent):
|
||||
message: str
|
||||
code: int = 400
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResponseProtocolError(HttpEvent):
|
||||
message: str
|
||||
code: int = 502
|
||||
|
||||
|
||||
__all__ = [
|
||||
"HttpEvent",
|
||||
"RequestHeaders",
|
||||
@ -50,4 +52,6 @@ __all__ = [
|
||||
"ResponseHeaders",
|
||||
"ResponseData",
|
||||
"ResponseEndOfMessage",
|
||||
"RequestProtocolError",
|
||||
"ResponseProtocolError",
|
||||
]
|
||||
|
31
mitmproxy/proxy2/layers/http/_hooks.py
Normal file
31
mitmproxy/proxy2/layers/http/_hooks.py
Normal file
@ -0,0 +1,31 @@
|
||||
from mitmproxy import http
|
||||
from mitmproxy.proxy2 import commands
|
||||
|
||||
|
||||
class HttpRequestHeadersHook(commands.Hook):
|
||||
name = "requestheaders"
|
||||
flow: http.HTTPFlow
|
||||
|
||||
|
||||
class HttpRequestHook(commands.Hook):
|
||||
name = "request"
|
||||
flow: http.HTTPFlow
|
||||
|
||||
|
||||
class HttpResponseHook(commands.Hook):
|
||||
name = "response"
|
||||
flow: http.HTTPFlow
|
||||
|
||||
|
||||
class HttpResponseHeadersHook(commands.Hook):
|
||||
name = "responseheaders"
|
||||
flow: http.HTTPFlow
|
||||
|
||||
|
||||
class HttpConnectHook(commands.Hook):
|
||||
flow: http.HTTPFlow
|
||||
|
||||
|
||||
class HttpErrorHook(commands.Hook):
|
||||
name = "error"
|
||||
flow: http.HTTPFlow
|
@ -12,17 +12,19 @@ from mitmproxy.proxy2 import commands, events
|
||||
from mitmproxy.proxy2.context import Client, Connection, Server
|
||||
from mitmproxy.proxy2.layers.http._base import StreamId
|
||||
from ._base import HttpConnection
|
||||
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, ResponseData, ResponseEndOfMessage, \
|
||||
ResponseHeaders
|
||||
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
|
||||
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
|
||||
|
||||
TBodyReader = typing.Union[ChunkedReader, Http10Reader, ContentLengthReader]
|
||||
|
||||
|
||||
class Http1Connection(HttpConnection):
|
||||
conn: Connection
|
||||
stream_id: StreamId = None
|
||||
request: http.HTTPRequest
|
||||
response: http.HTTPResponse
|
||||
stream_id: typing.Optional[StreamId] = None
|
||||
request: typing.Optional[http.HTTPRequest] = None
|
||||
response: typing.Optional[http.HTTPResponse] = None
|
||||
request_done: bool = False
|
||||
response_done: bool = False
|
||||
state: typing.Callable[[events.Event], typing.Iterator[HttpEvent]]
|
||||
body_reader: TBodyReader
|
||||
buf: ReceiveBuffer
|
||||
@ -58,17 +60,22 @@ class Http1Connection(HttpConnection):
|
||||
h11_event = self.body_reader.read_eof()
|
||||
else:
|
||||
raise ValueError(f"Unexpected event: {event}")
|
||||
except h11.ProtocolError:
|
||||
raise # FIXME
|
||||
except h11.ProtocolError as e:
|
||||
yield commands.CloseConnection(self.conn)
|
||||
if is_request:
|
||||
yield RequestProtocolError(self.stream_id, str(e))
|
||||
else:
|
||||
yield ResponseProtocolError(self.stream_id, str(e))
|
||||
return
|
||||
|
||||
if h11_event is None:
|
||||
return
|
||||
elif isinstance(h11_event, h11.Data):
|
||||
h11_event.data: bytearray # type checking
|
||||
if is_request:
|
||||
yield RequestData(bytes(h11_event.data), self.stream_id)
|
||||
yield RequestData(self.stream_id, bytes(h11_event.data))
|
||||
else:
|
||||
yield ResponseData(bytes(h11_event.data), self.stream_id)
|
||||
yield ResponseData(self.stream_id, bytes(h11_event.data))
|
||||
elif isinstance(h11_event, h11.EndOfMessage):
|
||||
if is_request:
|
||||
yield RequestEndOfMessage(self.stream_id)
|
||||
@ -96,13 +103,15 @@ class Http1Server(Http1Connection):
|
||||
|
||||
def __init__(self, conn: Client):
|
||||
super().__init__(conn)
|
||||
self.stream_id = 1
|
||||
self.state = self.read_request_headers
|
||||
self.stream_id = -1
|
||||
|
||||
def send(self, event: HttpEvent) -> commands.TCommandGenerator:
|
||||
assert event.stream_id == self.stream_id
|
||||
if isinstance(event, ResponseHeaders):
|
||||
self.response = event.response
|
||||
raw = http1.assemble_response_head(event.response)
|
||||
yield commands.SendData(self.conn, raw)
|
||||
if self.request.first_line_format == "authority":
|
||||
assert self.state == self.wait
|
||||
self.body_reader = self.make_body_reader(-1)
|
||||
@ -113,24 +122,35 @@ class Http1Server(Http1Connection):
|
||||
raw = b"%x\r\n%s\r\n" % (len(event.data), event.data)
|
||||
else:
|
||||
raw = event.data
|
||||
yield commands.SendData(self.conn, raw)
|
||||
elif isinstance(event, ResponseEndOfMessage):
|
||||
if "chunked" in self.response.headers.get("transfer-encoding", "").lower():
|
||||
raw = b"0\r\n\r\n"
|
||||
yield commands.SendData(self.conn, b"0\r\n\r\n")
|
||||
elif http1.expected_http_body_size(self.request, self.response) == -1:
|
||||
yield commands.CloseConnection(self.conn)
|
||||
return
|
||||
else:
|
||||
raw = False
|
||||
self.request = None
|
||||
self.response = None
|
||||
self.stream_id += 2
|
||||
self.state = self.read_request_headers
|
||||
yield from self.state(events.DataReceived(self.conn, b""))
|
||||
yield from self.mark_done(response=True)
|
||||
elif isinstance(event, ResponseProtocolError):
|
||||
if not self.response:
|
||||
resp = http.make_error_response(event.code, event.message)
|
||||
raw = http1.assemble_response(resp)
|
||||
yield commands.SendData(self.conn, raw)
|
||||
yield commands.CloseConnection(self.conn)
|
||||
else:
|
||||
raise NotImplementedError(f"{event}")
|
||||
|
||||
if raw:
|
||||
yield commands.SendData(self.conn, raw)
|
||||
def mark_done(self, *, request: bool = False, response: bool = False):
|
||||
if request:
|
||||
self.request_done = True
|
||||
if response:
|
||||
self.response_done = True
|
||||
if self.request_done and self.response_done:
|
||||
self.request_done = self.response_done = False
|
||||
self.request = self.response = None
|
||||
self.stream_id += 2
|
||||
self.state = self.read_request_headers
|
||||
yield from self.state(events.DataReceived(self.conn, b""))
|
||||
elif self.request_done:
|
||||
self.state = self.wait
|
||||
|
||||
def read_request_headers(self, event: events.Event) -> typing.Iterator[HttpEvent]:
|
||||
if isinstance(event, events.DataReceived):
|
||||
@ -138,7 +158,7 @@ class Http1Server(Http1Connection):
|
||||
if request_head:
|
||||
request_head = [bytes(x) for x in request_head] # TODO: Make url.parse compatible with bytearrays
|
||||
self.request = http.HTTPRequest.wrap(http1_sansio.read_request_head(request_head))
|
||||
yield RequestHeaders(self.request, self.stream_id)
|
||||
yield RequestHeaders(self.stream_id, self.request)
|
||||
|
||||
if self.request.first_line_format == "authority":
|
||||
# The previous proxy server implementation tried to read the request body here:
|
||||
@ -151,17 +171,20 @@ class Http1Server(Http1Connection):
|
||||
self.body_reader = self.make_body_reader(expected_size)
|
||||
self.state = self.read_request_body
|
||||
yield from self.state(event)
|
||||
else:
|
||||
pass # FIXME: protect against header size DoS
|
||||
elif isinstance(event, events.ConnectionClosed):
|
||||
pass # TODO: Better handling, tear everything down.
|
||||
if bytes(self.buf).strip():
|
||||
yield commands.Log(f"Client closed connection before sending request headers: {bytes(self.buf)}")
|
||||
yield commands.CloseConnection(self.conn)
|
||||
else:
|
||||
raise ValueError(f"Unexpected event: {event}")
|
||||
|
||||
def read_request_body(self, event: events.Event) -> typing.Iterator[HttpEvent]:
|
||||
for e in self.read_body(event, True):
|
||||
if isinstance(e, RequestEndOfMessage):
|
||||
self.state = self.wait
|
||||
yield from self.state(event)
|
||||
yield e
|
||||
if isinstance(e, RequestEndOfMessage):
|
||||
yield from self.mark_done(request=True)
|
||||
|
||||
|
||||
class Http1Client(Http1Connection):
|
||||
@ -188,25 +211,40 @@ class Http1Client(Http1Connection):
|
||||
|
||||
if isinstance(event, RequestHeaders):
|
||||
raw = http1.assemble_request_head(event.request)
|
||||
yield commands.SendData(self.conn, raw)
|
||||
elif isinstance(event, RequestData):
|
||||
if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
|
||||
raw = b"%x\r\n%s\r\n" % (len(event.data), event.data)
|
||||
else:
|
||||
raw = event.data
|
||||
yield commands.SendData(self.conn, raw)
|
||||
elif isinstance(event, RequestEndOfMessage):
|
||||
if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
|
||||
raw = b"0\r\n\r\n"
|
||||
yield commands.SendData(self.conn, b"0\r\n\r\n")
|
||||
elif http1.expected_http_body_size(self.request) == -1:
|
||||
assert not self.send_queue
|
||||
yield commands.CloseConnection(self.conn)
|
||||
return
|
||||
else:
|
||||
raw = False
|
||||
yield from self.mark_done(request=True)
|
||||
elif isinstance(event, RequestProtocolError):
|
||||
yield commands.CloseConnection(self.conn)
|
||||
return
|
||||
else:
|
||||
raise NotImplementedError(f"{event}")
|
||||
|
||||
if raw:
|
||||
yield commands.SendData(self.conn, raw)
|
||||
def mark_done(self, *, request: bool = False, response: bool = False):
|
||||
if request:
|
||||
self.request_done = True
|
||||
if response:
|
||||
self.response_done = True
|
||||
if self.request_done and self.response_done:
|
||||
self.request_done = self.response_done = False
|
||||
self.request = self.response = None
|
||||
self.stream_id = None
|
||||
if self.send_queue:
|
||||
send_queue = self.send_queue
|
||||
self.send_queue = []
|
||||
for ev in send_queue:
|
||||
yield from self.send(ev)
|
||||
|
||||
def read_response_headers(self, event: events.Event) -> typing.Iterator[HttpEvent]:
|
||||
assert isinstance(event, events.ConnectionEvent)
|
||||
@ -216,18 +254,27 @@ class Http1Client(Http1Connection):
|
||||
if response_head:
|
||||
response_head = [bytes(x) for x in response_head]
|
||||
self.response = http.HTTPResponse.wrap(http1_sansio.read_response_head(response_head))
|
||||
yield ResponseHeaders(self.response, self.stream_id)
|
||||
yield ResponseHeaders(self.stream_id, self.response)
|
||||
|
||||
expected_size = http1.expected_http_body_size(self.request, self.response)
|
||||
self.body_reader = self.make_body_reader(expected_size)
|
||||
|
||||
self.state = self.read_response_body
|
||||
yield from self.state(event)
|
||||
else:
|
||||
pass # FIXME: protect against header size DoS
|
||||
elif isinstance(event, events.ConnectionClosed):
|
||||
if self.stream_id:
|
||||
raise NotImplementedError(f"{event}")
|
||||
if self.buf:
|
||||
yield ResponseProtocolError(self.stream_id, f"unexpected server response: {bytes(self.buf)}")
|
||||
else:
|
||||
# The server has closed the connection to prevent us from continuing.
|
||||
# We need to signal that to the stream.
|
||||
# https://tools.ietf.org/html/rfc7231#section-6.5.11
|
||||
yield ResponseProtocolError(self.stream_id, "server closed connection")
|
||||
else:
|
||||
return
|
||||
yield commands.CloseConnection(self.conn)
|
||||
else:
|
||||
raise ValueError(f"Unexpected event: {event}")
|
||||
|
||||
@ -237,14 +284,7 @@ class Http1Client(Http1Connection):
|
||||
yield e
|
||||
if isinstance(e, ResponseEndOfMessage):
|
||||
self.state = self.read_response_headers
|
||||
self.stream_id = None
|
||||
self.request = None
|
||||
self.response = None
|
||||
if self.send_queue:
|
||||
send_queue = self.send_queue
|
||||
self.send_queue = []
|
||||
for ev in send_queue:
|
||||
yield from self.send(ev)
|
||||
yield from self.mark_done(response=True)
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@ -1,9 +1,9 @@
|
||||
import pytest
|
||||
|
||||
from mitmproxy.http import HTTPResponse, HTTPFlow
|
||||
from mitmproxy.http import HTTPFlow, HTTPResponse
|
||||
from mitmproxy.proxy.protocol.http import HTTPMode
|
||||
from mitmproxy.proxy2 import layer
|
||||
from mitmproxy.proxy2.commands import OpenConnection, SendData
|
||||
from mitmproxy.proxy2.commands import CloseConnection, OpenConnection, SendData
|
||||
from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
|
||||
from mitmproxy.proxy2.layers import http, tls
|
||||
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply, reply_establish_server_tls, reply_next_layer
|
||||
@ -168,6 +168,22 @@ def test_http_reply_from_proxy(tctx):
|
||||
)
|
||||
|
||||
|
||||
def test_response_until_eof(tctx):
|
||||
"""Test scenario where the server response body is terminated by EOF."""
|
||||
server = Placeholder()
|
||||
assert (
|
||||
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
<< OpenConnection(server)
|
||||
>> reply(None)
|
||||
<< SendData(server, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
>> DataReceived(server, b"HTTP/1.1 200 OK\r\n\r\nfoo")
|
||||
>> ConnectionClosed(server)
|
||||
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\n\r\nfoo")
|
||||
<< CloseConnection(tctx.client)
|
||||
)
|
||||
|
||||
|
||||
def test_disconnect_while_intercept(tctx):
|
||||
"""Test a server disconnect while a request is intercepted."""
|
||||
tctx.options.connection_strategy = "eager"
|
||||
@ -198,3 +214,130 @@ def test_disconnect_while_intercept(tctx):
|
||||
)
|
||||
assert server1() != server2()
|
||||
assert flow().server_conn == server2()
|
||||
|
||||
|
||||
def test_response_streaming(tctx):
|
||||
"""Test HTTP response streaming"""
|
||||
server = Placeholder()
|
||||
flow = Placeholder()
|
||||
|
||||
def enable_streaming(flow: HTTPFlow):
|
||||
flow.response.stream = lambda x: x.upper()
|
||||
|
||||
assert (
|
||||
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
>> DataReceived(tctx.client, b"GET http://example.com/largefile HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
<< OpenConnection(server)
|
||||
>> reply(None)
|
||||
<< SendData(server, b"GET /largefile HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
>> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nabc")
|
||||
<< http.HttpResponseHeadersHook(flow)
|
||||
>> reply(side_effect=enable_streaming)
|
||||
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nABC")
|
||||
>> DataReceived(server, b"def")
|
||||
<< SendData(tctx.client, b"DEF")
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("response", ["normal response", "early response", "early close", "early kill"])
|
||||
def test_request_streaming(tctx, response):
|
||||
"""
|
||||
Test HTTP request streaming
|
||||
|
||||
This is a bit more contrived as we may receive server data while we are still sending the request.
|
||||
"""
|
||||
server = Placeholder()
|
||||
flow = Placeholder()
|
||||
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
|
||||
def enable_streaming(flow: HTTPFlow):
|
||||
flow.request.stream = lambda x: x.upper()
|
||||
|
||||
assert (
|
||||
playbook
|
||||
>> DataReceived(tctx.client, b"POST http://example.com/ HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Length: 6\r\n\r\n"
|
||||
b"abc")
|
||||
<< http.HttpRequestHeadersHook(flow)
|
||||
>> reply(side_effect=enable_streaming)
|
||||
<< OpenConnection(server)
|
||||
>> reply(None)
|
||||
<< SendData(server, b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Length: 6\r\n\r\n"
|
||||
b"ABC")
|
||||
)
|
||||
if response == "normal response":
|
||||
assert (
|
||||
playbook
|
||||
>> DataReceived(tctx.client, b"def")
|
||||
<< SendData(server, b"DEF")
|
||||
>> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
|
||||
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
|
||||
)
|
||||
elif response == "early response":
|
||||
# We may receive a response before we have finished sending our request.
|
||||
# We continue sending unless the server closes the connection.
|
||||
# https://tools.ietf.org/html/rfc7231#section-6.5.11
|
||||
assert (
|
||||
playbook
|
||||
>> DataReceived(server, b"HTTP/1.1 413 Request Entity Too Large\r\nContent-Length: 0\r\n\r\n")
|
||||
<< SendData(tctx.client, b"HTTP/1.1 413 Request Entity Too Large\r\nContent-Length: 0\r\n\r\n")
|
||||
>> DataReceived(tctx.client, b"def")
|
||||
<< SendData(server, b"DEF")
|
||||
)
|
||||
elif response == "early close":
|
||||
assert (
|
||||
playbook
|
||||
>> DataReceived(server, b"HTTP/1.1 413 Request Entity Too Large\r\nContent-Length: 0\r\n\r\n")
|
||||
<< SendData(tctx.client, b"HTTP/1.1 413 Request Entity Too Large\r\nContent-Length: 0\r\n\r\n")
|
||||
>> ConnectionClosed(server)
|
||||
<< CloseConnection(server)
|
||||
<< CloseConnection(tctx.client)
|
||||
)
|
||||
elif response == "early kill":
|
||||
err = Placeholder()
|
||||
assert (
|
||||
playbook
|
||||
>> ConnectionClosed(server)
|
||||
<< CloseConnection(server)
|
||||
<< SendData(tctx.client, err)
|
||||
<< CloseConnection(tctx.client)
|
||||
)
|
||||
assert b"502 Bad Gateway" in err()
|
||||
else: # pragma: no cover
|
||||
assert False
|
||||
|
||||
|
||||
@pytest.mark.parametrize("data", [
|
||||
None,
|
||||
b"I don't speak HTTP.",
|
||||
b"HTTP/1.1 200 OK\r\nContent-Length: 10\r\n\r\nweee"
|
||||
])
|
||||
def test_server_aborts(tctx, data):
|
||||
"""Test the scenario where the server doesn't serve a response"""
|
||||
server = Placeholder()
|
||||
flow = Placeholder()
|
||||
err = Placeholder()
|
||||
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
||||
assert (
|
||||
playbook
|
||||
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
<< OpenConnection(server)
|
||||
>> reply(None)
|
||||
<< SendData(server, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
)
|
||||
if data:
|
||||
playbook >> DataReceived(server, data)
|
||||
assert (
|
||||
playbook
|
||||
>> ConnectionClosed(server)
|
||||
<< CloseConnection(server)
|
||||
<< http.HttpErrorHook(flow)
|
||||
>> reply()
|
||||
<< SendData(tctx.client, err)
|
||||
<< CloseConnection(tctx.client)
|
||||
)
|
||||
assert flow().error
|
||||
assert b"502 Bad Gateway" in err()
|
||||
|
Loading…
Reference in New Issue
Block a user