mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-22 15:37:45 +00:00
[sans-io] handle 101 switching protocols
This commit is contained in:
parent
396673b2b1
commit
a4a0428bc6
@ -2,10 +2,12 @@ from . import modes
|
||||
from .http import HttpLayer
|
||||
from .tcp import TCPLayer
|
||||
from .tls import ClientTLSLayer, ServerTLSLayer
|
||||
from .websocket import WebsocketLayer
|
||||
|
||||
__all__ = [
|
||||
"modes",
|
||||
"HttpLayer",
|
||||
"TCPLayer",
|
||||
"ClientTLSLayer", "ServerTLSLayer",
|
||||
"WebsocketLayer",
|
||||
]
|
||||
|
@ -9,7 +9,7 @@ from mitmproxy.net.http import url
|
||||
from mitmproxy.proxy.protocol.http import HTTPMode
|
||||
from mitmproxy.proxy2 import commands, events, layer, tunnel
|
||||
from mitmproxy.proxy2.context import Connection, ConnectionState, Context, Server
|
||||
from mitmproxy.proxy2.layers import tls
|
||||
from mitmproxy.proxy2.layers import tls, websocket, tcp
|
||||
from mitmproxy.proxy2.layers.http import _upstream_proxy
|
||||
from mitmproxy.proxy2.utils import expect
|
||||
from mitmproxy.utils import human
|
||||
@ -226,12 +226,14 @@ class HttpStream(layer.Layer):
|
||||
self.flow.request.timestamp_end = time.time()
|
||||
self.flow.request.data.content = self.request_body_buf
|
||||
self.request_body_buf = b""
|
||||
self.client_state = self.state_done
|
||||
yield HttpRequestHook(self.flow)
|
||||
if (yield from self.check_killed(True)):
|
||||
return
|
||||
elif self.flow.response:
|
||||
# response was set by an inline script.
|
||||
# we now need to emulate the responseheaders hook.
|
||||
self.flow.response.timestamp_start = time.time()
|
||||
yield HttpResponseHeadersHook(self.flow)
|
||||
if (yield from self.check_killed(True)):
|
||||
return
|
||||
@ -247,8 +249,6 @@ class HttpStream(layer.Layer):
|
||||
yield SendHttp(RequestData(self.stream_id, self.flow.request.raw_content), self.context.server)
|
||||
yield SendHttp(RequestEndOfMessage(self.stream_id), self.context.server)
|
||||
|
||||
self.client_state = self.state_done
|
||||
|
||||
@expect(ResponseHeaders)
|
||||
def state_wait_for_response_headers(self, event: ResponseHeaders) -> layer.CommandGenerator[None]:
|
||||
self.flow.response = event.response
|
||||
@ -270,38 +270,51 @@ class HttpStream(layer.Layer):
|
||||
data = event.data
|
||||
yield SendHttp(ResponseData(self.stream_id, data), self.context.client)
|
||||
elif isinstance(event, ResponseEndOfMessage):
|
||||
self.flow.response.timestamp_end = time.time()
|
||||
yield HttpResponseHook(self.flow)
|
||||
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
||||
self.server_state = self.state_done
|
||||
yield from self.send_response(already_streamed=True)
|
||||
|
||||
@expect(ResponseData, ResponseEndOfMessage)
|
||||
def state_consume_response_body(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, ResponseData):
|
||||
self.response_body_buf += event.data
|
||||
elif isinstance(event, ResponseEndOfMessage):
|
||||
self.flow.response.timestamp_end = time.time()
|
||||
self.flow.response.data.content = self.response_body_buf
|
||||
self.response_body_buf = b""
|
||||
yield from self.send_response()
|
||||
self.server_state = self.state_done
|
||||
|
||||
def send_response(self):
|
||||
def send_response(self, already_streamed: bool = False):
|
||||
"""We have either consumed the entire response from the server or the response was set by an addon."""
|
||||
self.flow.response.timestamp_end = time.time()
|
||||
yield HttpResponseHook(self.flow)
|
||||
if (yield from self.check_killed(False)):
|
||||
return
|
||||
has_content = bool(self.flow.response.raw_content)
|
||||
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response, not has_content), self.context.client)
|
||||
if has_content:
|
||||
yield SendHttp(ResponseData(self.stream_id, self.flow.response.raw_content), self.context.client)
|
||||
|
||||
if not already_streamed:
|
||||
has_content = bool(self.flow.response.raw_content)
|
||||
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response, not has_content), self.context.client)
|
||||
if has_content:
|
||||
yield SendHttp(ResponseData(self.stream_id, self.flow.response.raw_content), self.context.client)
|
||||
|
||||
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
||||
|
||||
if self.flow.response.status_code == 101:
|
||||
if self.flow.response.headers.get("upgrade", "").strip().lower() == "websocket":
|
||||
self.child_layer = websocket.WebsocketLayer(self.context, self.flow)
|
||||
else:
|
||||
self.child_layer = tcp.TCPLayer(self.context)
|
||||
if self.debug:
|
||||
yield commands.Log(f"{self.debug}[http] upgrading to {self.child_layer}", "debug")
|
||||
yield from self.child_layer.handle_event(events.Start())
|
||||
self._handle_event = self.passthrough
|
||||
return
|
||||
|
||||
self.server_state = self.state_done
|
||||
|
||||
def check_killed(self, emit_error_hook: bool) -> layer.CommandGenerator[bool]:
|
||||
killed_by_us = (
|
||||
self.flow.error and self.flow.error.msg == flow.Error.KILLED_MESSAGE
|
||||
)
|
||||
killed_by_remote = (
|
||||
self.context.client.state is not ConnectionState.OPEN
|
||||
self.context.client.state is not ConnectionState.OPEN
|
||||
)
|
||||
if killed_by_remote:
|
||||
if not self.flow.error:
|
||||
@ -320,15 +333,15 @@ class HttpStream(layer.Layer):
|
||||
event: typing.Union[RequestProtocolError, ResponseProtocolError]
|
||||
) -> layer.CommandGenerator[None]:
|
||||
is_client_error_but_we_already_talk_upstream = (
|
||||
isinstance(event, RequestProtocolError) and
|
||||
self.client_state in (self.state_stream_request_body, self.state_done)
|
||||
isinstance(event, RequestProtocolError) and
|
||||
self.client_state in (self.state_stream_request_body, self.state_done)
|
||||
)
|
||||
if is_client_error_but_we_already_talk_upstream:
|
||||
yield SendHttp(event, self.context.server)
|
||||
self.client_state = self.state_errored
|
||||
|
||||
response_hook_already_triggered = (
|
||||
self.server_state in (self.state_done, self.state_errored)
|
||||
self.server_state in (self.state_done, self.state_errored)
|
||||
)
|
||||
if not response_hook_already_triggered:
|
||||
# We don't want to trigger both a response hook and an error hook,
|
||||
@ -398,34 +411,48 @@ class HttpStream(layer.Layer):
|
||||
|
||||
if 200 <= self.flow.response.status_code < 300:
|
||||
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response), self.context.client)
|
||||
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
||||
self.child_layer = self.child_layer or layer.NextLayer(self.context)
|
||||
yield from self.child_layer.handle_event(events.Start())
|
||||
self._handle_event = self.passthrough
|
||||
else:
|
||||
yield from self.send_response()
|
||||
return (yield SendHttp(ResponseProtocolError(self.stream_id, "EOF"), self.context.client))
|
||||
|
||||
@expect(RequestData, RequestEndOfMessage, events.Event)
|
||||
def passthrough(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
# HTTP events -> normal connection events
|
||||
if isinstance(event, RequestData):
|
||||
event = events.DataReceived(self.context.client, event.data)
|
||||
elif isinstance(event, ResponseData):
|
||||
event = events.DataReceived(self.context.server, event.data)
|
||||
elif isinstance(event, RequestEndOfMessage):
|
||||
event = events.ConnectionClosed(self.context.client)
|
||||
elif isinstance(event, ResponseEndOfMessage):
|
||||
event = events.ConnectionClosed(self.context.server)
|
||||
|
||||
for command in self.child_layer.handle_event(event):
|
||||
# normal connection events -> HTTP events
|
||||
if isinstance(command, commands.SendData) and command.connection == self.context.client:
|
||||
yield SendHttp(ResponseData(self.stream_id, command.data), self.context.client)
|
||||
elif isinstance(command, commands.CloseConnection) and command.connection == self.context.client:
|
||||
yield SendHttp(ResponseProtocolError(self.stream_id, "EOF"), self.context.client)
|
||||
if isinstance(command, commands.SendData):
|
||||
if command.connection == self.context.client:
|
||||
yield SendHttp(ResponseData(self.stream_id, command.data), self.context.client)
|
||||
elif command.connection == self.context.server and self.flow.response.status_code == 101:
|
||||
# there only is a HTTP server connection if we have switched protocols,
|
||||
# not if a connection is established via CONNECT.
|
||||
yield SendHttp(RequestData(self.stream_id, command.data), self.context.server)
|
||||
else:
|
||||
yield command
|
||||
elif isinstance(command, commands.CloseConnection):
|
||||
# If we are running TCP over HTTP we want to be consistent with half-closes.
|
||||
# The easiest approach for this is to just always full close for now.
|
||||
# Alternatively, we could signal that we want a half close only through ResponseProtocolError,
|
||||
# but that is more complex to implement.
|
||||
command.half_close = False
|
||||
yield command
|
||||
if command.connection == self.context.client:
|
||||
yield SendHttp(ResponseProtocolError(self.stream_id, "EOF"), self.context.client)
|
||||
elif command.connection == self.context.server and self.flow.response.status_code == 101:
|
||||
yield SendHttp(RequestProtocolError(self.stream_id, "EOF"), self.context.server)
|
||||
else:
|
||||
# If we are running TCP over HTTP we want to be consistent with half-closes.
|
||||
# The easiest approach for this is to just always full close for now.
|
||||
# Alternatively, we could signal that we want a half close only through ResponseProtocolError,
|
||||
# but that is more complex to implement.
|
||||
command.half_close = False
|
||||
yield command
|
||||
else:
|
||||
yield command
|
||||
|
||||
@ -540,8 +567,8 @@ class HttpLayer(layer.Layer):
|
||||
for connection in self.connections:
|
||||
# see "tricky multiplexing edge case" in make_http_connection for an explanation
|
||||
conn_is_pending_or_h2 = (
|
||||
connection.alpn == b"h2"
|
||||
or connection in self.waiting_for_establishment
|
||||
connection.alpn == b"h2"
|
||||
or connection in self.waiting_for_establishment
|
||||
)
|
||||
h2_to_h1 = self.context.client.alpn == b"h2" and not conn_is_pending_or_h2
|
||||
connection_suitable = (
|
||||
|
@ -401,7 +401,8 @@ def test_server_unreachable(tctx, connect):
|
||||
playbook << http.HttpErrorHook(flow)
|
||||
playbook >> reply()
|
||||
playbook << SendData(tctx.client, err)
|
||||
playbook << CloseConnection(tctx.client)
|
||||
if not connect:
|
||||
playbook << CloseConnection(tctx.client)
|
||||
|
||||
assert playbook
|
||||
if not connect:
|
||||
|
Loading…
Reference in New Issue
Block a user