mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 00:01:36 +00:00
[sans-io] unify HTTP/1 client/server implementations
This commit is contained in:
parent
7afa290eff
commit
ab733b73ca
@ -4,6 +4,7 @@ from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
from mitmproxy.net import check
|
||||
from mitmproxy.net.http import headers, request, response, url
|
||||
from mitmproxy.net.http.http1 import read
|
||||
|
||||
|
||||
def _parse_authority_form(hostport: bytes) -> Tuple[bytes, int]:
|
||||
@ -165,3 +166,16 @@ def read_response_head(lines: List[bytes]) -> response.Response:
|
||||
timestamp_start=time.time(),
|
||||
timestamp_end=None,
|
||||
)
|
||||
|
||||
|
||||
def expected_http_body_size(
|
||||
request: request.Request,
|
||||
response: Optional[response.Response] = None,
|
||||
expect_continue_as_0: bool = True,
|
||||
):
|
||||
"""
|
||||
Like the non-sans-io version, but also treating CONNECT as content-length: 0
|
||||
"""
|
||||
if request.data.method.upper() == b"CONNECT":
|
||||
return 0
|
||||
return read.expected_http_body_size(request, response, expect_continue_as_0)
|
||||
|
@ -1,5 +1,5 @@
|
||||
import abc
|
||||
import typing
|
||||
from typing import Union, Optional, Callable
|
||||
|
||||
import h11
|
||||
from h11._readers import ChunkedReader, ContentLengthReader, Http10Reader
|
||||
@ -7,6 +7,7 @@ from h11._receivebuffer import ReceiveBuffer
|
||||
|
||||
from mitmproxy import exceptions, http
|
||||
from mitmproxy.net.http import http1, status_codes
|
||||
from mitmproxy.net import http as net_http
|
||||
from mitmproxy.net.http.http1 import read_sansio as http1_sansio
|
||||
from mitmproxy.proxy2 import commands, events, layer
|
||||
from mitmproxy.proxy2.context import Connection, ConnectionState, Context
|
||||
@ -17,44 +18,51 @@ from ._base import HttpConnection
|
||||
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
|
||||
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
|
||||
|
||||
TBodyReader = typing.Union[ChunkedReader, Http10Reader, ContentLengthReader]
|
||||
TBodyReader = Union[ChunkedReader, Http10Reader, ContentLengthReader]
|
||||
|
||||
|
||||
class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
|
||||
stream_id: typing.Optional[StreamId] = None
|
||||
request: typing.Optional[http.HTTPRequest] = None
|
||||
response: typing.Optional[http.HTTPResponse] = None
|
||||
stream_id: Optional[StreamId] = None
|
||||
request: Optional[http.HTTPRequest] = None
|
||||
response: Optional[http.HTTPResponse] = None
|
||||
request_done: bool = False
|
||||
response_done: bool = False
|
||||
state: typing.Callable[[events.Event], layer.CommandGenerator[None]]
|
||||
state: Callable[[events.ConnectionEvent], layer.CommandGenerator[None]]
|
||||
body_reader: TBodyReader
|
||||
buf: ReceiveBuffer
|
||||
|
||||
ReceiveProtocolError: Callable[[int, str], Union[RequestProtocolError, ResponseProtocolError]]
|
||||
ReceiveData: Callable[[int, bytes], Union[RequestData, ResponseData]]
|
||||
ReceiveEndOfMessage: Callable[[int], Union[RequestEndOfMessage, ResponseEndOfMessage]]
|
||||
|
||||
def __init__(self, context: Context, conn: Connection):
|
||||
super().__init__(context, conn)
|
||||
self.buf = ReceiveBuffer()
|
||||
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, HttpEvent):
|
||||
yield from self.send(event)
|
||||
else:
|
||||
if isinstance(event, events.DataReceived):
|
||||
self.buf += event.data
|
||||
yield from self.state(event)
|
||||
|
||||
@abc.abstractmethod
|
||||
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
|
||||
yield from () # pragma: no cover
|
||||
|
||||
def make_body_reader(self, expected_size: typing.Optional[int]) -> TBodyReader:
|
||||
if expected_size is None:
|
||||
return ChunkedReader()
|
||||
elif expected_size == -1:
|
||||
return Http10Reader()
|
||||
else:
|
||||
return ContentLengthReader(expected_size)
|
||||
@abc.abstractmethod
|
||||
def read_headers(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
|
||||
yield from () # pragma: no cover
|
||||
|
||||
def read_body(self, event: events.Event, is_request: bool) -> layer.CommandGenerator[None]:
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, HttpEvent):
|
||||
yield from self.send(event)
|
||||
else:
|
||||
if isinstance(event, events.DataReceived) and self.state != self.passthrough:
|
||||
self.buf += event.data
|
||||
yield from self.state(event)
|
||||
|
||||
@expect(events.Start)
|
||||
def start(self, _) -> layer.CommandGenerator[None]:
|
||||
self.state = self.read_headers
|
||||
yield from ()
|
||||
|
||||
state = start
|
||||
|
||||
def read_body(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
while True:
|
||||
try:
|
||||
if isinstance(event, events.DataReceived):
|
||||
@ -65,10 +73,7 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
|
||||
raise AssertionError(f"Unexpected event: {event}")
|
||||
except h11.ProtocolError as e:
|
||||
yield commands.CloseConnection(self.conn)
|
||||
if is_request:
|
||||
yield ReceiveHttp(RequestProtocolError(self.stream_id, f"HTTP/1 protocol error: {e}"))
|
||||
else:
|
||||
yield ReceiveHttp(ResponseProtocolError(self.stream_id, f"HTTP/1 protocol error: {e}"))
|
||||
yield ReceiveHttp(self.ReceiveProtocolError(self.stream_id, f"HTTP/1 protocol error: {e}"))
|
||||
return
|
||||
|
||||
if h11_event is None:
|
||||
@ -76,17 +81,17 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
|
||||
elif isinstance(h11_event, h11.Data):
|
||||
data: bytes = bytes(h11_event.data)
|
||||
if data:
|
||||
if is_request:
|
||||
yield ReceiveHttp(RequestData(self.stream_id, data))
|
||||
else:
|
||||
yield ReceiveHttp(ResponseData(self.stream_id, data))
|
||||
yield ReceiveHttp(self.ReceiveData(self.stream_id, data))
|
||||
elif isinstance(h11_event, h11.EndOfMessage):
|
||||
if h11_event.headers:
|
||||
raise NotImplementedError(f"HTTP trailers are not implemented yet.")
|
||||
if is_request:
|
||||
yield ReceiveHttp(RequestEndOfMessage(self.stream_id))
|
||||
else:
|
||||
yield ReceiveHttp(ResponseEndOfMessage(self.stream_id))
|
||||
if self.request.data.method.upper() != b"CONNECT":
|
||||
yield ReceiveHttp(self.ReceiveEndOfMessage(self.stream_id))
|
||||
is_request = isinstance(self, Http1Server)
|
||||
yield from self.mark_done(
|
||||
request=is_request,
|
||||
response=not is_request
|
||||
)
|
||||
return
|
||||
|
||||
def wait(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
@ -103,17 +108,72 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
|
||||
if event.connection.state is not ConnectionState.CLOSED:
|
||||
yield commands.CloseConnection(event.connection)
|
||||
else: # pragma: no cover
|
||||
yield from ()
|
||||
raise AssertionError(f"Unexpected event: {event}")
|
||||
|
||||
def done(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
|
||||
yield from () # pragma: no cover
|
||||
|
||||
def make_pipe(self) -> layer.CommandGenerator[None]:
|
||||
self.state = self.passthrough
|
||||
if self.buf:
|
||||
already_received = self.buf.maybe_extract_at_most(len(self.buf))
|
||||
yield from self.state(events.DataReceived(self.conn, already_received))
|
||||
self.buf.compress()
|
||||
|
||||
def passthrough(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, events.DataReceived):
|
||||
yield ReceiveHttp(self.ReceiveData(self.stream_id, event.data))
|
||||
elif isinstance(event, events.ConnectionClosed):
|
||||
if isinstance(self, Http1Server):
|
||||
yield ReceiveHttp(RequestEndOfMessage(self.stream_id))
|
||||
else:
|
||||
yield ReceiveHttp(ResponseEndOfMessage(self.stream_id))
|
||||
|
||||
def mark_done(self, *, request: bool = False, response: bool = False) -> layer.CommandGenerator[None]:
|
||||
if request:
|
||||
self.request_done = True
|
||||
if response:
|
||||
self.response_done = True
|
||||
if self.request_done and self.response_done:
|
||||
if should_make_pipe(self.request, self.response):
|
||||
yield from self.make_pipe()
|
||||
return
|
||||
connection_done = (
|
||||
http1_sansio.expected_http_body_size(self.request, self.response) == -1 or
|
||||
http1.connection_close(self.request.http_version, self.request.headers) or
|
||||
http1.connection_close(self.response.http_version, self.response.headers) or
|
||||
(
|
||||
# If we proxy HTTP/2 to HTTP/1, we only use upstream connections for one request.
|
||||
# This simplifies our connection management quite a bit as we can rely on
|
||||
# the proxyserver's max-connection-per-server throttling.
|
||||
self.request.is_http2 and isinstance(self, Http1Client)
|
||||
)
|
||||
)
|
||||
if connection_done:
|
||||
yield commands.CloseConnection(self.conn)
|
||||
self.state = self.done
|
||||
return
|
||||
self.request_done = self.response_done = False
|
||||
self.request = self.response = None
|
||||
if isinstance(self, Http1Server):
|
||||
self.stream_id += 2
|
||||
else:
|
||||
self.stream_id = None
|
||||
self.state = self.read_headers
|
||||
if self.buf:
|
||||
yield from self.state(events.DataReceived(self.conn, b""))
|
||||
|
||||
|
||||
class Http1Server(Http1Connection):
|
||||
"""A simple HTTP/1 server with no pipelining support."""
|
||||
|
||||
ReceiveProtocolError = RequestProtocolError
|
||||
ReceiveData = RequestData
|
||||
ReceiveEndOfMessage = RequestEndOfMessage
|
||||
|
||||
def __init__(self, context: Context):
|
||||
super().__init__(context, context.client)
|
||||
self.stream_id = 1
|
||||
self.state = self.start
|
||||
|
||||
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
|
||||
assert event.stream_id == self.stream_id
|
||||
@ -131,11 +191,6 @@ class Http1Server(Http1Connection):
|
||||
|
||||
raw = http1.assemble_response_head(response)
|
||||
yield commands.SendData(self.conn, raw)
|
||||
if self.request.first_line_format == "authority":
|
||||
assert self.state == self.wait
|
||||
self.body_reader = self.make_body_reader(-1)
|
||||
self.state = self.read_request_body
|
||||
yield from self.state(events.DataReceived(self.conn, b""))
|
||||
elif isinstance(event, ResponseData):
|
||||
if "chunked" in self.response.headers.get("transfer-encoding", "").lower():
|
||||
raw = b"%x\r\n%s\r\n" % (len(event.data), event.data)
|
||||
@ -146,8 +201,7 @@ class Http1Server(Http1Connection):
|
||||
elif isinstance(event, ResponseEndOfMessage):
|
||||
if "chunked" in self.response.headers.get("transfer-encoding", "").lower():
|
||||
yield commands.SendData(self.conn, b"0\r\n\r\n")
|
||||
if self.request.first_line_format != "authority":
|
||||
yield from self.mark_done(response=True)
|
||||
yield from self.mark_done(response=True)
|
||||
elif isinstance(event, ResponseProtocolError):
|
||||
if not self.response:
|
||||
resp = http.make_error_response(event.code, event.message)
|
||||
@ -157,59 +211,23 @@ class Http1Server(Http1Connection):
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event}")
|
||||
|
||||
def mark_done(self, *, request: bool = False, response: bool = False) -> layer.CommandGenerator[None]:
|
||||
if request:
|
||||
self.request_done = True
|
||||
if response:
|
||||
self.response_done = True
|
||||
if self.request_done and self.response_done:
|
||||
connection_done = (
|
||||
http1.expected_http_body_size(self.request, self.response) == -1 or
|
||||
http1.connection_close(self.request.http_version, self.request.headers) or
|
||||
http1.connection_close(self.response.http_version, self.response.headers)
|
||||
)
|
||||
if connection_done:
|
||||
yield commands.CloseConnection(self.conn)
|
||||
self.state = self.wait
|
||||
return
|
||||
self.request_done = self.response_done = False
|
||||
self.request = self.response = None
|
||||
self.stream_id += 2
|
||||
self.state = self.read_request_headers
|
||||
yield from self.state(events.DataReceived(self.conn, b""))
|
||||
elif self.request_done:
|
||||
self.state = self.wait
|
||||
|
||||
@expect(events.Start)
|
||||
def start(self, event: events.Start) -> layer.CommandGenerator[None]:
|
||||
self.state = self.read_request_headers
|
||||
yield from ()
|
||||
|
||||
def read_request_headers(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
def read_headers(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, events.DataReceived):
|
||||
request_head = self.buf.maybe_extract_lines()
|
||||
if request_head:
|
||||
request_head = [bytes(x) for x in request_head] # TODO: Make url.parse compatible with bytearrays
|
||||
try:
|
||||
self.request = http1_sansio.read_request_head(request_head)
|
||||
expected_body_size = http1.expected_http_body_size(self.request, expect_continue_as_0=False)
|
||||
expected_body_size = http1_sansio.expected_http_body_size(self.request, expect_continue_as_0=False)
|
||||
except (ValueError, exceptions.HttpSyntaxException) as e:
|
||||
yield commands.Log(f"{human.format_address(self.conn.peername)}: {e}")
|
||||
yield commands.CloseConnection(self.conn)
|
||||
self.state = self.wait
|
||||
self.state = self.done
|
||||
return
|
||||
yield ReceiveHttp(RequestHeaders(self.stream_id, self.request, expected_body_size == 0))
|
||||
|
||||
if self.request.first_line_format == "authority":
|
||||
# The previous proxy server implementation tried to read the request body here:
|
||||
# https://github.com/mitmproxy/mitmproxy/blob/45e3ae0f9cb50b0edbf4180fd969ea99d40bdf7b/mitmproxy/proxy/protocol/http.py#L251-L255
|
||||
# We don't do this to be compliant with the h2 spec:
|
||||
# https://http2.github.io/http2-spec/#CONNECT
|
||||
self.state = self.wait
|
||||
else:
|
||||
self.body_reader = self.make_body_reader(expected_body_size)
|
||||
self.state = self.read_request_body
|
||||
yield from self.state(event)
|
||||
self.body_reader = make_body_reader(expected_body_size)
|
||||
self.state = self.read_body
|
||||
yield from self.state(event)
|
||||
else:
|
||||
pass # FIXME: protect against header size DoS
|
||||
elif isinstance(event, events.ConnectionClosed):
|
||||
@ -220,33 +238,28 @@ class Http1Server(Http1Connection):
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event}")
|
||||
|
||||
def read_request_body(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
for e in self.read_body(event, True):
|
||||
yield e
|
||||
if isinstance(e, ReceiveHttp) and isinstance(e.event, RequestEndOfMessage):
|
||||
yield from self.mark_done(request=True)
|
||||
def mark_done(self, *, request: bool = False, response: bool = False) -> layer.CommandGenerator[None]:
|
||||
yield from super().mark_done(request=request, response=response)
|
||||
if self.request_done and not self.response_done:
|
||||
self.state = self.wait
|
||||
|
||||
|
||||
class Http1Client(Http1Connection):
|
||||
send_queue: typing.List[HttpEvent]
|
||||
"""A queue of send events for flows other than the one that is currently being transmitted."""
|
||||
"""A simple HTTP/1 client with no pipelining support."""
|
||||
|
||||
ReceiveProtocolError = ResponseProtocolError
|
||||
ReceiveData = ResponseData
|
||||
ReceiveEndOfMessage = ResponseEndOfMessage
|
||||
|
||||
def __init__(self, context: Context):
|
||||
super().__init__(context, context.server)
|
||||
self.state = self.start
|
||||
self.send_queue = []
|
||||
|
||||
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
|
||||
if not self.stream_id:
|
||||
assert isinstance(event, RequestHeaders)
|
||||
self.stream_id = event.stream_id
|
||||
self.request = event.request
|
||||
if self.stream_id != event.stream_id:
|
||||
# Assuming an h2 server, we may have multiple Streams that try to send requests
|
||||
# over a single h1 connection. To keep things relatively simple, we don't do any HTTP/1 pipelining
|
||||
# but keep a queue of still-to-send requests.
|
||||
self.send_queue.append(event)
|
||||
return
|
||||
assert self.stream_id == event.stream_id
|
||||
|
||||
if isinstance(event, RequestHeaders):
|
||||
request = event.request
|
||||
@ -269,8 +282,7 @@ class Http1Client(Http1Connection):
|
||||
elif isinstance(event, RequestEndOfMessage):
|
||||
if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
|
||||
yield commands.SendData(self.conn, b"0\r\n\r\n")
|
||||
elif http1.expected_http_body_size(self.request, self.response) == -1:
|
||||
assert not self.send_queue
|
||||
elif http1_sansio.expected_http_body_size(self.request, self.response) == -1:
|
||||
yield commands.CloseConnection(self.conn, half_close=True)
|
||||
yield from self.mark_done(request=True)
|
||||
elif isinstance(event, RequestProtocolError):
|
||||
@ -279,42 +291,7 @@ class Http1Client(Http1Connection):
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event}")
|
||||
|
||||
def mark_done(self, *, request: bool = False, response: bool = False) -> layer.CommandGenerator[None]:
|
||||
if request:
|
||||
self.request_done = True
|
||||
if response:
|
||||
self.response_done = True
|
||||
if self.request_done and self.response_done:
|
||||
# If we proxy HTTP/2 to HTTP/1, we only use upstream connections for one request.
|
||||
# This simplifies our connection management quite a bit as we can rely on
|
||||
# the proxyserver's max-connection-per-server throttling.
|
||||
connection_done = (
|
||||
http1.expected_http_body_size(self.request, self.response) == -1 or
|
||||
http1.connection_close(self.request.http_version, self.request.headers) or
|
||||
http1.connection_close(self.response.http_version, self.response.headers) or
|
||||
self.request.is_http2
|
||||
)
|
||||
if connection_done:
|
||||
assert not self.send_queue
|
||||
yield commands.CloseConnection(self.conn)
|
||||
self.state = self.wait
|
||||
return
|
||||
self.request_done = self.response_done = False
|
||||
self.request = self.response = None
|
||||
self.stream_id = None
|
||||
if self.send_queue:
|
||||
send_queue = self.send_queue
|
||||
self.send_queue = []
|
||||
for ev in send_queue:
|
||||
yield from self.send(ev)
|
||||
|
||||
@expect(events.Start)
|
||||
def start(self, event: events.Start) -> layer.CommandGenerator[None]:
|
||||
self.state = self.read_response_headers
|
||||
yield from ()
|
||||
|
||||
@expect(events.ConnectionEvent)
|
||||
def read_response_headers(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
|
||||
def read_headers(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, events.DataReceived):
|
||||
if not self.request:
|
||||
# we just received some data for an unknown request.
|
||||
@ -327,15 +304,15 @@ class Http1Client(Http1Connection):
|
||||
response_head = [bytes(x) for x in response_head] # TODO: Make url.parse compatible with bytearrays
|
||||
try:
|
||||
self.response = http1_sansio.read_response_head(response_head)
|
||||
expected_size = http1.expected_http_body_size(self.request, self.response)
|
||||
expected_size = http1_sansio.expected_http_body_size(self.request, self.response)
|
||||
except (ValueError, exceptions.HttpSyntaxException) as e:
|
||||
yield commands.CloseConnection(self.conn)
|
||||
yield ReceiveHttp(ResponseProtocolError(self.stream_id, f"Cannot parse HTTP response: {e}"))
|
||||
return
|
||||
yield ReceiveHttp(ResponseHeaders(self.stream_id, self.response, expected_size == 0))
|
||||
self.body_reader = self.make_body_reader(expected_size)
|
||||
self.body_reader = make_body_reader(expected_size)
|
||||
|
||||
self.state = self.read_response_body
|
||||
self.state = self.read_body
|
||||
yield from self.state(event)
|
||||
else:
|
||||
pass # FIXME: protect against header size DoS
|
||||
@ -355,13 +332,23 @@ class Http1Client(Http1Connection):
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event}")
|
||||
|
||||
@expect(events.ConnectionEvent)
|
||||
def read_response_body(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
|
||||
for e in self.read_body(event, False):
|
||||
yield e
|
||||
if isinstance(e, ReceiveHttp) and isinstance(e.event, ResponseEndOfMessage):
|
||||
self.state = self.read_response_headers
|
||||
yield from self.mark_done(response=True)
|
||||
|
||||
def should_make_pipe(request: net_http.Request, response: net_http.Response) -> bool:
|
||||
if response.status_code == 101:
|
||||
return True
|
||||
elif response.status_code == 200 and request.method.upper() == "CONNECT":
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def make_body_reader(expected_size: Optional[int]) -> TBodyReader:
|
||||
if expected_size is None:
|
||||
return ChunkedReader()
|
||||
elif expected_size == -1:
|
||||
return Http10Reader()
|
||||
else:
|
||||
return ContentLengthReader(expected_size)
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
Loading…
Reference in New Issue
Block a user