mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 18:18:25 +00:00
[sans-io] implement http streaming, refine error handling
This commit is contained in:
parent
5671012163
commit
0740c673bd
@ -112,6 +112,9 @@ class Hook(Command):
|
|||||||
all_hooks: typing.Dict[str, typing.Type[Hook]] = {}
|
all_hooks: typing.Dict[str, typing.Type[Hook]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Move descriptions from addons/events.py into hooks and have hook documentation generated from all_hooks.
|
||||||
|
|
||||||
|
|
||||||
class GetSocket(ConnectionCommand):
|
class GetSocket(ConnectionCommand):
|
||||||
"""
|
"""
|
||||||
Get the underlying socket.
|
Get the underlying socket.
|
||||||
|
@ -5,6 +5,7 @@ The counterpart to events are commands.
|
|||||||
"""
|
"""
|
||||||
import socket
|
import socket
|
||||||
import typing
|
import typing
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from mitmproxy.proxy2 import commands
|
from mitmproxy.proxy2 import commands
|
||||||
from mitmproxy.proxy2.context import Connection
|
from mitmproxy.proxy2.context import Connection
|
||||||
@ -27,15 +28,13 @@ class Start(Event):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class ConnectionEvent(Event):
|
class ConnectionEvent(Event):
|
||||||
"""
|
"""
|
||||||
All events involving connection IO.
|
All events involving connection IO.
|
||||||
"""
|
"""
|
||||||
connection: Connection
|
connection: Connection
|
||||||
|
|
||||||
def __init__(self, connection: Connection):
|
|
||||||
self.connection = connection
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectionClosed(ConnectionEvent):
|
class ConnectionClosed(ConnectionEvent):
|
||||||
"""
|
"""
|
||||||
@ -44,20 +43,19 @@ class ConnectionClosed(ConnectionEvent):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class DataReceived(ConnectionEvent):
|
class DataReceived(ConnectionEvent):
|
||||||
"""
|
"""
|
||||||
Remote has sent some data.
|
Remote has sent some data.
|
||||||
"""
|
"""
|
||||||
|
data: bytes
|
||||||
def __init__(self, connection: Connection, data: bytes) -> None:
|
|
||||||
super().__init__(connection)
|
|
||||||
self.data = data
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
target = type(self.connection).__name__.lower()
|
target = type(self.connection).__name__.lower()
|
||||||
return f"DataReceived({target}, {self.data})"
|
return f"DataReceived({target}, {self.data})"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class CommandReply(Event):
|
class CommandReply(Event):
|
||||||
"""
|
"""
|
||||||
Emitted when a command has been finished, e.g.
|
Emitted when a command has been finished, e.g.
|
||||||
@ -66,10 +64,6 @@ class CommandReply(Event):
|
|||||||
command: commands.Command
|
command: commands.Command
|
||||||
reply: typing.Any
|
reply: typing.Any
|
||||||
|
|
||||||
def __init__(self, command: commands.Command, reply: typing.Any):
|
|
||||||
self.command = command
|
|
||||||
self.reply = reply
|
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
if cls is CommandReply:
|
if cls is CommandReply:
|
||||||
raise TypeError("CommandReply may not be instantiated directly.")
|
raise TypeError("CommandReply may not be instantiated directly.")
|
||||||
@ -88,35 +82,23 @@ class CommandReply(Event):
|
|||||||
command_reply_subclasses: typing.Dict[commands.Command, typing.Type[CommandReply]] = {}
|
command_reply_subclasses: typing.Dict[commands.Command, typing.Type[CommandReply]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class OpenConnectionReply(CommandReply):
|
class OpenConnectionReply(CommandReply):
|
||||||
command: commands.OpenConnection
|
command: commands.OpenConnection
|
||||||
reply: typing.Optional[str]
|
reply: typing.Optional[str]
|
||||||
|
"""error message"""
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
command: commands.OpenConnection,
|
|
||||||
err: typing.Optional[str]
|
|
||||||
):
|
|
||||||
super().__init__(command, err)
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class HookReply(CommandReply):
|
class HookReply(CommandReply):
|
||||||
command: commands.Hook
|
command: commands.Hook
|
||||||
|
reply: None = None
|
||||||
def __init__(self, command: commands.Hook):
|
|
||||||
super().__init__(command, None)
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"HookReply({repr(self.command)[5:-1]})"
|
return f"HookReply({repr(self.command)[5:-1]})"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class GetSocketReply(CommandReply):
|
class GetSocketReply(CommandReply):
|
||||||
command: commands.GetSocket
|
command: commands.GetSocket
|
||||||
reply: socket.socket
|
reply: socket.socket
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
command: commands.GetSocket,
|
|
||||||
socket: socket.socket
|
|
||||||
):
|
|
||||||
super().__init__(command, socket)
|
|
||||||
|
@ -10,8 +10,10 @@ from mitmproxy.proxy2.layers.tls import EstablishServerTLS, EstablishServerTLSRe
|
|||||||
from mitmproxy.proxy2.utils import expect
|
from mitmproxy.proxy2.utils import expect
|
||||||
from mitmproxy.utils import human
|
from mitmproxy.utils import human
|
||||||
from ._base import HttpConnection, StreamId
|
from ._base import HttpConnection, StreamId
|
||||||
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, ResponseData, ResponseEndOfMessage, \
|
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
|
||||||
ResponseHeaders
|
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
|
||||||
|
from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \
|
||||||
|
HttpResponseHook
|
||||||
from ._http1 import Http1Client, Http1Server
|
from ._http1 import Http1Client, Http1Server
|
||||||
from ._http2 import Http2Client
|
from ._http2 import Http2Client
|
||||||
|
|
||||||
@ -42,8 +44,8 @@ class GetHttpConnection(HttpCommand):
|
|||||||
|
|
||||||
class GetHttpConnectionReply(events.CommandReply):
|
class GetHttpConnectionReply(events.CommandReply):
|
||||||
command: GetHttpConnection
|
command: GetHttpConnection
|
||||||
reply: typing.Optional[str]
|
reply: typing.Tuple[typing.Optional[Connection], typing.Optional[str]]
|
||||||
"""error message"""
|
"""connection object, error message"""
|
||||||
|
|
||||||
|
|
||||||
class SendHttp(HttpCommand):
|
class SendHttp(HttpCommand):
|
||||||
@ -58,35 +60,6 @@ class SendHttp(HttpCommand):
|
|||||||
return f"Send({self.event})"
|
return f"Send({self.event})"
|
||||||
|
|
||||||
|
|
||||||
class HttpRequestHeadersHook(commands.Hook):
|
|
||||||
name = "requestheaders"
|
|
||||||
flow: http.HTTPFlow
|
|
||||||
|
|
||||||
|
|
||||||
class HttpRequestHook(commands.Hook):
|
|
||||||
name = "request"
|
|
||||||
flow: http.HTTPFlow
|
|
||||||
|
|
||||||
|
|
||||||
class HttpResponseHook(commands.Hook):
|
|
||||||
name = "response"
|
|
||||||
flow: http.HTTPFlow
|
|
||||||
|
|
||||||
|
|
||||||
class HttpResponseHeadersHook(commands.Hook):
|
|
||||||
name = "responseheaders"
|
|
||||||
flow: http.HTTPFlow
|
|
||||||
|
|
||||||
|
|
||||||
class HttpConnectHook(commands.Hook):
|
|
||||||
flow: http.HTTPFlow
|
|
||||||
|
|
||||||
|
|
||||||
class HttpErrorHook(commands.Hook):
|
|
||||||
name = "error"
|
|
||||||
flow: http.HTTPFlow
|
|
||||||
|
|
||||||
|
|
||||||
class HttpStream(Layer):
|
class HttpStream(Layer):
|
||||||
request_body_buf: bytes
|
request_body_buf: bytes
|
||||||
response_body_buf: bytes
|
response_body_buf: bytes
|
||||||
@ -103,15 +76,22 @@ class HttpStream(Layer):
|
|||||||
super().__init__(context)
|
super().__init__(context)
|
||||||
self.request_body_buf = b""
|
self.request_body_buf = b""
|
||||||
self.response_body_buf = b""
|
self.response_body_buf = b""
|
||||||
self._handle_event = self.start
|
self.client_state = self.state_uninitialized
|
||||||
|
self.server_state = self.state_uninitialized
|
||||||
|
|
||||||
@expect(events.Start)
|
@expect(events.Start, HttpEvent)
|
||||||
def start(self, event: events.Event) -> commands.TCommandGenerator:
|
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||||
self._handle_event = self.read_request_headers
|
if isinstance(event, events.Start):
|
||||||
yield from ()
|
self.client_state = self.state_wait_for_request_headers
|
||||||
|
elif isinstance(event, (RequestProtocolError, ResponseProtocolError)):
|
||||||
|
yield from self.handle_protocol_error(event)
|
||||||
|
elif isinstance(event, (RequestHeaders, RequestData, RequestEndOfMessage)):
|
||||||
|
yield from self.client_state(event)
|
||||||
|
else:
|
||||||
|
yield from self.server_state(event)
|
||||||
|
|
||||||
@expect(RequestHeaders)
|
@expect(RequestHeaders)
|
||||||
def read_request_headers(self, event: RequestHeaders) -> commands.TCommandGenerator:
|
def state_wait_for_request_headers(self, event: RequestHeaders) -> commands.TCommandGenerator:
|
||||||
self.stream_id = event.stream_id
|
self.stream_id = event.stream_id
|
||||||
self.flow = http.HTTPFlow(
|
self.flow = http.HTTPFlow(
|
||||||
self.context.client,
|
self.context.client,
|
||||||
@ -122,72 +102,12 @@ class HttpStream(Layer):
|
|||||||
if self.flow.request.first_line_format == "authority":
|
if self.flow.request.first_line_format == "authority":
|
||||||
yield from self.handle_connect()
|
yield from self.handle_connect()
|
||||||
return
|
return
|
||||||
else:
|
|
||||||
yield HttpRequestHeadersHook(self.flow)
|
|
||||||
|
|
||||||
if self.flow.request.headers.get("expect", "").lower() == "100-continue":
|
if self.flow.request.headers.get("expect", "").lower() == "100-continue":
|
||||||
raise NotImplementedError("expect nothing")
|
raise NotImplementedError("expect nothing")
|
||||||
# self.send_response(http.expect_continue_response)
|
# self.send_response(http.expect_continue_response)
|
||||||
# request.headers.pop("expect")
|
# request.headers.pop("expect")
|
||||||
|
|
||||||
if self.flow.request.stream:
|
|
||||||
raise NotImplementedError # FIXME
|
|
||||||
else:
|
|
||||||
self._handle_event = self.read_request_body
|
|
||||||
|
|
||||||
@expect(RequestData, RequestEndOfMessage)
|
|
||||||
def read_request_body(self, event: events.Event) -> commands.TCommandGenerator:
|
|
||||||
if isinstance(event, RequestData):
|
|
||||||
self.request_body_buf += event.data
|
|
||||||
elif isinstance(event, RequestEndOfMessage):
|
|
||||||
self.flow.request.data.content = self.request_body_buf
|
|
||||||
self.request_body_buf = b""
|
|
||||||
yield from self.handle_request()
|
|
||||||
|
|
||||||
def handle_connect(self) -> commands.TCommandGenerator:
|
|
||||||
yield HttpConnectHook(self.flow)
|
|
||||||
|
|
||||||
self.context.server = Server((self.flow.request.host, self.flow.request.port))
|
|
||||||
if self.context.options.connection_strategy == "eager":
|
|
||||||
err = yield commands.OpenConnection(self.context.server)
|
|
||||||
if err:
|
|
||||||
self.flow.response = http.HTTPResponse.make(
|
|
||||||
502, f"Cannot connect to {human.format_address(self.context.server.address)}: {err}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.flow.response:
|
|
||||||
self.flow.response = http.make_connect_response(self.flow.request.data.http_version)
|
|
||||||
|
|
||||||
yield SendHttp(ResponseHeaders(self.flow.response, self.stream_id), self.context.client)
|
|
||||||
|
|
||||||
if 200 <= self.flow.response.status_code < 300:
|
|
||||||
self.child_layer = NextLayer(self.context)
|
|
||||||
yield from self.child_layer.handle_event(events.Start())
|
|
||||||
self._handle_event = self.passthrough
|
|
||||||
else:
|
|
||||||
yield SendHttp(ResponseData(self.flow.response.data.content, self.stream_id), self.context.client)
|
|
||||||
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
|
||||||
|
|
||||||
@expect(RequestData, RequestEndOfMessage, events.Event)
|
|
||||||
def passthrough(self, event: events.Event) -> commands.TCommandGenerator:
|
|
||||||
# HTTP events -> normal connection events
|
|
||||||
if isinstance(event, RequestData):
|
|
||||||
event = events.DataReceived(self.context.client, event.data)
|
|
||||||
elif isinstance(event, RequestEndOfMessage):
|
|
||||||
event = events.ConnectionClosed(self.context.client)
|
|
||||||
|
|
||||||
for command in self.child_layer.handle_event(event):
|
|
||||||
# normal connection events -> HTTP events
|
|
||||||
if isinstance(command, commands.SendData) and command.connection == self.context.client:
|
|
||||||
yield SendHttp(ResponseData(command.data, self.stream_id), self.context.client)
|
|
||||||
elif isinstance(command, commands.CloseConnection) and command.connection == self.context.client:
|
|
||||||
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
|
||||||
elif isinstance(command, commands.OpenConnection) and command.connection == self.context.server:
|
|
||||||
yield from self.passthrough(events.OpenConnectionReply(command, None))
|
|
||||||
else:
|
|
||||||
yield command
|
|
||||||
|
|
||||||
def handle_request(self) -> commands.TCommandGenerator:
|
|
||||||
# set first line format to relative in regular mode,
|
# set first line format to relative in regular mode,
|
||||||
# see https://github.com/mitmproxy/mitmproxy/issues/1759
|
# see https://github.com/mitmproxy/mitmproxy/issues/1759
|
||||||
if self.mode is HTTPMode.regular and self.flow.request.first_line_format == "absolute":
|
if self.mode is HTTPMode.regular and self.flow.request.first_line_format == "absolute":
|
||||||
@ -210,68 +130,174 @@ class HttpStream(Layer):
|
|||||||
self.flow.request.port = self.context.server.address[1]
|
self.flow.request.port = self.context.server.address[1]
|
||||||
self.flow.request.host_header = host_header # set again as .host overwrites this.
|
self.flow.request.host_header = host_header # set again as .host overwrites this.
|
||||||
self.flow.request.scheme = "https" if self.context.server.tls else "http"
|
self.flow.request.scheme = "https" if self.context.server.tls else "http"
|
||||||
yield HttpRequestHook(self.flow)
|
|
||||||
|
|
||||||
|
yield HttpRequestHeadersHook(self.flow)
|
||||||
|
|
||||||
|
if self.flow.request.stream:
|
||||||
|
if self.flow.response:
|
||||||
|
raise NotImplementedError("Can't set a response and enable streaming at the same time.")
|
||||||
|
ok = yield from self.make_server_connection()
|
||||||
|
if not ok:
|
||||||
|
return
|
||||||
|
yield SendHttp(RequestHeaders(self.stream_id, self.flow.request), self.context.server)
|
||||||
|
self.client_state = self.state_stream_request_body
|
||||||
|
else:
|
||||||
|
self.client_state = self.state_consume_request_body
|
||||||
|
self.server_state = self.state_wait_for_response_headers
|
||||||
|
|
||||||
|
@expect(RequestData, RequestEndOfMessage)
|
||||||
|
def state_stream_request_body(self, event: events.Event) -> commands.TCommandGenerator:
|
||||||
|
if isinstance(event, RequestData):
|
||||||
|
if callable(self.flow.request.stream):
|
||||||
|
data = self.flow.request.stream(event.data)
|
||||||
|
else:
|
||||||
|
data = event.data
|
||||||
|
yield SendHttp(RequestData(self.stream_id, data), self.context.server)
|
||||||
|
elif isinstance(event, RequestEndOfMessage):
|
||||||
|
yield SendHttp(RequestEndOfMessage(self.stream_id), self.context.server)
|
||||||
|
self.client_state = self.state_done
|
||||||
|
|
||||||
|
@expect(RequestData, RequestEndOfMessage)
|
||||||
|
def state_consume_request_body(self, event: events.Event) -> commands.TCommandGenerator:
|
||||||
|
if isinstance(event, RequestData):
|
||||||
|
self.request_body_buf += event.data
|
||||||
|
elif isinstance(event, RequestEndOfMessage):
|
||||||
|
self.flow.request.data.content = self.request_body_buf
|
||||||
|
self.request_body_buf = b""
|
||||||
|
yield HttpRequestHook(self.flow)
|
||||||
if self.flow.response:
|
if self.flow.response:
|
||||||
# response was set by an inline script.
|
# response was set by an inline script.
|
||||||
# we now need to emulate the responseheaders hook.
|
# we now need to emulate the responseheaders hook.
|
||||||
yield HttpResponseHeadersHook(self.flow)
|
yield HttpResponseHeadersHook(self.flow)
|
||||||
yield from self.handle_response()
|
yield from self.send_response()
|
||||||
else:
|
else:
|
||||||
connection, err = yield GetHttpConnection(
|
ok = yield from self.make_server_connection()
|
||||||
(self.flow.request.host, self.flow.request.port),
|
if not ok:
|
||||||
self.flow.request.scheme == "https"
|
|
||||||
)
|
|
||||||
if err:
|
|
||||||
yield from self.send_error_response(502, err)
|
|
||||||
self.flow.error = flow.Error(err)
|
|
||||||
yield HttpErrorHook(self.flow)
|
|
||||||
return
|
return
|
||||||
else:
|
|
||||||
self.flow.server_conn = connection
|
|
||||||
|
|
||||||
yield SendHttp(RequestHeaders(self.flow.request, self.stream_id), connection)
|
yield SendHttp(RequestHeaders(self.stream_id, self.flow.request), self.context.server)
|
||||||
|
yield SendHttp(RequestData(self.stream_id, self.flow.request.data.content), self.context.server)
|
||||||
|
yield SendHttp(RequestEndOfMessage(self.stream_id), self.context.server)
|
||||||
|
|
||||||
if self.flow.request.stream:
|
self.client_state = self.state_done
|
||||||
raise NotImplementedError
|
|
||||||
else:
|
|
||||||
yield SendHttp(RequestData(self.flow.request.data.content, self.stream_id), connection)
|
|
||||||
yield SendHttp(RequestEndOfMessage(self.stream_id), connection)
|
|
||||||
self._handle_event = self.read_response_headers
|
|
||||||
|
|
||||||
@expect(ResponseHeaders)
|
@expect(ResponseHeaders)
|
||||||
def read_response_headers(self, event: ResponseHeaders) -> commands.TCommandGenerator:
|
def state_wait_for_response_headers(self, event: ResponseHeaders) -> commands.TCommandGenerator:
|
||||||
self.flow.response = event.response
|
self.flow.response = event.response
|
||||||
yield HttpResponseHeadersHook(self.flow)
|
yield HttpResponseHeadersHook(self.flow)
|
||||||
if not self.flow.response.stream:
|
if self.flow.response.stream:
|
||||||
self._handle_event = self.read_response_body
|
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response), self.context.client)
|
||||||
|
self.server_state = self.state_stream_response_body
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
self.server_state = self.state_consume_response_body
|
||||||
|
|
||||||
@expect(ResponseData, ResponseEndOfMessage)
|
@expect(ResponseData, ResponseEndOfMessage)
|
||||||
def read_response_body(self, event: events.Event) -> commands.TCommandGenerator:
|
def state_stream_response_body(self, event: events.Event) -> commands.TCommandGenerator:
|
||||||
|
if isinstance(event, ResponseData):
|
||||||
|
if callable(self.flow.response.stream):
|
||||||
|
data = self.flow.response.stream(event.data)
|
||||||
|
else:
|
||||||
|
data = event.data
|
||||||
|
yield SendHttp(ResponseData(self.stream_id, data), self.context.client)
|
||||||
|
elif isinstance(event, ResponseEndOfMessage):
|
||||||
|
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
||||||
|
self.server_state = self.state_done
|
||||||
|
|
||||||
|
@expect(ResponseData, ResponseEndOfMessage)
|
||||||
|
def state_consume_response_body(self, event: events.Event) -> commands.TCommandGenerator:
|
||||||
if isinstance(event, ResponseData):
|
if isinstance(event, ResponseData):
|
||||||
self.response_body_buf += event.data
|
self.response_body_buf += event.data
|
||||||
elif isinstance(event, ResponseEndOfMessage):
|
elif isinstance(event, ResponseEndOfMessage):
|
||||||
self.flow.response.data.content = self.response_body_buf
|
self.flow.response.data.content = self.response_body_buf
|
||||||
self.response_body_buf = b""
|
self.response_body_buf = b""
|
||||||
yield from self.handle_response()
|
yield from self.send_response()
|
||||||
|
|
||||||
def handle_response(self):
|
def send_response(self):
|
||||||
yield HttpResponseHook(self.flow)
|
yield HttpResponseHook(self.flow)
|
||||||
yield SendHttp(ResponseHeaders(self.flow.response, self.stream_id), self.context.client)
|
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response), self.context.client)
|
||||||
|
yield SendHttp(ResponseData(self.stream_id, self.flow.response.data.content), self.context.client)
|
||||||
|
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
||||||
|
self.server_state = self.state_done
|
||||||
|
|
||||||
if self.flow.response.stream:
|
def handle_protocol_error(
|
||||||
raise NotImplementedError
|
self,
|
||||||
|
event: typing.Union[RequestProtocolError, ResponseProtocolError]
|
||||||
|
) -> commands.TCommandGenerator:
|
||||||
|
self.flow.error = flow.Error(event.message)
|
||||||
|
yield HttpErrorHook(self.flow)
|
||||||
|
|
||||||
|
if isinstance(event, RequestProtocolError):
|
||||||
|
yield SendHttp(event, self.context.server)
|
||||||
else:
|
else:
|
||||||
yield SendHttp(ResponseData(self.flow.response.data.content, self.stream_id), self.context.client)
|
yield SendHttp(event, self.context.client)
|
||||||
|
return
|
||||||
|
|
||||||
|
def make_server_connection(self) -> typing.Generator[commands.Command, typing.Any, bool]:
|
||||||
|
connection, err = yield GetHttpConnection(
|
||||||
|
(self.flow.request.host, self.flow.request.port),
|
||||||
|
self.flow.request.scheme == "https"
|
||||||
|
)
|
||||||
|
if err:
|
||||||
|
yield from self.handle_protocol_error(ResponseProtocolError(self.stream_id, err))
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
self.context.server = self.flow.server_conn = connection
|
||||||
|
return True
|
||||||
|
|
||||||
|
def handle_connect(self) -> commands.TCommandGenerator:
|
||||||
|
yield HttpConnectHook(self.flow)
|
||||||
|
|
||||||
|
self.context.server = Server((self.flow.request.host, self.flow.request.port))
|
||||||
|
if self.context.options.connection_strategy == "eager":
|
||||||
|
err = yield commands.OpenConnection(self.context.server)
|
||||||
|
if err:
|
||||||
|
self.flow.response = http.HTTPResponse.make(
|
||||||
|
502, f"Cannot connect to {human.format_address(self.context.server.address)}: {err}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.flow.response:
|
||||||
|
self.flow.response = http.make_connect_response(self.flow.request.data.http_version)
|
||||||
|
|
||||||
|
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response), self.context.client)
|
||||||
|
|
||||||
|
if 200 <= self.flow.response.status_code < 300:
|
||||||
|
self.child_layer = NextLayer(self.context)
|
||||||
|
yield from self.child_layer.handle_event(events.Start())
|
||||||
|
self._handle_event = self.passthrough
|
||||||
|
else:
|
||||||
|
yield SendHttp(ResponseData(self.stream_id, self.flow.response.data.content), self.context.client)
|
||||||
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
||||||
|
|
||||||
def send_error_response(self, status_code: int, message: str, headers=None):
|
@expect(RequestData, RequestEndOfMessage, events.Event)
|
||||||
response = http.make_error_response(status_code, message, headers)
|
def passthrough(self, event: events.Event) -> commands.TCommandGenerator:
|
||||||
yield SendHttp(ResponseHeaders(response, self.stream_id), self.context.client)
|
# HTTP events -> normal connection events
|
||||||
yield SendHttp(ResponseData(response.data.content, self.stream_id), self.context.client)
|
if isinstance(event, RequestData):
|
||||||
|
event = events.DataReceived(self.context.client, event.data)
|
||||||
|
elif isinstance(event, RequestEndOfMessage):
|
||||||
|
event = events.ConnectionClosed(self.context.client)
|
||||||
|
|
||||||
|
for command in self.child_layer.handle_event(event):
|
||||||
|
# normal connection events -> HTTP events
|
||||||
|
if isinstance(command, commands.SendData) and command.connection == self.context.client:
|
||||||
|
yield SendHttp(ResponseData(self.stream_id, command.data), self.context.client)
|
||||||
|
elif isinstance(command, commands.CloseConnection) and command.connection == self.context.client:
|
||||||
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
||||||
|
elif isinstance(command, commands.OpenConnection) and command.connection == self.context.server:
|
||||||
|
yield from self.passthrough(events.OpenConnectionReply(command, None))
|
||||||
|
else:
|
||||||
|
yield command
|
||||||
|
|
||||||
|
@expect()
|
||||||
|
def state_uninitialized(self, _) -> commands.TCommandGenerator:
|
||||||
|
yield from ()
|
||||||
|
|
||||||
|
@expect()
|
||||||
|
def state_done(self, _) -> commands.TCommandGenerator:
|
||||||
|
yield from ()
|
||||||
|
|
||||||
|
def state_errored(self, _) -> commands.TCommandGenerator:
|
||||||
|
# silently consume every event.
|
||||||
|
yield from ()
|
||||||
|
|
||||||
|
|
||||||
class HTTPLayer(Layer):
|
class HTTPLayer(Layer):
|
||||||
@ -304,8 +330,6 @@ class HTTPLayer(Layer):
|
|||||||
self.connections = {
|
self.connections = {
|
||||||
context.client: Http1Server(context.client)
|
context.client: Http1Server(context.client)
|
||||||
}
|
}
|
||||||
if self.context.server.connected:
|
|
||||||
self.make_http_connection(self.context.server)
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"HTTPLayer(conns: {len(self.connections)}, events: {[type(e).__name__ for e in self.event_queue]})"
|
return f"HTTPLayer(conns: {len(self.connections)}, events: {[type(e).__name__ for e in self.event_queue]})"
|
||||||
@ -322,8 +346,6 @@ class HTTPLayer(Layer):
|
|||||||
self.event_to_child(stream, GetHttpConnectionReply(cmd, (None, event.reply)))
|
self.event_to_child(stream, GetHttpConnectionReply(cmd, (None, event.reply)))
|
||||||
else:
|
else:
|
||||||
yield from self.make_http_connection(event.command.connection)
|
yield from self.make_http_connection(event.command.connection)
|
||||||
elif isinstance(event, EstablishServerTLSReply) and event.command.connection in self.waiting_for_connection:
|
|
||||||
yield from self.make_http_connection(event.command.connection)
|
|
||||||
elif isinstance(event, events.CommandReply):
|
elif isinstance(event, events.CommandReply):
|
||||||
try:
|
try:
|
||||||
stream = self.stream_by_command.pop(event.command)
|
stream = self.stream_by_command.pop(event.command)
|
||||||
|
@ -1,18 +1,16 @@
|
|||||||
import abc
|
import abc
|
||||||
import typing
|
import typing
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from mitmproxy.proxy2 import commands, events
|
from mitmproxy.proxy2 import commands, events
|
||||||
|
|
||||||
StreamId = int
|
StreamId = int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class HttpEvent(events.Event):
|
class HttpEvent(events.Event):
|
||||||
stream_id: StreamId
|
|
||||||
|
|
||||||
# we need stream ids on every event to avoid race conditions
|
# we need stream ids on every event to avoid race conditions
|
||||||
|
stream_id: StreamId
|
||||||
def __init__(self, stream_id: StreamId):
|
|
||||||
self.stream_id = stream_id
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
x = self.__dict__.copy()
|
x = self.__dict__.copy()
|
||||||
|
@ -1,38 +1,28 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from mitmproxy import http
|
from mitmproxy import http
|
||||||
from ._base import HttpEvent, StreamId
|
from ._base import HttpEvent
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class RequestHeaders(HttpEvent):
|
class RequestHeaders(HttpEvent):
|
||||||
request: http.HTTPRequest
|
request: http.HTTPRequest
|
||||||
|
|
||||||
def __init__(self, request: http.HTTPRequest, stream_id: StreamId):
|
|
||||||
super().__init__(stream_id)
|
|
||||||
self.request = request
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class ResponseHeaders(HttpEvent):
|
class ResponseHeaders(HttpEvent):
|
||||||
response: http.HTTPResponse
|
response: http.HTTPResponse
|
||||||
|
|
||||||
def __init__(self, response: http.HTTPResponse, stream_id: StreamId):
|
|
||||||
super().__init__(stream_id)
|
|
||||||
self.response = response
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class RequestData(HttpEvent):
|
class RequestData(HttpEvent):
|
||||||
data: bytes
|
data: bytes
|
||||||
|
|
||||||
def __init__(self, data: bytes, stream_id: StreamId):
|
|
||||||
super().__init__(stream_id)
|
|
||||||
self.data = data
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class ResponseData(HttpEvent):
|
class ResponseData(HttpEvent):
|
||||||
data: bytes
|
data: bytes
|
||||||
|
|
||||||
def __init__(self, data: bytes, stream_id: StreamId):
|
|
||||||
super().__init__(stream_id)
|
|
||||||
self.data = data
|
|
||||||
|
|
||||||
|
|
||||||
class RequestEndOfMessage(HttpEvent):
|
class RequestEndOfMessage(HttpEvent):
|
||||||
pass
|
pass
|
||||||
@ -42,6 +32,18 @@ class ResponseEndOfMessage(HttpEvent):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RequestProtocolError(HttpEvent):
|
||||||
|
message: str
|
||||||
|
code: int = 400
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ResponseProtocolError(HttpEvent):
|
||||||
|
message: str
|
||||||
|
code: int = 502
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"HttpEvent",
|
"HttpEvent",
|
||||||
"RequestHeaders",
|
"RequestHeaders",
|
||||||
@ -50,4 +52,6 @@ __all__ = [
|
|||||||
"ResponseHeaders",
|
"ResponseHeaders",
|
||||||
"ResponseData",
|
"ResponseData",
|
||||||
"ResponseEndOfMessage",
|
"ResponseEndOfMessage",
|
||||||
|
"RequestProtocolError",
|
||||||
|
"ResponseProtocolError",
|
||||||
]
|
]
|
||||||
|
31
mitmproxy/proxy2/layers/http/_hooks.py
Normal file
31
mitmproxy/proxy2/layers/http/_hooks.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from mitmproxy import http
|
||||||
|
from mitmproxy.proxy2 import commands
|
||||||
|
|
||||||
|
|
||||||
|
class HttpRequestHeadersHook(commands.Hook):
|
||||||
|
name = "requestheaders"
|
||||||
|
flow: http.HTTPFlow
|
||||||
|
|
||||||
|
|
||||||
|
class HttpRequestHook(commands.Hook):
|
||||||
|
name = "request"
|
||||||
|
flow: http.HTTPFlow
|
||||||
|
|
||||||
|
|
||||||
|
class HttpResponseHook(commands.Hook):
|
||||||
|
name = "response"
|
||||||
|
flow: http.HTTPFlow
|
||||||
|
|
||||||
|
|
||||||
|
class HttpResponseHeadersHook(commands.Hook):
|
||||||
|
name = "responseheaders"
|
||||||
|
flow: http.HTTPFlow
|
||||||
|
|
||||||
|
|
||||||
|
class HttpConnectHook(commands.Hook):
|
||||||
|
flow: http.HTTPFlow
|
||||||
|
|
||||||
|
|
||||||
|
class HttpErrorHook(commands.Hook):
|
||||||
|
name = "error"
|
||||||
|
flow: http.HTTPFlow
|
@ -12,17 +12,19 @@ from mitmproxy.proxy2 import commands, events
|
|||||||
from mitmproxy.proxy2.context import Client, Connection, Server
|
from mitmproxy.proxy2.context import Client, Connection, Server
|
||||||
from mitmproxy.proxy2.layers.http._base import StreamId
|
from mitmproxy.proxy2.layers.http._base import StreamId
|
||||||
from ._base import HttpConnection
|
from ._base import HttpConnection
|
||||||
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, ResponseData, ResponseEndOfMessage, \
|
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
|
||||||
ResponseHeaders
|
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
|
||||||
|
|
||||||
TBodyReader = typing.Union[ChunkedReader, Http10Reader, ContentLengthReader]
|
TBodyReader = typing.Union[ChunkedReader, Http10Reader, ContentLengthReader]
|
||||||
|
|
||||||
|
|
||||||
class Http1Connection(HttpConnection):
|
class Http1Connection(HttpConnection):
|
||||||
conn: Connection
|
conn: Connection
|
||||||
stream_id: StreamId = None
|
stream_id: typing.Optional[StreamId] = None
|
||||||
request: http.HTTPRequest
|
request: typing.Optional[http.HTTPRequest] = None
|
||||||
response: http.HTTPResponse
|
response: typing.Optional[http.HTTPResponse] = None
|
||||||
|
request_done: bool = False
|
||||||
|
response_done: bool = False
|
||||||
state: typing.Callable[[events.Event], typing.Iterator[HttpEvent]]
|
state: typing.Callable[[events.Event], typing.Iterator[HttpEvent]]
|
||||||
body_reader: TBodyReader
|
body_reader: TBodyReader
|
||||||
buf: ReceiveBuffer
|
buf: ReceiveBuffer
|
||||||
@ -58,17 +60,22 @@ class Http1Connection(HttpConnection):
|
|||||||
h11_event = self.body_reader.read_eof()
|
h11_event = self.body_reader.read_eof()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected event: {event}")
|
raise ValueError(f"Unexpected event: {event}")
|
||||||
except h11.ProtocolError:
|
except h11.ProtocolError as e:
|
||||||
raise # FIXME
|
yield commands.CloseConnection(self.conn)
|
||||||
|
if is_request:
|
||||||
|
yield RequestProtocolError(self.stream_id, str(e))
|
||||||
|
else:
|
||||||
|
yield ResponseProtocolError(self.stream_id, str(e))
|
||||||
|
return
|
||||||
|
|
||||||
if h11_event is None:
|
if h11_event is None:
|
||||||
return
|
return
|
||||||
elif isinstance(h11_event, h11.Data):
|
elif isinstance(h11_event, h11.Data):
|
||||||
h11_event.data: bytearray # type checking
|
h11_event.data: bytearray # type checking
|
||||||
if is_request:
|
if is_request:
|
||||||
yield RequestData(bytes(h11_event.data), self.stream_id)
|
yield RequestData(self.stream_id, bytes(h11_event.data))
|
||||||
else:
|
else:
|
||||||
yield ResponseData(bytes(h11_event.data), self.stream_id)
|
yield ResponseData(self.stream_id, bytes(h11_event.data))
|
||||||
elif isinstance(h11_event, h11.EndOfMessage):
|
elif isinstance(h11_event, h11.EndOfMessage):
|
||||||
if is_request:
|
if is_request:
|
||||||
yield RequestEndOfMessage(self.stream_id)
|
yield RequestEndOfMessage(self.stream_id)
|
||||||
@ -96,13 +103,15 @@ class Http1Server(Http1Connection):
|
|||||||
|
|
||||||
def __init__(self, conn: Client):
|
def __init__(self, conn: Client):
|
||||||
super().__init__(conn)
|
super().__init__(conn)
|
||||||
|
self.stream_id = 1
|
||||||
self.state = self.read_request_headers
|
self.state = self.read_request_headers
|
||||||
self.stream_id = -1
|
|
||||||
|
|
||||||
def send(self, event: HttpEvent) -> commands.TCommandGenerator:
|
def send(self, event: HttpEvent) -> commands.TCommandGenerator:
|
||||||
|
assert event.stream_id == self.stream_id
|
||||||
if isinstance(event, ResponseHeaders):
|
if isinstance(event, ResponseHeaders):
|
||||||
self.response = event.response
|
self.response = event.response
|
||||||
raw = http1.assemble_response_head(event.response)
|
raw = http1.assemble_response_head(event.response)
|
||||||
|
yield commands.SendData(self.conn, raw)
|
||||||
if self.request.first_line_format == "authority":
|
if self.request.first_line_format == "authority":
|
||||||
assert self.state == self.wait
|
assert self.state == self.wait
|
||||||
self.body_reader = self.make_body_reader(-1)
|
self.body_reader = self.make_body_reader(-1)
|
||||||
@ -113,24 +122,35 @@ class Http1Server(Http1Connection):
|
|||||||
raw = b"%x\r\n%s\r\n" % (len(event.data), event.data)
|
raw = b"%x\r\n%s\r\n" % (len(event.data), event.data)
|
||||||
else:
|
else:
|
||||||
raw = event.data
|
raw = event.data
|
||||||
|
yield commands.SendData(self.conn, raw)
|
||||||
elif isinstance(event, ResponseEndOfMessage):
|
elif isinstance(event, ResponseEndOfMessage):
|
||||||
if "chunked" in self.response.headers.get("transfer-encoding", "").lower():
|
if "chunked" in self.response.headers.get("transfer-encoding", "").lower():
|
||||||
raw = b"0\r\n\r\n"
|
yield commands.SendData(self.conn, b"0\r\n\r\n")
|
||||||
elif http1.expected_http_body_size(self.request, self.response) == -1:
|
elif http1.expected_http_body_size(self.request, self.response) == -1:
|
||||||
yield commands.CloseConnection(self.conn)
|
yield commands.CloseConnection(self.conn)
|
||||||
return
|
yield from self.mark_done(response=True)
|
||||||
else:
|
elif isinstance(event, ResponseProtocolError):
|
||||||
raw = False
|
if not self.response:
|
||||||
self.request = None
|
resp = http.make_error_response(event.code, event.message)
|
||||||
self.response = None
|
raw = http1.assemble_response(resp)
|
||||||
self.stream_id += 2
|
yield commands.SendData(self.conn, raw)
|
||||||
self.state = self.read_request_headers
|
yield commands.CloseConnection(self.conn)
|
||||||
yield from self.state(events.DataReceived(self.conn, b""))
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"{event}")
|
raise NotImplementedError(f"{event}")
|
||||||
|
|
||||||
if raw:
|
def mark_done(self, *, request: bool = False, response: bool = False):
|
||||||
yield commands.SendData(self.conn, raw)
|
if request:
|
||||||
|
self.request_done = True
|
||||||
|
if response:
|
||||||
|
self.response_done = True
|
||||||
|
if self.request_done and self.response_done:
|
||||||
|
self.request_done = self.response_done = False
|
||||||
|
self.request = self.response = None
|
||||||
|
self.stream_id += 2
|
||||||
|
self.state = self.read_request_headers
|
||||||
|
yield from self.state(events.DataReceived(self.conn, b""))
|
||||||
|
elif self.request_done:
|
||||||
|
self.state = self.wait
|
||||||
|
|
||||||
def read_request_headers(self, event: events.Event) -> typing.Iterator[HttpEvent]:
|
def read_request_headers(self, event: events.Event) -> typing.Iterator[HttpEvent]:
|
||||||
if isinstance(event, events.DataReceived):
|
if isinstance(event, events.DataReceived):
|
||||||
@ -138,7 +158,7 @@ class Http1Server(Http1Connection):
|
|||||||
if request_head:
|
if request_head:
|
||||||
request_head = [bytes(x) for x in request_head] # TODO: Make url.parse compatible with bytearrays
|
request_head = [bytes(x) for x in request_head] # TODO: Make url.parse compatible with bytearrays
|
||||||
self.request = http.HTTPRequest.wrap(http1_sansio.read_request_head(request_head))
|
self.request = http.HTTPRequest.wrap(http1_sansio.read_request_head(request_head))
|
||||||
yield RequestHeaders(self.request, self.stream_id)
|
yield RequestHeaders(self.stream_id, self.request)
|
||||||
|
|
||||||
if self.request.first_line_format == "authority":
|
if self.request.first_line_format == "authority":
|
||||||
# The previous proxy server implementation tried to read the request body here:
|
# The previous proxy server implementation tried to read the request body here:
|
||||||
@ -151,17 +171,20 @@ class Http1Server(Http1Connection):
|
|||||||
self.body_reader = self.make_body_reader(expected_size)
|
self.body_reader = self.make_body_reader(expected_size)
|
||||||
self.state = self.read_request_body
|
self.state = self.read_request_body
|
||||||
yield from self.state(event)
|
yield from self.state(event)
|
||||||
|
else:
|
||||||
|
pass # FIXME: protect against header size DoS
|
||||||
elif isinstance(event, events.ConnectionClosed):
|
elif isinstance(event, events.ConnectionClosed):
|
||||||
pass # TODO: Better handling, tear everything down.
|
if bytes(self.buf).strip():
|
||||||
|
yield commands.Log(f"Client closed connection before sending request headers: {bytes(self.buf)}")
|
||||||
|
yield commands.CloseConnection(self.conn)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected event: {event}")
|
raise ValueError(f"Unexpected event: {event}")
|
||||||
|
|
||||||
def read_request_body(self, event: events.Event) -> typing.Iterator[HttpEvent]:
|
def read_request_body(self, event: events.Event) -> typing.Iterator[HttpEvent]:
|
||||||
for e in self.read_body(event, True):
|
for e in self.read_body(event, True):
|
||||||
if isinstance(e, RequestEndOfMessage):
|
|
||||||
self.state = self.wait
|
|
||||||
yield from self.state(event)
|
|
||||||
yield e
|
yield e
|
||||||
|
if isinstance(e, RequestEndOfMessage):
|
||||||
|
yield from self.mark_done(request=True)
|
||||||
|
|
||||||
|
|
||||||
class Http1Client(Http1Connection):
|
class Http1Client(Http1Connection):
|
||||||
@ -188,25 +211,40 @@ class Http1Client(Http1Connection):
|
|||||||
|
|
||||||
if isinstance(event, RequestHeaders):
|
if isinstance(event, RequestHeaders):
|
||||||
raw = http1.assemble_request_head(event.request)
|
raw = http1.assemble_request_head(event.request)
|
||||||
|
yield commands.SendData(self.conn, raw)
|
||||||
elif isinstance(event, RequestData):
|
elif isinstance(event, RequestData):
|
||||||
if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
|
if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
|
||||||
raw = b"%x\r\n%s\r\n" % (len(event.data), event.data)
|
raw = b"%x\r\n%s\r\n" % (len(event.data), event.data)
|
||||||
else:
|
else:
|
||||||
raw = event.data
|
raw = event.data
|
||||||
|
yield commands.SendData(self.conn, raw)
|
||||||
elif isinstance(event, RequestEndOfMessage):
|
elif isinstance(event, RequestEndOfMessage):
|
||||||
if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
|
if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
|
||||||
raw = b"0\r\n\r\n"
|
yield commands.SendData(self.conn, b"0\r\n\r\n")
|
||||||
elif http1.expected_http_body_size(self.request) == -1:
|
elif http1.expected_http_body_size(self.request) == -1:
|
||||||
assert not self.send_queue
|
assert not self.send_queue
|
||||||
yield commands.CloseConnection(self.conn)
|
yield commands.CloseConnection(self.conn)
|
||||||
|
yield from self.mark_done(request=True)
|
||||||
|
elif isinstance(event, RequestProtocolError):
|
||||||
|
yield commands.CloseConnection(self.conn)
|
||||||
return
|
return
|
||||||
else:
|
|
||||||
raw = False
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"{event}")
|
raise NotImplementedError(f"{event}")
|
||||||
|
|
||||||
if raw:
|
def mark_done(self, *, request: bool = False, response: bool = False):
|
||||||
yield commands.SendData(self.conn, raw)
|
if request:
|
||||||
|
self.request_done = True
|
||||||
|
if response:
|
||||||
|
self.response_done = True
|
||||||
|
if self.request_done and self.response_done:
|
||||||
|
self.request_done = self.response_done = False
|
||||||
|
self.request = self.response = None
|
||||||
|
self.stream_id = None
|
||||||
|
if self.send_queue:
|
||||||
|
send_queue = self.send_queue
|
||||||
|
self.send_queue = []
|
||||||
|
for ev in send_queue:
|
||||||
|
yield from self.send(ev)
|
||||||
|
|
||||||
def read_response_headers(self, event: events.Event) -> typing.Iterator[HttpEvent]:
|
def read_response_headers(self, event: events.Event) -> typing.Iterator[HttpEvent]:
|
||||||
assert isinstance(event, events.ConnectionEvent)
|
assert isinstance(event, events.ConnectionEvent)
|
||||||
@ -216,18 +254,27 @@ class Http1Client(Http1Connection):
|
|||||||
if response_head:
|
if response_head:
|
||||||
response_head = [bytes(x) for x in response_head]
|
response_head = [bytes(x) for x in response_head]
|
||||||
self.response = http.HTTPResponse.wrap(http1_sansio.read_response_head(response_head))
|
self.response = http.HTTPResponse.wrap(http1_sansio.read_response_head(response_head))
|
||||||
yield ResponseHeaders(self.response, self.stream_id)
|
yield ResponseHeaders(self.stream_id, self.response)
|
||||||
|
|
||||||
expected_size = http1.expected_http_body_size(self.request, self.response)
|
expected_size = http1.expected_http_body_size(self.request, self.response)
|
||||||
self.body_reader = self.make_body_reader(expected_size)
|
self.body_reader = self.make_body_reader(expected_size)
|
||||||
|
|
||||||
self.state = self.read_response_body
|
self.state = self.read_response_body
|
||||||
yield from self.state(event)
|
yield from self.state(event)
|
||||||
|
else:
|
||||||
|
pass # FIXME: protect against header size DoS
|
||||||
elif isinstance(event, events.ConnectionClosed):
|
elif isinstance(event, events.ConnectionClosed):
|
||||||
if self.stream_id:
|
if self.stream_id:
|
||||||
raise NotImplementedError(f"{event}")
|
if self.buf:
|
||||||
|
yield ResponseProtocolError(self.stream_id, f"unexpected server response: {bytes(self.buf)}")
|
||||||
|
else:
|
||||||
|
# The server has closed the connection to prevent us from continuing.
|
||||||
|
# We need to signal that to the stream.
|
||||||
|
# https://tools.ietf.org/html/rfc7231#section-6.5.11
|
||||||
|
yield ResponseProtocolError(self.stream_id, "server closed connection")
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
yield commands.CloseConnection(self.conn)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected event: {event}")
|
raise ValueError(f"Unexpected event: {event}")
|
||||||
|
|
||||||
@ -237,14 +284,7 @@ class Http1Client(Http1Connection):
|
|||||||
yield e
|
yield e
|
||||||
if isinstance(e, ResponseEndOfMessage):
|
if isinstance(e, ResponseEndOfMessage):
|
||||||
self.state = self.read_response_headers
|
self.state = self.read_response_headers
|
||||||
self.stream_id = None
|
yield from self.mark_done(response=True)
|
||||||
self.request = None
|
|
||||||
self.response = None
|
|
||||||
if self.send_queue:
|
|
||||||
send_queue = self.send_queue
|
|
||||||
self.send_queue = []
|
|
||||||
for ev in send_queue:
|
|
||||||
yield from self.send(ev)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from mitmproxy.http import HTTPResponse, HTTPFlow
|
from mitmproxy.http import HTTPFlow, HTTPResponse
|
||||||
from mitmproxy.proxy.protocol.http import HTTPMode
|
from mitmproxy.proxy.protocol.http import HTTPMode
|
||||||
from mitmproxy.proxy2 import layer
|
from mitmproxy.proxy2 import layer
|
||||||
from mitmproxy.proxy2.commands import OpenConnection, SendData
|
from mitmproxy.proxy2.commands import CloseConnection, OpenConnection, SendData
|
||||||
from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
|
from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
|
||||||
from mitmproxy.proxy2.layers import http, tls
|
from mitmproxy.proxy2.layers import http, tls
|
||||||
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply, reply_establish_server_tls, reply_next_layer
|
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply, reply_establish_server_tls, reply_next_layer
|
||||||
@ -168,6 +168,22 @@ def test_http_reply_from_proxy(tctx):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_response_until_eof(tctx):
|
||||||
|
"""Test scenario where the server response body is terminated by EOF."""
|
||||||
|
server = Placeholder()
|
||||||
|
assert (
|
||||||
|
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
||||||
|
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||||
|
<< OpenConnection(server)
|
||||||
|
>> reply(None)
|
||||||
|
<< SendData(server, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||||
|
>> DataReceived(server, b"HTTP/1.1 200 OK\r\n\r\nfoo")
|
||||||
|
>> ConnectionClosed(server)
|
||||||
|
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\n\r\nfoo")
|
||||||
|
<< CloseConnection(tctx.client)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_disconnect_while_intercept(tctx):
|
def test_disconnect_while_intercept(tctx):
|
||||||
"""Test a server disconnect while a request is intercepted."""
|
"""Test a server disconnect while a request is intercepted."""
|
||||||
tctx.options.connection_strategy = "eager"
|
tctx.options.connection_strategy = "eager"
|
||||||
@ -198,3 +214,130 @@ def test_disconnect_while_intercept(tctx):
|
|||||||
)
|
)
|
||||||
assert server1() != server2()
|
assert server1() != server2()
|
||||||
assert flow().server_conn == server2()
|
assert flow().server_conn == server2()
|
||||||
|
|
||||||
|
|
||||||
|
def test_response_streaming(tctx):
|
||||||
|
"""Test HTTP response streaming"""
|
||||||
|
server = Placeholder()
|
||||||
|
flow = Placeholder()
|
||||||
|
|
||||||
|
def enable_streaming(flow: HTTPFlow):
|
||||||
|
flow.response.stream = lambda x: x.upper()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
||||||
|
>> DataReceived(tctx.client, b"GET http://example.com/largefile HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||||
|
<< OpenConnection(server)
|
||||||
|
>> reply(None)
|
||||||
|
<< SendData(server, b"GET /largefile HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||||
|
>> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nabc")
|
||||||
|
<< http.HttpResponseHeadersHook(flow)
|
||||||
|
>> reply(side_effect=enable_streaming)
|
||||||
|
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nABC")
|
||||||
|
>> DataReceived(server, b"def")
|
||||||
|
<< SendData(tctx.client, b"DEF")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("response", ["normal response", "early response", "early close", "early kill"])
|
||||||
|
def test_request_streaming(tctx, response):
|
||||||
|
"""
|
||||||
|
Test HTTP request streaming
|
||||||
|
|
||||||
|
This is a bit more contrived as we may receive server data while we are still sending the request.
|
||||||
|
"""
|
||||||
|
server = Placeholder()
|
||||||
|
flow = Placeholder()
|
||||||
|
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
||||||
|
|
||||||
|
def enable_streaming(flow: HTTPFlow):
|
||||||
|
flow.request.stream = lambda x: x.upper()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
>> DataReceived(tctx.client, b"POST http://example.com/ HTTP/1.1\r\n"
|
||||||
|
b"Host: example.com\r\n"
|
||||||
|
b"Content-Length: 6\r\n\r\n"
|
||||||
|
b"abc")
|
||||||
|
<< http.HttpRequestHeadersHook(flow)
|
||||||
|
>> reply(side_effect=enable_streaming)
|
||||||
|
<< OpenConnection(server)
|
||||||
|
>> reply(None)
|
||||||
|
<< SendData(server, b"POST / HTTP/1.1\r\n"
|
||||||
|
b"Host: example.com\r\n"
|
||||||
|
b"Content-Length: 6\r\n\r\n"
|
||||||
|
b"ABC")
|
||||||
|
)
|
||||||
|
if response == "normal response":
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
>> DataReceived(tctx.client, b"def")
|
||||||
|
<< SendData(server, b"DEF")
|
||||||
|
>> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
|
||||||
|
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
|
||||||
|
)
|
||||||
|
elif response == "early response":
|
||||||
|
# We may receive a response before we have finished sending our request.
|
||||||
|
# We continue sending unless the server closes the connection.
|
||||||
|
# https://tools.ietf.org/html/rfc7231#section-6.5.11
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
>> DataReceived(server, b"HTTP/1.1 413 Request Entity Too Large\r\nContent-Length: 0\r\n\r\n")
|
||||||
|
<< SendData(tctx.client, b"HTTP/1.1 413 Request Entity Too Large\r\nContent-Length: 0\r\n\r\n")
|
||||||
|
>> DataReceived(tctx.client, b"def")
|
||||||
|
<< SendData(server, b"DEF")
|
||||||
|
)
|
||||||
|
elif response == "early close":
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
>> DataReceived(server, b"HTTP/1.1 413 Request Entity Too Large\r\nContent-Length: 0\r\n\r\n")
|
||||||
|
<< SendData(tctx.client, b"HTTP/1.1 413 Request Entity Too Large\r\nContent-Length: 0\r\n\r\n")
|
||||||
|
>> ConnectionClosed(server)
|
||||||
|
<< CloseConnection(server)
|
||||||
|
<< CloseConnection(tctx.client)
|
||||||
|
)
|
||||||
|
elif response == "early kill":
|
||||||
|
err = Placeholder()
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
>> ConnectionClosed(server)
|
||||||
|
<< CloseConnection(server)
|
||||||
|
<< SendData(tctx.client, err)
|
||||||
|
<< CloseConnection(tctx.client)
|
||||||
|
)
|
||||||
|
assert b"502 Bad Gateway" in err()
|
||||||
|
else: # pragma: no cover
|
||||||
|
assert False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("data", [
|
||||||
|
None,
|
||||||
|
b"I don't speak HTTP.",
|
||||||
|
b"HTTP/1.1 200 OK\r\nContent-Length: 10\r\n\r\nweee"
|
||||||
|
])
|
||||||
|
def test_server_aborts(tctx, data):
|
||||||
|
"""Test the scenario where the server doesn't serve a response"""
|
||||||
|
server = Placeholder()
|
||||||
|
flow = Placeholder()
|
||||||
|
err = Placeholder()
|
||||||
|
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||||
|
<< OpenConnection(server)
|
||||||
|
>> reply(None)
|
||||||
|
<< SendData(server, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||||
|
)
|
||||||
|
if data:
|
||||||
|
playbook >> DataReceived(server, data)
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
>> ConnectionClosed(server)
|
||||||
|
<< CloseConnection(server)
|
||||||
|
<< http.HttpErrorHook(flow)
|
||||||
|
>> reply()
|
||||||
|
<< SendData(tctx.client, err)
|
||||||
|
<< CloseConnection(tctx.client)
|
||||||
|
)
|
||||||
|
assert flow().error
|
||||||
|
assert b"502 Bad Gateway" in err()
|
||||||
|
Loading…
Reference in New Issue
Block a user