[sans-io] unify HTTP/1 client/server implementations

This commit is contained in:
Maximilian Hils 2020-12-07 22:50:06 +01:00
parent 7afa290eff
commit ab733b73ca
2 changed files with 149 additions and 148 deletions

View File

@ -4,6 +4,7 @@ from typing import Iterable, List, Optional, Tuple
from mitmproxy.net import check
from mitmproxy.net.http import headers, request, response, url
from mitmproxy.net.http.http1 import read
def _parse_authority_form(hostport: bytes) -> Tuple[bytes, int]:
@ -165,3 +166,16 @@ def read_response_head(lines: List[bytes]) -> response.Response:
timestamp_start=time.time(),
timestamp_end=None,
)
def expected_http_body_size(
request: request.Request,
response: Optional[response.Response] = None,
expect_continue_as_0: bool = True,
):
"""
Like the non-sans-io version, but also treating CONNECT as content-length: 0
"""
if request.data.method.upper() == b"CONNECT":
return 0
return read.expected_http_body_size(request, response, expect_continue_as_0)

View File

@ -1,5 +1,5 @@
import abc
import typing
from typing import Union, Optional, Callable
import h11
from h11._readers import ChunkedReader, ContentLengthReader, Http10Reader
@ -7,6 +7,7 @@ from h11._receivebuffer import ReceiveBuffer
from mitmproxy import exceptions, http
from mitmproxy.net.http import http1, status_codes
from mitmproxy.net import http as net_http
from mitmproxy.net.http.http1 import read_sansio as http1_sansio
from mitmproxy.proxy2 import commands, events, layer
from mitmproxy.proxy2.context import Connection, ConnectionState, Context
@ -17,44 +18,51 @@ from ._base import HttpConnection
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
TBodyReader = typing.Union[ChunkedReader, Http10Reader, ContentLengthReader]
TBodyReader = Union[ChunkedReader, Http10Reader, ContentLengthReader]
class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
stream_id: typing.Optional[StreamId] = None
request: typing.Optional[http.HTTPRequest] = None
response: typing.Optional[http.HTTPResponse] = None
stream_id: Optional[StreamId] = None
request: Optional[http.HTTPRequest] = None
response: Optional[http.HTTPResponse] = None
request_done: bool = False
response_done: bool = False
state: typing.Callable[[events.Event], layer.CommandGenerator[None]]
state: Callable[[events.ConnectionEvent], layer.CommandGenerator[None]]
body_reader: TBodyReader
buf: ReceiveBuffer
ReceiveProtocolError: Callable[[int, str], Union[RequestProtocolError, ResponseProtocolError]]
ReceiveData: Callable[[int, bytes], Union[RequestData, ResponseData]]
ReceiveEndOfMessage: Callable[[int], Union[RequestEndOfMessage, ResponseEndOfMessage]]
def __init__(self, context: Context, conn: Connection):
super().__init__(context, conn)
self.buf = ReceiveBuffer()
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, HttpEvent):
yield from self.send(event)
else:
if isinstance(event, events.DataReceived):
self.buf += event.data
yield from self.state(event)
@abc.abstractmethod
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
yield from () # pragma: no cover
def make_body_reader(self, expected_size: typing.Optional[int]) -> TBodyReader:
if expected_size is None:
return ChunkedReader()
elif expected_size == -1:
return Http10Reader()
else:
return ContentLengthReader(expected_size)
@abc.abstractmethod
def read_headers(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
yield from () # pragma: no cover
def read_body(self, event: events.Event, is_request: bool) -> layer.CommandGenerator[None]:
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, HttpEvent):
yield from self.send(event)
else:
if isinstance(event, events.DataReceived) and self.state != self.passthrough:
self.buf += event.data
yield from self.state(event)
@expect(events.Start)
def start(self, _) -> layer.CommandGenerator[None]:
self.state = self.read_headers
yield from ()
state = start
def read_body(self, event: events.Event) -> layer.CommandGenerator[None]:
while True:
try:
if isinstance(event, events.DataReceived):
@ -65,10 +73,7 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
raise AssertionError(f"Unexpected event: {event}")
except h11.ProtocolError as e:
yield commands.CloseConnection(self.conn)
if is_request:
yield ReceiveHttp(RequestProtocolError(self.stream_id, f"HTTP/1 protocol error: {e}"))
else:
yield ReceiveHttp(ResponseProtocolError(self.stream_id, f"HTTP/1 protocol error: {e}"))
yield ReceiveHttp(self.ReceiveProtocolError(self.stream_id, f"HTTP/1 protocol error: {e}"))
return
if h11_event is None:
@ -76,17 +81,17 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
elif isinstance(h11_event, h11.Data):
data: bytes = bytes(h11_event.data)
if data:
if is_request:
yield ReceiveHttp(RequestData(self.stream_id, data))
else:
yield ReceiveHttp(ResponseData(self.stream_id, data))
yield ReceiveHttp(self.ReceiveData(self.stream_id, data))
elif isinstance(h11_event, h11.EndOfMessage):
if h11_event.headers:
raise NotImplementedError(f"HTTP trailers are not implemented yet.")
if is_request:
yield ReceiveHttp(RequestEndOfMessage(self.stream_id))
else:
yield ReceiveHttp(ResponseEndOfMessage(self.stream_id))
if self.request.data.method.upper() != b"CONNECT":
yield ReceiveHttp(self.ReceiveEndOfMessage(self.stream_id))
is_request = isinstance(self, Http1Server)
yield from self.mark_done(
request=is_request,
response=not is_request
)
return
def wait(self, event: events.Event) -> layer.CommandGenerator[None]:
@ -103,17 +108,72 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
if event.connection.state is not ConnectionState.CLOSED:
yield commands.CloseConnection(event.connection)
else: # pragma: no cover
yield from ()
raise AssertionError(f"Unexpected event: {event}")
def done(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
yield from () # pragma: no cover
def make_pipe(self) -> layer.CommandGenerator[None]:
self.state = self.passthrough
if self.buf:
already_received = self.buf.maybe_extract_at_most(len(self.buf))
yield from self.state(events.DataReceived(self.conn, already_received))
self.buf.compress()
def passthrough(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, events.DataReceived):
yield ReceiveHttp(self.ReceiveData(self.stream_id, event.data))
elif isinstance(event, events.ConnectionClosed):
if isinstance(self, Http1Server):
yield ReceiveHttp(RequestEndOfMessage(self.stream_id))
else:
yield ReceiveHttp(ResponseEndOfMessage(self.stream_id))
def mark_done(self, *, request: bool = False, response: bool = False) -> layer.CommandGenerator[None]:
if request:
self.request_done = True
if response:
self.response_done = True
if self.request_done and self.response_done:
if should_make_pipe(self.request, self.response):
yield from self.make_pipe()
return
connection_done = (
http1_sansio.expected_http_body_size(self.request, self.response) == -1 or
http1.connection_close(self.request.http_version, self.request.headers) or
http1.connection_close(self.response.http_version, self.response.headers) or
(
# If we proxy HTTP/2 to HTTP/1, we only use upstream connections for one request.
# This simplifies our connection management quite a bit as we can rely on
# the proxyserver's max-connection-per-server throttling.
self.request.is_http2 and isinstance(self, Http1Client)
)
)
if connection_done:
yield commands.CloseConnection(self.conn)
self.state = self.done
return
self.request_done = self.response_done = False
self.request = self.response = None
if isinstance(self, Http1Server):
self.stream_id += 2
else:
self.stream_id = None
self.state = self.read_headers
if self.buf:
yield from self.state(events.DataReceived(self.conn, b""))
class Http1Server(Http1Connection):
"""A simple HTTP/1 server with no pipelining support."""
ReceiveProtocolError = RequestProtocolError
ReceiveData = RequestData
ReceiveEndOfMessage = RequestEndOfMessage
def __init__(self, context: Context):
super().__init__(context, context.client)
self.stream_id = 1
self.state = self.start
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
assert event.stream_id == self.stream_id
@ -131,11 +191,6 @@ class Http1Server(Http1Connection):
raw = http1.assemble_response_head(response)
yield commands.SendData(self.conn, raw)
if self.request.first_line_format == "authority":
assert self.state == self.wait
self.body_reader = self.make_body_reader(-1)
self.state = self.read_request_body
yield from self.state(events.DataReceived(self.conn, b""))
elif isinstance(event, ResponseData):
if "chunked" in self.response.headers.get("transfer-encoding", "").lower():
raw = b"%x\r\n%s\r\n" % (len(event.data), event.data)
@ -146,8 +201,7 @@ class Http1Server(Http1Connection):
elif isinstance(event, ResponseEndOfMessage):
if "chunked" in self.response.headers.get("transfer-encoding", "").lower():
yield commands.SendData(self.conn, b"0\r\n\r\n")
if self.request.first_line_format != "authority":
yield from self.mark_done(response=True)
yield from self.mark_done(response=True)
elif isinstance(event, ResponseProtocolError):
if not self.response:
resp = http.make_error_response(event.code, event.message)
@ -157,59 +211,23 @@ class Http1Server(Http1Connection):
else:
raise AssertionError(f"Unexpected event: {event}")
def mark_done(self, *, request: bool = False, response: bool = False) -> layer.CommandGenerator[None]:
if request:
self.request_done = True
if response:
self.response_done = True
if self.request_done and self.response_done:
connection_done = (
http1.expected_http_body_size(self.request, self.response) == -1 or
http1.connection_close(self.request.http_version, self.request.headers) or
http1.connection_close(self.response.http_version, self.response.headers)
)
if connection_done:
yield commands.CloseConnection(self.conn)
self.state = self.wait
return
self.request_done = self.response_done = False
self.request = self.response = None
self.stream_id += 2
self.state = self.read_request_headers
yield from self.state(events.DataReceived(self.conn, b""))
elif self.request_done:
self.state = self.wait
@expect(events.Start)
def start(self, event: events.Start) -> layer.CommandGenerator[None]:
self.state = self.read_request_headers
yield from ()
def read_request_headers(self, event: events.Event) -> layer.CommandGenerator[None]:
def read_headers(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
if isinstance(event, events.DataReceived):
request_head = self.buf.maybe_extract_lines()
if request_head:
request_head = [bytes(x) for x in request_head] # TODO: Make url.parse compatible with bytearrays
try:
self.request = http1_sansio.read_request_head(request_head)
expected_body_size = http1.expected_http_body_size(self.request, expect_continue_as_0=False)
expected_body_size = http1_sansio.expected_http_body_size(self.request, expect_continue_as_0=False)
except (ValueError, exceptions.HttpSyntaxException) as e:
yield commands.Log(f"{human.format_address(self.conn.peername)}: {e}")
yield commands.CloseConnection(self.conn)
self.state = self.wait
self.state = self.done
return
yield ReceiveHttp(RequestHeaders(self.stream_id, self.request, expected_body_size == 0))
if self.request.first_line_format == "authority":
# The previous proxy server implementation tried to read the request body here:
# https://github.com/mitmproxy/mitmproxy/blob/45e3ae0f9cb50b0edbf4180fd969ea99d40bdf7b/mitmproxy/proxy/protocol/http.py#L251-L255
# We don't do this to be compliant with the h2 spec:
# https://http2.github.io/http2-spec/#CONNECT
self.state = self.wait
else:
self.body_reader = self.make_body_reader(expected_body_size)
self.state = self.read_request_body
yield from self.state(event)
self.body_reader = make_body_reader(expected_body_size)
self.state = self.read_body
yield from self.state(event)
else:
pass # FIXME: protect against header size DoS
elif isinstance(event, events.ConnectionClosed):
@ -220,33 +238,28 @@ class Http1Server(Http1Connection):
else:
raise AssertionError(f"Unexpected event: {event}")
def read_request_body(self, event: events.Event) -> layer.CommandGenerator[None]:
for e in self.read_body(event, True):
yield e
if isinstance(e, ReceiveHttp) and isinstance(e.event, RequestEndOfMessage):
yield from self.mark_done(request=True)
def mark_done(self, *, request: bool = False, response: bool = False) -> layer.CommandGenerator[None]:
yield from super().mark_done(request=request, response=response)
if self.request_done and not self.response_done:
self.state = self.wait
class Http1Client(Http1Connection):
send_queue: typing.List[HttpEvent]
"""A queue of send events for flows other than the one that is currently being transmitted."""
"""A simple HTTP/1 client with no pipelining support."""
ReceiveProtocolError = ResponseProtocolError
ReceiveData = ResponseData
ReceiveEndOfMessage = ResponseEndOfMessage
def __init__(self, context: Context):
super().__init__(context, context.server)
self.state = self.start
self.send_queue = []
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
if not self.stream_id:
assert isinstance(event, RequestHeaders)
self.stream_id = event.stream_id
self.request = event.request
if self.stream_id != event.stream_id:
# Assuming an h2 server, we may have multiple Streams that try to send requests
# over a single h1 connection. To keep things relatively simple, we don't do any HTTP/1 pipelining
# but keep a queue of still-to-send requests.
self.send_queue.append(event)
return
assert self.stream_id == event.stream_id
if isinstance(event, RequestHeaders):
request = event.request
@ -269,8 +282,7 @@ class Http1Client(Http1Connection):
elif isinstance(event, RequestEndOfMessage):
if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
yield commands.SendData(self.conn, b"0\r\n\r\n")
elif http1.expected_http_body_size(self.request, self.response) == -1:
assert not self.send_queue
elif http1_sansio.expected_http_body_size(self.request, self.response) == -1:
yield commands.CloseConnection(self.conn, half_close=True)
yield from self.mark_done(request=True)
elif isinstance(event, RequestProtocolError):
@ -279,42 +291,7 @@ class Http1Client(Http1Connection):
else:
raise AssertionError(f"Unexpected event: {event}")
def mark_done(self, *, request: bool = False, response: bool = False) -> layer.CommandGenerator[None]:
if request:
self.request_done = True
if response:
self.response_done = True
if self.request_done and self.response_done:
# If we proxy HTTP/2 to HTTP/1, we only use upstream connections for one request.
# This simplifies our connection management quite a bit as we can rely on
# the proxyserver's max-connection-per-server throttling.
connection_done = (
http1.expected_http_body_size(self.request, self.response) == -1 or
http1.connection_close(self.request.http_version, self.request.headers) or
http1.connection_close(self.response.http_version, self.response.headers) or
self.request.is_http2
)
if connection_done:
assert not self.send_queue
yield commands.CloseConnection(self.conn)
self.state = self.wait
return
self.request_done = self.response_done = False
self.request = self.response = None
self.stream_id = None
if self.send_queue:
send_queue = self.send_queue
self.send_queue = []
for ev in send_queue:
yield from self.send(ev)
@expect(events.Start)
def start(self, event: events.Start) -> layer.CommandGenerator[None]:
self.state = self.read_response_headers
yield from ()
@expect(events.ConnectionEvent)
def read_response_headers(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
def read_headers(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
if isinstance(event, events.DataReceived):
if not self.request:
# we just received some data for an unknown request.
@ -327,15 +304,15 @@ class Http1Client(Http1Connection):
response_head = [bytes(x) for x in response_head] # TODO: Make url.parse compatible with bytearrays
try:
self.response = http1_sansio.read_response_head(response_head)
expected_size = http1.expected_http_body_size(self.request, self.response)
expected_size = http1_sansio.expected_http_body_size(self.request, self.response)
except (ValueError, exceptions.HttpSyntaxException) as e:
yield commands.CloseConnection(self.conn)
yield ReceiveHttp(ResponseProtocolError(self.stream_id, f"Cannot parse HTTP response: {e}"))
return
yield ReceiveHttp(ResponseHeaders(self.stream_id, self.response, expected_size == 0))
self.body_reader = self.make_body_reader(expected_size)
self.body_reader = make_body_reader(expected_size)
self.state = self.read_response_body
self.state = self.read_body
yield from self.state(event)
else:
pass # FIXME: protect against header size DoS
@ -355,13 +332,23 @@ class Http1Client(Http1Connection):
else:
raise AssertionError(f"Unexpected event: {event}")
@expect(events.ConnectionEvent)
def read_response_body(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
for e in self.read_body(event, False):
yield e
if isinstance(e, ReceiveHttp) and isinstance(e.event, ResponseEndOfMessage):
self.state = self.read_response_headers
yield from self.mark_done(response=True)
def should_make_pipe(request: net_http.Request, response: net_http.Response) -> bool:
if response.status_code == 101:
return True
elif response.status_code == 200 and request.method.upper() == "CONNECT":
return True
else:
return False
def make_body_reader(expected_size: Optional[int]) -> TBodyReader:
if expected_size is None:
return ChunkedReader()
elif expected_size == -1:
return Http10Reader()
else:
return ContentLengthReader(expected_size)
__all__ = [