[sans-io] implement http streaming, refine error handling

This commit is contained in:
Maximilian Hils 2019-11-25 04:13:21 +01:00
parent 5671012163
commit 0740c673bd
8 changed files with 460 additions and 237 deletions

View File

@ -112,6 +112,9 @@ class Hook(Command):
all_hooks: typing.Dict[str, typing.Type[Hook]] = {}
# TODO: Move descriptions from addons/events.py into hooks and have hook documentation generated from all_hooks.
class GetSocket(ConnectionCommand):
"""
Get the underlying socket.

View File

@ -5,6 +5,7 @@ The counterpart to events are commands.
"""
import socket
import typing
from dataclasses import dataclass
from mitmproxy.proxy2 import commands
from mitmproxy.proxy2.context import Connection
@ -27,15 +28,13 @@ class Start(Event):
pass
@dataclass
class ConnectionEvent(Event):
"""
All events involving connection IO.
"""
connection: Connection
def __init__(self, connection: Connection):
self.connection = connection
class ConnectionClosed(ConnectionEvent):
"""
@ -44,20 +43,19 @@ class ConnectionClosed(ConnectionEvent):
pass
@dataclass
class DataReceived(ConnectionEvent):
"""
Remote has sent some data.
"""
def __init__(self, connection: Connection, data: bytes) -> None:
super().__init__(connection)
self.data = data
data: bytes
def __repr__(self):
target = type(self.connection).__name__.lower()
return f"DataReceived({target}, {self.data})"
@dataclass
class CommandReply(Event):
"""
Emitted when a command has been finished, e.g.
@ -66,10 +64,6 @@ class CommandReply(Event):
command: commands.Command
reply: typing.Any
def __init__(self, command: commands.Command, reply: typing.Any):
self.command = command
self.reply = reply
def __new__(cls, *args, **kwargs):
if cls is CommandReply:
raise TypeError("CommandReply may not be instantiated directly.")
@ -88,35 +82,23 @@ class CommandReply(Event):
command_reply_subclasses: typing.Dict[commands.Command, typing.Type[CommandReply]] = {}
@dataclass
class OpenConnectionReply(CommandReply):
command: commands.OpenConnection
reply: typing.Optional[str]
def __init__(
self,
command: commands.OpenConnection,
err: typing.Optional[str]
):
super().__init__(command, err)
"""error message"""
@dataclass
class HookReply(CommandReply):
command: commands.Hook
def __init__(self, command: commands.Hook):
super().__init__(command, None)
reply: None = None
def __repr__(self):
return f"HookReply({repr(self.command)[5:-1]})"
@dataclass
class GetSocketReply(CommandReply):
command: commands.GetSocket
reply: socket.socket
def __init__(
self,
command: commands.GetSocket,
socket: socket.socket
):
super().__init__(command, socket)

View File

@ -10,8 +10,10 @@ from mitmproxy.proxy2.layers.tls import EstablishServerTLS, EstablishServerTLSRe
from mitmproxy.proxy2.utils import expect
from mitmproxy.utils import human
from ._base import HttpConnection, StreamId
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, ResponseData, ResponseEndOfMessage, \
ResponseHeaders
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \
HttpResponseHook
from ._http1 import Http1Client, Http1Server
from ._http2 import Http2Client
@ -42,8 +44,8 @@ class GetHttpConnection(HttpCommand):
class GetHttpConnectionReply(events.CommandReply):
command: GetHttpConnection
reply: typing.Optional[str]
"""error message"""
reply: typing.Tuple[typing.Optional[Connection], typing.Optional[str]]
"""connection object, error message"""
class SendHttp(HttpCommand):
@ -58,35 +60,6 @@ class SendHttp(HttpCommand):
return f"Send({self.event})"
class HttpRequestHeadersHook(commands.Hook):
name = "requestheaders"
flow: http.HTTPFlow
class HttpRequestHook(commands.Hook):
name = "request"
flow: http.HTTPFlow
class HttpResponseHook(commands.Hook):
name = "response"
flow: http.HTTPFlow
class HttpResponseHeadersHook(commands.Hook):
name = "responseheaders"
flow: http.HTTPFlow
class HttpConnectHook(commands.Hook):
flow: http.HTTPFlow
class HttpErrorHook(commands.Hook):
name = "error"
flow: http.HTTPFlow
class HttpStream(Layer):
request_body_buf: bytes
response_body_buf: bytes
@ -103,15 +76,22 @@ class HttpStream(Layer):
super().__init__(context)
self.request_body_buf = b""
self.response_body_buf = b""
self._handle_event = self.start
self.client_state = self.state_uninitialized
self.server_state = self.state_uninitialized
@expect(events.Start)
def start(self, event: events.Event) -> commands.TCommandGenerator:
self._handle_event = self.read_request_headers
yield from ()
@expect(events.Start, HttpEvent)
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
if isinstance(event, events.Start):
self.client_state = self.state_wait_for_request_headers
elif isinstance(event, (RequestProtocolError, ResponseProtocolError)):
yield from self.handle_protocol_error(event)
elif isinstance(event, (RequestHeaders, RequestData, RequestEndOfMessage)):
yield from self.client_state(event)
else:
yield from self.server_state(event)
@expect(RequestHeaders)
def read_request_headers(self, event: RequestHeaders) -> commands.TCommandGenerator:
def state_wait_for_request_headers(self, event: RequestHeaders) -> commands.TCommandGenerator:
self.stream_id = event.stream_id
self.flow = http.HTTPFlow(
self.context.client,
@ -122,72 +102,12 @@ class HttpStream(Layer):
if self.flow.request.first_line_format == "authority":
yield from self.handle_connect()
return
else:
yield HttpRequestHeadersHook(self.flow)
if self.flow.request.headers.get("expect", "").lower() == "100-continue":
raise NotImplementedError("expect nothing")
# self.send_response(http.expect_continue_response)
# request.headers.pop("expect")
if self.flow.request.stream:
raise NotImplementedError # FIXME
else:
self._handle_event = self.read_request_body
@expect(RequestData, RequestEndOfMessage)
def read_request_body(self, event: events.Event) -> commands.TCommandGenerator:
if isinstance(event, RequestData):
self.request_body_buf += event.data
elif isinstance(event, RequestEndOfMessage):
self.flow.request.data.content = self.request_body_buf
self.request_body_buf = b""
yield from self.handle_request()
def handle_connect(self) -> commands.TCommandGenerator:
yield HttpConnectHook(self.flow)
self.context.server = Server((self.flow.request.host, self.flow.request.port))
if self.context.options.connection_strategy == "eager":
err = yield commands.OpenConnection(self.context.server)
if err:
self.flow.response = http.HTTPResponse.make(
502, f"Cannot connect to {human.format_address(self.context.server.address)}: {err}"
)
if not self.flow.response:
self.flow.response = http.make_connect_response(self.flow.request.data.http_version)
yield SendHttp(ResponseHeaders(self.flow.response, self.stream_id), self.context.client)
if 200 <= self.flow.response.status_code < 300:
self.child_layer = NextLayer(self.context)
yield from self.child_layer.handle_event(events.Start())
self._handle_event = self.passthrough
else:
yield SendHttp(ResponseData(self.flow.response.data.content, self.stream_id), self.context.client)
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
@expect(RequestData, RequestEndOfMessage, events.Event)
def passthrough(self, event: events.Event) -> commands.TCommandGenerator:
# HTTP events -> normal connection events
if isinstance(event, RequestData):
event = events.DataReceived(self.context.client, event.data)
elif isinstance(event, RequestEndOfMessage):
event = events.ConnectionClosed(self.context.client)
for command in self.child_layer.handle_event(event):
# normal connection events -> HTTP events
if isinstance(command, commands.SendData) and command.connection == self.context.client:
yield SendHttp(ResponseData(command.data, self.stream_id), self.context.client)
elif isinstance(command, commands.CloseConnection) and command.connection == self.context.client:
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
elif isinstance(command, commands.OpenConnection) and command.connection == self.context.server:
yield from self.passthrough(events.OpenConnectionReply(command, None))
else:
yield command
def handle_request(self) -> commands.TCommandGenerator:
# set first line format to relative in regular mode,
# see https://github.com/mitmproxy/mitmproxy/issues/1759
if self.mode is HTTPMode.regular and self.flow.request.first_line_format == "absolute":
@ -210,68 +130,174 @@ class HttpStream(Layer):
self.flow.request.port = self.context.server.address[1]
self.flow.request.host_header = host_header # set again as .host overwrites this.
self.flow.request.scheme = "https" if self.context.server.tls else "http"
yield HttpRequestHook(self.flow)
yield HttpRequestHeadersHook(self.flow)
if self.flow.request.stream:
if self.flow.response:
raise NotImplementedError("Can't set a response and enable streaming at the same time.")
ok = yield from self.make_server_connection()
if not ok:
return
yield SendHttp(RequestHeaders(self.stream_id, self.flow.request), self.context.server)
self.client_state = self.state_stream_request_body
else:
self.client_state = self.state_consume_request_body
self.server_state = self.state_wait_for_response_headers
@expect(RequestData, RequestEndOfMessage)
def state_stream_request_body(self, event: events.Event) -> commands.TCommandGenerator:
if isinstance(event, RequestData):
if callable(self.flow.request.stream):
data = self.flow.request.stream(event.data)
else:
data = event.data
yield SendHttp(RequestData(self.stream_id, data), self.context.server)
elif isinstance(event, RequestEndOfMessage):
yield SendHttp(RequestEndOfMessage(self.stream_id), self.context.server)
self.client_state = self.state_done
@expect(RequestData, RequestEndOfMessage)
def state_consume_request_body(self, event: events.Event) -> commands.TCommandGenerator:
if isinstance(event, RequestData):
self.request_body_buf += event.data
elif isinstance(event, RequestEndOfMessage):
self.flow.request.data.content = self.request_body_buf
self.request_body_buf = b""
yield HttpRequestHook(self.flow)
if self.flow.response:
# response was set by an inline script.
# we now need to emulate the responseheaders hook.
yield HttpResponseHeadersHook(self.flow)
yield from self.handle_response()
yield from self.send_response()
else:
connection, err = yield GetHttpConnection(
(self.flow.request.host, self.flow.request.port),
self.flow.request.scheme == "https"
)
if err:
yield from self.send_error_response(502, err)
self.flow.error = flow.Error(err)
yield HttpErrorHook(self.flow)
ok = yield from self.make_server_connection()
if not ok:
return
else:
self.flow.server_conn = connection
yield SendHttp(RequestHeaders(self.flow.request, self.stream_id), connection)
yield SendHttp(RequestHeaders(self.stream_id, self.flow.request), self.context.server)
yield SendHttp(RequestData(self.stream_id, self.flow.request.data.content), self.context.server)
yield SendHttp(RequestEndOfMessage(self.stream_id), self.context.server)
if self.flow.request.stream:
raise NotImplementedError
else:
yield SendHttp(RequestData(self.flow.request.data.content, self.stream_id), connection)
yield SendHttp(RequestEndOfMessage(self.stream_id), connection)
self._handle_event = self.read_response_headers
self.client_state = self.state_done
@expect(ResponseHeaders)
def read_response_headers(self, event: ResponseHeaders) -> commands.TCommandGenerator:
def state_wait_for_response_headers(self, event: ResponseHeaders) -> commands.TCommandGenerator:
self.flow.response = event.response
yield HttpResponseHeadersHook(self.flow)
if not self.flow.response.stream:
self._handle_event = self.read_response_body
if self.flow.response.stream:
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response), self.context.client)
self.server_state = self.state_stream_response_body
else:
raise NotImplementedError
self.server_state = self.state_consume_response_body
@expect(ResponseData, ResponseEndOfMessage)
def read_response_body(self, event: events.Event) -> commands.TCommandGenerator:
def state_stream_response_body(self, event: events.Event) -> commands.TCommandGenerator:
if isinstance(event, ResponseData):
if callable(self.flow.response.stream):
data = self.flow.response.stream(event.data)
else:
data = event.data
yield SendHttp(ResponseData(self.stream_id, data), self.context.client)
elif isinstance(event, ResponseEndOfMessage):
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
self.server_state = self.state_done
@expect(ResponseData, ResponseEndOfMessage)
def state_consume_response_body(self, event: events.Event) -> commands.TCommandGenerator:
if isinstance(event, ResponseData):
self.response_body_buf += event.data
elif isinstance(event, ResponseEndOfMessage):
self.flow.response.data.content = self.response_body_buf
self.response_body_buf = b""
yield from self.handle_response()
yield from self.send_response()
def handle_response(self):
def send_response(self):
yield HttpResponseHook(self.flow)
yield SendHttp(ResponseHeaders(self.flow.response, self.stream_id), self.context.client)
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response), self.context.client)
yield SendHttp(ResponseData(self.stream_id, self.flow.response.data.content), self.context.client)
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
self.server_state = self.state_done
if self.flow.response.stream:
raise NotImplementedError
def handle_protocol_error(
self,
event: typing.Union[RequestProtocolError, ResponseProtocolError]
) -> commands.TCommandGenerator:
self.flow.error = flow.Error(event.message)
yield HttpErrorHook(self.flow)
if isinstance(event, RequestProtocolError):
yield SendHttp(event, self.context.server)
else:
yield SendHttp(ResponseData(self.flow.response.data.content, self.stream_id), self.context.client)
yield SendHttp(event, self.context.client)
return
def make_server_connection(self) -> typing.Generator[commands.Command, typing.Any, bool]:
connection, err = yield GetHttpConnection(
(self.flow.request.host, self.flow.request.port),
self.flow.request.scheme == "https"
)
if err:
yield from self.handle_protocol_error(ResponseProtocolError(self.stream_id, err))
return False
else:
self.context.server = self.flow.server_conn = connection
return True
def handle_connect(self) -> commands.TCommandGenerator:
yield HttpConnectHook(self.flow)
self.context.server = Server((self.flow.request.host, self.flow.request.port))
if self.context.options.connection_strategy == "eager":
err = yield commands.OpenConnection(self.context.server)
if err:
self.flow.response = http.HTTPResponse.make(
502, f"Cannot connect to {human.format_address(self.context.server.address)}: {err}"
)
if not self.flow.response:
self.flow.response = http.make_connect_response(self.flow.request.data.http_version)
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response), self.context.client)
if 200 <= self.flow.response.status_code < 300:
self.child_layer = NextLayer(self.context)
yield from self.child_layer.handle_event(events.Start())
self._handle_event = self.passthrough
else:
yield SendHttp(ResponseData(self.stream_id, self.flow.response.data.content), self.context.client)
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
def send_error_response(self, status_code: int, message: str, headers=None):
response = http.make_error_response(status_code, message, headers)
yield SendHttp(ResponseHeaders(response, self.stream_id), self.context.client)
yield SendHttp(ResponseData(response.data.content, self.stream_id), self.context.client)
@expect(RequestData, RequestEndOfMessage, events.Event)
def passthrough(self, event: events.Event) -> commands.TCommandGenerator:
# HTTP events -> normal connection events
if isinstance(event, RequestData):
event = events.DataReceived(self.context.client, event.data)
elif isinstance(event, RequestEndOfMessage):
event = events.ConnectionClosed(self.context.client)
for command in self.child_layer.handle_event(event):
# normal connection events -> HTTP events
if isinstance(command, commands.SendData) and command.connection == self.context.client:
yield SendHttp(ResponseData(self.stream_id, command.data), self.context.client)
elif isinstance(command, commands.CloseConnection) and command.connection == self.context.client:
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
elif isinstance(command, commands.OpenConnection) and command.connection == self.context.server:
yield from self.passthrough(events.OpenConnectionReply(command, None))
else:
yield command
@expect()
def state_uninitialized(self, _) -> commands.TCommandGenerator:
yield from ()
@expect()
def state_done(self, _) -> commands.TCommandGenerator:
yield from ()
def state_errored(self, _) -> commands.TCommandGenerator:
# silently consume every event.
yield from ()
class HTTPLayer(Layer):
@ -304,8 +330,6 @@ class HTTPLayer(Layer):
self.connections = {
context.client: Http1Server(context.client)
}
if self.context.server.connected:
self.make_http_connection(self.context.server)
def __repr__(self):
return f"HTTPLayer(conns: {len(self.connections)}, events: {[type(e).__name__ for e in self.event_queue]})"
@ -322,8 +346,6 @@ class HTTPLayer(Layer):
self.event_to_child(stream, GetHttpConnectionReply(cmd, (None, event.reply)))
else:
yield from self.make_http_connection(event.command.connection)
elif isinstance(event, EstablishServerTLSReply) and event.command.connection in self.waiting_for_connection:
yield from self.make_http_connection(event.command.connection)
elif isinstance(event, events.CommandReply):
try:
stream = self.stream_by_command.pop(event.command)

View File

@ -1,18 +1,16 @@
import abc
import typing
from dataclasses import dataclass
from mitmproxy.proxy2 import commands, events
StreamId = int
@dataclass
class HttpEvent(events.Event):
stream_id: StreamId
# we need stream ids on every event to avoid race conditions
def __init__(self, stream_id: StreamId):
self.stream_id = stream_id
stream_id: StreamId
def __repr__(self) -> str:
x = self.__dict__.copy()

View File

@ -1,38 +1,28 @@
from dataclasses import dataclass
from mitmproxy import http
from ._base import HttpEvent, StreamId
from ._base import HttpEvent
@dataclass
class RequestHeaders(HttpEvent):
request: http.HTTPRequest
def __init__(self, request: http.HTTPRequest, stream_id: StreamId):
super().__init__(stream_id)
self.request = request
@dataclass
class ResponseHeaders(HttpEvent):
response: http.HTTPResponse
def __init__(self, response: http.HTTPResponse, stream_id: StreamId):
super().__init__(stream_id)
self.response = response
@dataclass
class RequestData(HttpEvent):
data: bytes
def __init__(self, data: bytes, stream_id: StreamId):
super().__init__(stream_id)
self.data = data
@dataclass
class ResponseData(HttpEvent):
data: bytes
def __init__(self, data: bytes, stream_id: StreamId):
super().__init__(stream_id)
self.data = data
class RequestEndOfMessage(HttpEvent):
pass
@ -42,6 +32,18 @@ class ResponseEndOfMessage(HttpEvent):
pass
@dataclass
class RequestProtocolError(HttpEvent):
message: str
code: int = 400
@dataclass
class ResponseProtocolError(HttpEvent):
message: str
code: int = 502
__all__ = [
"HttpEvent",
"RequestHeaders",
@ -50,4 +52,6 @@ __all__ = [
"ResponseHeaders",
"ResponseData",
"ResponseEndOfMessage",
"RequestProtocolError",
"ResponseProtocolError",
]

View File

@ -0,0 +1,31 @@
from mitmproxy import http
from mitmproxy.proxy2 import commands
class HttpRequestHeadersHook(commands.Hook):
name = "requestheaders"
flow: http.HTTPFlow
class HttpRequestHook(commands.Hook):
name = "request"
flow: http.HTTPFlow
class HttpResponseHook(commands.Hook):
name = "response"
flow: http.HTTPFlow
class HttpResponseHeadersHook(commands.Hook):
name = "responseheaders"
flow: http.HTTPFlow
class HttpConnectHook(commands.Hook):
flow: http.HTTPFlow
class HttpErrorHook(commands.Hook):
name = "error"
flow: http.HTTPFlow

View File

@ -12,17 +12,19 @@ from mitmproxy.proxy2 import commands, events
from mitmproxy.proxy2.context import Client, Connection, Server
from mitmproxy.proxy2.layers.http._base import StreamId
from ._base import HttpConnection
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, ResponseData, ResponseEndOfMessage, \
ResponseHeaders
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
TBodyReader = typing.Union[ChunkedReader, Http10Reader, ContentLengthReader]
class Http1Connection(HttpConnection):
conn: Connection
stream_id: StreamId = None
request: http.HTTPRequest
response: http.HTTPResponse
stream_id: typing.Optional[StreamId] = None
request: typing.Optional[http.HTTPRequest] = None
response: typing.Optional[http.HTTPResponse] = None
request_done: bool = False
response_done: bool = False
state: typing.Callable[[events.Event], typing.Iterator[HttpEvent]]
body_reader: TBodyReader
buf: ReceiveBuffer
@ -58,17 +60,22 @@ class Http1Connection(HttpConnection):
h11_event = self.body_reader.read_eof()
else:
raise ValueError(f"Unexpected event: {event}")
except h11.ProtocolError:
raise # FIXME
except h11.ProtocolError as e:
yield commands.CloseConnection(self.conn)
if is_request:
yield RequestProtocolError(self.stream_id, str(e))
else:
yield ResponseProtocolError(self.stream_id, str(e))
return
if h11_event is None:
return
elif isinstance(h11_event, h11.Data):
h11_event.data: bytearray # type checking
if is_request:
yield RequestData(bytes(h11_event.data), self.stream_id)
yield RequestData(self.stream_id, bytes(h11_event.data))
else:
yield ResponseData(bytes(h11_event.data), self.stream_id)
yield ResponseData(self.stream_id, bytes(h11_event.data))
elif isinstance(h11_event, h11.EndOfMessage):
if is_request:
yield RequestEndOfMessage(self.stream_id)
@ -96,13 +103,15 @@ class Http1Server(Http1Connection):
def __init__(self, conn: Client):
super().__init__(conn)
self.stream_id = 1
self.state = self.read_request_headers
self.stream_id = -1
def send(self, event: HttpEvent) -> commands.TCommandGenerator:
assert event.stream_id == self.stream_id
if isinstance(event, ResponseHeaders):
self.response = event.response
raw = http1.assemble_response_head(event.response)
yield commands.SendData(self.conn, raw)
if self.request.first_line_format == "authority":
assert self.state == self.wait
self.body_reader = self.make_body_reader(-1)
@ -113,24 +122,35 @@ class Http1Server(Http1Connection):
raw = b"%x\r\n%s\r\n" % (len(event.data), event.data)
else:
raw = event.data
yield commands.SendData(self.conn, raw)
elif isinstance(event, ResponseEndOfMessage):
if "chunked" in self.response.headers.get("transfer-encoding", "").lower():
raw = b"0\r\n\r\n"
yield commands.SendData(self.conn, b"0\r\n\r\n")
elif http1.expected_http_body_size(self.request, self.response) == -1:
yield commands.CloseConnection(self.conn)
return
else:
raw = False
self.request = None
self.response = None
self.stream_id += 2
self.state = self.read_request_headers
yield from self.state(events.DataReceived(self.conn, b""))
yield from self.mark_done(response=True)
elif isinstance(event, ResponseProtocolError):
if not self.response:
resp = http.make_error_response(event.code, event.message)
raw = http1.assemble_response(resp)
yield commands.SendData(self.conn, raw)
yield commands.CloseConnection(self.conn)
else:
raise NotImplementedError(f"{event}")
if raw:
yield commands.SendData(self.conn, raw)
def mark_done(self, *, request: bool = False, response: bool = False):
if request:
self.request_done = True
if response:
self.response_done = True
if self.request_done and self.response_done:
self.request_done = self.response_done = False
self.request = self.response = None
self.stream_id += 2
self.state = self.read_request_headers
yield from self.state(events.DataReceived(self.conn, b""))
elif self.request_done:
self.state = self.wait
def read_request_headers(self, event: events.Event) -> typing.Iterator[HttpEvent]:
if isinstance(event, events.DataReceived):
@ -138,7 +158,7 @@ class Http1Server(Http1Connection):
if request_head:
request_head = [bytes(x) for x in request_head] # TODO: Make url.parse compatible with bytearrays
self.request = http.HTTPRequest.wrap(http1_sansio.read_request_head(request_head))
yield RequestHeaders(self.request, self.stream_id)
yield RequestHeaders(self.stream_id, self.request)
if self.request.first_line_format == "authority":
# The previous proxy server implementation tried to read the request body here:
@ -151,17 +171,20 @@ class Http1Server(Http1Connection):
self.body_reader = self.make_body_reader(expected_size)
self.state = self.read_request_body
yield from self.state(event)
else:
pass # FIXME: protect against header size DoS
elif isinstance(event, events.ConnectionClosed):
pass # TODO: Better handling, tear everything down.
if bytes(self.buf).strip():
yield commands.Log(f"Client closed connection before sending request headers: {bytes(self.buf)}")
yield commands.CloseConnection(self.conn)
else:
raise ValueError(f"Unexpected event: {event}")
def read_request_body(self, event: events.Event) -> typing.Iterator[HttpEvent]:
for e in self.read_body(event, True):
if isinstance(e, RequestEndOfMessage):
self.state = self.wait
yield from self.state(event)
yield e
if isinstance(e, RequestEndOfMessage):
yield from self.mark_done(request=True)
class Http1Client(Http1Connection):
@ -188,25 +211,40 @@ class Http1Client(Http1Connection):
if isinstance(event, RequestHeaders):
raw = http1.assemble_request_head(event.request)
yield commands.SendData(self.conn, raw)
elif isinstance(event, RequestData):
if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
raw = b"%x\r\n%s\r\n" % (len(event.data), event.data)
else:
raw = event.data
yield commands.SendData(self.conn, raw)
elif isinstance(event, RequestEndOfMessage):
if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
raw = b"0\r\n\r\n"
yield commands.SendData(self.conn, b"0\r\n\r\n")
elif http1.expected_http_body_size(self.request) == -1:
assert not self.send_queue
yield commands.CloseConnection(self.conn)
yield from self.mark_done(request=True)
elif isinstance(event, RequestProtocolError):
yield commands.CloseConnection(self.conn)
return
else:
raw = False
else:
raise NotImplementedError(f"{event}")
if raw:
yield commands.SendData(self.conn, raw)
def mark_done(self, *, request: bool = False, response: bool = False):
if request:
self.request_done = True
if response:
self.response_done = True
if self.request_done and self.response_done:
self.request_done = self.response_done = False
self.request = self.response = None
self.stream_id = None
if self.send_queue:
send_queue = self.send_queue
self.send_queue = []
for ev in send_queue:
yield from self.send(ev)
def read_response_headers(self, event: events.Event) -> typing.Iterator[HttpEvent]:
assert isinstance(event, events.ConnectionEvent)
@ -216,18 +254,27 @@ class Http1Client(Http1Connection):
if response_head:
response_head = [bytes(x) for x in response_head]
self.response = http.HTTPResponse.wrap(http1_sansio.read_response_head(response_head))
yield ResponseHeaders(self.response, self.stream_id)
yield ResponseHeaders(self.stream_id, self.response)
expected_size = http1.expected_http_body_size(self.request, self.response)
self.body_reader = self.make_body_reader(expected_size)
self.state = self.read_response_body
yield from self.state(event)
else:
pass # FIXME: protect against header size DoS
elif isinstance(event, events.ConnectionClosed):
if self.stream_id:
raise NotImplementedError(f"{event}")
if self.buf:
yield ResponseProtocolError(self.stream_id, f"unexpected server response: {bytes(self.buf)}")
else:
# The server has closed the connection to prevent us from continuing.
# We need to signal that to the stream.
# https://tools.ietf.org/html/rfc7231#section-6.5.11
yield ResponseProtocolError(self.stream_id, "server closed connection")
else:
return
yield commands.CloseConnection(self.conn)
else:
raise ValueError(f"Unexpected event: {event}")
@ -237,14 +284,7 @@ class Http1Client(Http1Connection):
yield e
if isinstance(e, ResponseEndOfMessage):
self.state = self.read_response_headers
self.stream_id = None
self.request = None
self.response = None
if self.send_queue:
send_queue = self.send_queue
self.send_queue = []
for ev in send_queue:
yield from self.send(ev)
yield from self.mark_done(response=True)
__all__ = [

View File

@ -1,9 +1,9 @@
import pytest
from mitmproxy.http import HTTPResponse, HTTPFlow
from mitmproxy.http import HTTPFlow, HTTPResponse
from mitmproxy.proxy.protocol.http import HTTPMode
from mitmproxy.proxy2 import layer
from mitmproxy.proxy2.commands import OpenConnection, SendData
from mitmproxy.proxy2.commands import CloseConnection, OpenConnection, SendData
from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
from mitmproxy.proxy2.layers import http, tls
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply, reply_establish_server_tls, reply_next_layer
@ -168,6 +168,22 @@ def test_http_reply_from_proxy(tctx):
)
def test_response_until_eof(tctx):
"""Test scenario where the server response body is terminated by EOF."""
server = Placeholder()
assert (
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< OpenConnection(server)
>> reply(None)
<< SendData(server, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
>> DataReceived(server, b"HTTP/1.1 200 OK\r\n\r\nfoo")
>> ConnectionClosed(server)
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\n\r\nfoo")
<< CloseConnection(tctx.client)
)
def test_disconnect_while_intercept(tctx):
"""Test a server disconnect while a request is intercepted."""
tctx.options.connection_strategy = "eager"
@ -198,3 +214,130 @@ def test_disconnect_while_intercept(tctx):
)
assert server1() != server2()
assert flow().server_conn == server2()
def test_response_streaming(tctx):
"""Test HTTP response streaming"""
server = Placeholder()
flow = Placeholder()
def enable_streaming(flow: HTTPFlow):
flow.response.stream = lambda x: x.upper()
assert (
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
>> DataReceived(tctx.client, b"GET http://example.com/largefile HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< OpenConnection(server)
>> reply(None)
<< SendData(server, b"GET /largefile HTTP/1.1\r\nHost: example.com\r\n\r\n")
>> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nabc")
<< http.HttpResponseHeadersHook(flow)
>> reply(side_effect=enable_streaming)
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nABC")
>> DataReceived(server, b"def")
<< SendData(tctx.client, b"DEF")
)
@pytest.mark.parametrize("response", ["normal response", "early response", "early close", "early kill"])
def test_request_streaming(tctx, response):
"""
Test HTTP request streaming
This is a bit more contrived as we may receive server data while we are still sending the request.
"""
server = Placeholder()
flow = Placeholder()
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
def enable_streaming(flow: HTTPFlow):
flow.request.stream = lambda x: x.upper()
assert (
playbook
>> DataReceived(tctx.client, b"POST http://example.com/ HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Length: 6\r\n\r\n"
b"abc")
<< http.HttpRequestHeadersHook(flow)
>> reply(side_effect=enable_streaming)
<< OpenConnection(server)
>> reply(None)
<< SendData(server, b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Length: 6\r\n\r\n"
b"ABC")
)
if response == "normal response":
assert (
playbook
>> DataReceived(tctx.client, b"def")
<< SendData(server, b"DEF")
>> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
)
elif response == "early response":
# We may receive a response before we have finished sending our request.
# We continue sending unless the server closes the connection.
# https://tools.ietf.org/html/rfc7231#section-6.5.11
assert (
playbook
>> DataReceived(server, b"HTTP/1.1 413 Request Entity Too Large\r\nContent-Length: 0\r\n\r\n")
<< SendData(tctx.client, b"HTTP/1.1 413 Request Entity Too Large\r\nContent-Length: 0\r\n\r\n")
>> DataReceived(tctx.client, b"def")
<< SendData(server, b"DEF")
)
elif response == "early close":
assert (
playbook
>> DataReceived(server, b"HTTP/1.1 413 Request Entity Too Large\r\nContent-Length: 0\r\n\r\n")
<< SendData(tctx.client, b"HTTP/1.1 413 Request Entity Too Large\r\nContent-Length: 0\r\n\r\n")
>> ConnectionClosed(server)
<< CloseConnection(server)
<< CloseConnection(tctx.client)
)
elif response == "early kill":
err = Placeholder()
assert (
playbook
>> ConnectionClosed(server)
<< CloseConnection(server)
<< SendData(tctx.client, err)
<< CloseConnection(tctx.client)
)
assert b"502 Bad Gateway" in err()
else: # pragma: no cover
assert False
@pytest.mark.parametrize("data", [
None,
b"I don't speak HTTP.",
b"HTTP/1.1 200 OK\r\nContent-Length: 10\r\n\r\nweee"
])
def test_server_aborts(tctx, data):
"""Test the scenario where the server doesn't serve a response"""
server = Placeholder()
flow = Placeholder()
err = Placeholder()
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
assert (
playbook
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< OpenConnection(server)
>> reply(None)
<< SendData(server, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
)
if data:
playbook >> DataReceived(server, data)
assert (
playbook
>> ConnectionClosed(server)
<< CloseConnection(server)
<< http.HttpErrorHook(flow)
>> reply()
<< SendData(tctx.client, err)
<< CloseConnection(tctx.client)
)
assert flow().error
assert b"502 Bad Gateway" in err()