mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-27 02:24:18 +00:00
[sans-io] handle 101 switching protocols
This commit is contained in:
parent
396673b2b1
commit
a4a0428bc6
@ -2,10 +2,12 @@ from . import modes
|
|||||||
from .http import HttpLayer
|
from .http import HttpLayer
|
||||||
from .tcp import TCPLayer
|
from .tcp import TCPLayer
|
||||||
from .tls import ClientTLSLayer, ServerTLSLayer
|
from .tls import ClientTLSLayer, ServerTLSLayer
|
||||||
|
from .websocket import WebsocketLayer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"modes",
|
"modes",
|
||||||
"HttpLayer",
|
"HttpLayer",
|
||||||
"TCPLayer",
|
"TCPLayer",
|
||||||
"ClientTLSLayer", "ServerTLSLayer",
|
"ClientTLSLayer", "ServerTLSLayer",
|
||||||
|
"WebsocketLayer",
|
||||||
]
|
]
|
||||||
|
@ -9,7 +9,7 @@ from mitmproxy.net.http import url
|
|||||||
from mitmproxy.proxy.protocol.http import HTTPMode
|
from mitmproxy.proxy.protocol.http import HTTPMode
|
||||||
from mitmproxy.proxy2 import commands, events, layer, tunnel
|
from mitmproxy.proxy2 import commands, events, layer, tunnel
|
||||||
from mitmproxy.proxy2.context import Connection, ConnectionState, Context, Server
|
from mitmproxy.proxy2.context import Connection, ConnectionState, Context, Server
|
||||||
from mitmproxy.proxy2.layers import tls
|
from mitmproxy.proxy2.layers import tls, websocket, tcp
|
||||||
from mitmproxy.proxy2.layers.http import _upstream_proxy
|
from mitmproxy.proxy2.layers.http import _upstream_proxy
|
||||||
from mitmproxy.proxy2.utils import expect
|
from mitmproxy.proxy2.utils import expect
|
||||||
from mitmproxy.utils import human
|
from mitmproxy.utils import human
|
||||||
@ -226,12 +226,14 @@ class HttpStream(layer.Layer):
|
|||||||
self.flow.request.timestamp_end = time.time()
|
self.flow.request.timestamp_end = time.time()
|
||||||
self.flow.request.data.content = self.request_body_buf
|
self.flow.request.data.content = self.request_body_buf
|
||||||
self.request_body_buf = b""
|
self.request_body_buf = b""
|
||||||
|
self.client_state = self.state_done
|
||||||
yield HttpRequestHook(self.flow)
|
yield HttpRequestHook(self.flow)
|
||||||
if (yield from self.check_killed(True)):
|
if (yield from self.check_killed(True)):
|
||||||
return
|
return
|
||||||
elif self.flow.response:
|
elif self.flow.response:
|
||||||
# response was set by an inline script.
|
# response was set by an inline script.
|
||||||
# we now need to emulate the responseheaders hook.
|
# we now need to emulate the responseheaders hook.
|
||||||
|
self.flow.response.timestamp_start = time.time()
|
||||||
yield HttpResponseHeadersHook(self.flow)
|
yield HttpResponseHeadersHook(self.flow)
|
||||||
if (yield from self.check_killed(True)):
|
if (yield from self.check_killed(True)):
|
||||||
return
|
return
|
||||||
@ -247,8 +249,6 @@ class HttpStream(layer.Layer):
|
|||||||
yield SendHttp(RequestData(self.stream_id, self.flow.request.raw_content), self.context.server)
|
yield SendHttp(RequestData(self.stream_id, self.flow.request.raw_content), self.context.server)
|
||||||
yield SendHttp(RequestEndOfMessage(self.stream_id), self.context.server)
|
yield SendHttp(RequestEndOfMessage(self.stream_id), self.context.server)
|
||||||
|
|
||||||
self.client_state = self.state_done
|
|
||||||
|
|
||||||
@expect(ResponseHeaders)
|
@expect(ResponseHeaders)
|
||||||
def state_wait_for_response_headers(self, event: ResponseHeaders) -> layer.CommandGenerator[None]:
|
def state_wait_for_response_headers(self, event: ResponseHeaders) -> layer.CommandGenerator[None]:
|
||||||
self.flow.response = event.response
|
self.flow.response = event.response
|
||||||
@ -270,32 +270,45 @@ class HttpStream(layer.Layer):
|
|||||||
data = event.data
|
data = event.data
|
||||||
yield SendHttp(ResponseData(self.stream_id, data), self.context.client)
|
yield SendHttp(ResponseData(self.stream_id, data), self.context.client)
|
||||||
elif isinstance(event, ResponseEndOfMessage):
|
elif isinstance(event, ResponseEndOfMessage):
|
||||||
self.flow.response.timestamp_end = time.time()
|
yield from self.send_response(already_streamed=True)
|
||||||
yield HttpResponseHook(self.flow)
|
|
||||||
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
|
||||||
self.server_state = self.state_done
|
|
||||||
|
|
||||||
@expect(ResponseData, ResponseEndOfMessage)
|
@expect(ResponseData, ResponseEndOfMessage)
|
||||||
def state_consume_response_body(self, event: events.Event) -> layer.CommandGenerator[None]:
|
def state_consume_response_body(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||||
if isinstance(event, ResponseData):
|
if isinstance(event, ResponseData):
|
||||||
self.response_body_buf += event.data
|
self.response_body_buf += event.data
|
||||||
elif isinstance(event, ResponseEndOfMessage):
|
elif isinstance(event, ResponseEndOfMessage):
|
||||||
self.flow.response.timestamp_end = time.time()
|
|
||||||
self.flow.response.data.content = self.response_body_buf
|
self.flow.response.data.content = self.response_body_buf
|
||||||
self.response_body_buf = b""
|
self.response_body_buf = b""
|
||||||
yield from self.send_response()
|
yield from self.send_response()
|
||||||
self.server_state = self.state_done
|
|
||||||
|
|
||||||
def send_response(self):
|
def send_response(self, already_streamed: bool = False):
|
||||||
|
"""We have either consumed the entire response from the server or the response was set by an addon."""
|
||||||
|
self.flow.response.timestamp_end = time.time()
|
||||||
yield HttpResponseHook(self.flow)
|
yield HttpResponseHook(self.flow)
|
||||||
if (yield from self.check_killed(False)):
|
if (yield from self.check_killed(False)):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if not already_streamed:
|
||||||
has_content = bool(self.flow.response.raw_content)
|
has_content = bool(self.flow.response.raw_content)
|
||||||
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response, not has_content), self.context.client)
|
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response, not has_content), self.context.client)
|
||||||
if has_content:
|
if has_content:
|
||||||
yield SendHttp(ResponseData(self.stream_id, self.flow.response.raw_content), self.context.client)
|
yield SendHttp(ResponseData(self.stream_id, self.flow.response.raw_content), self.context.client)
|
||||||
|
|
||||||
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
||||||
|
|
||||||
|
if self.flow.response.status_code == 101:
|
||||||
|
if self.flow.response.headers.get("upgrade", "").strip().lower() == "websocket":
|
||||||
|
self.child_layer = websocket.WebsocketLayer(self.context, self.flow)
|
||||||
|
else:
|
||||||
|
self.child_layer = tcp.TCPLayer(self.context)
|
||||||
|
if self.debug:
|
||||||
|
yield commands.Log(f"{self.debug}[http] upgrading to {self.child_layer}", "debug")
|
||||||
|
yield from self.child_layer.handle_event(events.Start())
|
||||||
|
self._handle_event = self.passthrough
|
||||||
|
return
|
||||||
|
|
||||||
|
self.server_state = self.state_done
|
||||||
|
|
||||||
def check_killed(self, emit_error_hook: bool) -> layer.CommandGenerator[bool]:
|
def check_killed(self, emit_error_hook: bool) -> layer.CommandGenerator[bool]:
|
||||||
killed_by_us = (
|
killed_by_us = (
|
||||||
self.flow.error and self.flow.error.msg == flow.Error.KILLED_MESSAGE
|
self.flow.error and self.flow.error.msg == flow.Error.KILLED_MESSAGE
|
||||||
@ -398,28 +411,42 @@ class HttpStream(layer.Layer):
|
|||||||
|
|
||||||
if 200 <= self.flow.response.status_code < 300:
|
if 200 <= self.flow.response.status_code < 300:
|
||||||
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response), self.context.client)
|
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response), self.context.client)
|
||||||
|
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
||||||
self.child_layer = self.child_layer or layer.NextLayer(self.context)
|
self.child_layer = self.child_layer or layer.NextLayer(self.context)
|
||||||
yield from self.child_layer.handle_event(events.Start())
|
yield from self.child_layer.handle_event(events.Start())
|
||||||
self._handle_event = self.passthrough
|
self._handle_event = self.passthrough
|
||||||
else:
|
else:
|
||||||
yield from self.send_response()
|
yield from self.send_response()
|
||||||
return (yield SendHttp(ResponseProtocolError(self.stream_id, "EOF"), self.context.client))
|
|
||||||
|
|
||||||
@expect(RequestData, RequestEndOfMessage, events.Event)
|
@expect(RequestData, RequestEndOfMessage, events.Event)
|
||||||
def passthrough(self, event: events.Event) -> layer.CommandGenerator[None]:
|
def passthrough(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||||
# HTTP events -> normal connection events
|
# HTTP events -> normal connection events
|
||||||
if isinstance(event, RequestData):
|
if isinstance(event, RequestData):
|
||||||
event = events.DataReceived(self.context.client, event.data)
|
event = events.DataReceived(self.context.client, event.data)
|
||||||
|
elif isinstance(event, ResponseData):
|
||||||
|
event = events.DataReceived(self.context.server, event.data)
|
||||||
elif isinstance(event, RequestEndOfMessage):
|
elif isinstance(event, RequestEndOfMessage):
|
||||||
event = events.ConnectionClosed(self.context.client)
|
event = events.ConnectionClosed(self.context.client)
|
||||||
|
elif isinstance(event, ResponseEndOfMessage):
|
||||||
|
event = events.ConnectionClosed(self.context.server)
|
||||||
|
|
||||||
for command in self.child_layer.handle_event(event):
|
for command in self.child_layer.handle_event(event):
|
||||||
# normal connection events -> HTTP events
|
# normal connection events -> HTTP events
|
||||||
if isinstance(command, commands.SendData) and command.connection == self.context.client:
|
if isinstance(command, commands.SendData):
|
||||||
|
if command.connection == self.context.client:
|
||||||
yield SendHttp(ResponseData(self.stream_id, command.data), self.context.client)
|
yield SendHttp(ResponseData(self.stream_id, command.data), self.context.client)
|
||||||
elif isinstance(command, commands.CloseConnection) and command.connection == self.context.client:
|
elif command.connection == self.context.server and self.flow.response.status_code == 101:
|
||||||
yield SendHttp(ResponseProtocolError(self.stream_id, "EOF"), self.context.client)
|
# there only is a HTTP server connection if we have switched protocols,
|
||||||
|
# not if a connection is established via CONNECT.
|
||||||
|
yield SendHttp(RequestData(self.stream_id, command.data), self.context.server)
|
||||||
|
else:
|
||||||
|
yield command
|
||||||
elif isinstance(command, commands.CloseConnection):
|
elif isinstance(command, commands.CloseConnection):
|
||||||
|
if command.connection == self.context.client:
|
||||||
|
yield SendHttp(ResponseProtocolError(self.stream_id, "EOF"), self.context.client)
|
||||||
|
elif command.connection == self.context.server and self.flow.response.status_code == 101:
|
||||||
|
yield SendHttp(RequestProtocolError(self.stream_id, "EOF"), self.context.server)
|
||||||
|
else:
|
||||||
# If we are running TCP over HTTP we want to be consistent with half-closes.
|
# If we are running TCP over HTTP we want to be consistent with half-closes.
|
||||||
# The easiest approach for this is to just always full close for now.
|
# The easiest approach for this is to just always full close for now.
|
||||||
# Alternatively, we could signal that we want a half close only through ResponseProtocolError,
|
# Alternatively, we could signal that we want a half close only through ResponseProtocolError,
|
||||||
|
@ -401,6 +401,7 @@ def test_server_unreachable(tctx, connect):
|
|||||||
playbook << http.HttpErrorHook(flow)
|
playbook << http.HttpErrorHook(flow)
|
||||||
playbook >> reply()
|
playbook >> reply()
|
||||||
playbook << SendData(tctx.client, err)
|
playbook << SendData(tctx.client, err)
|
||||||
|
if not connect:
|
||||||
playbook << CloseConnection(tctx.client)
|
playbook << CloseConnection(tctx.client)
|
||||||
|
|
||||||
assert playbook
|
assert playbook
|
||||||
|
Loading…
Reference in New Issue
Block a user