[sans-io] handle 101 switching protocols

This commit is contained in:
Maximilian Hils 2020-12-07 22:54:48 +01:00
parent 396673b2b1
commit a4a0428bc6
3 changed files with 62 additions and 32 deletions

View File

@ -2,10 +2,12 @@ from . import modes
from .http import HttpLayer from .http import HttpLayer
from .tcp import TCPLayer from .tcp import TCPLayer
from .tls import ClientTLSLayer, ServerTLSLayer from .tls import ClientTLSLayer, ServerTLSLayer
from .websocket import WebsocketLayer
__all__ = [ __all__ = [
"modes", "modes",
"HttpLayer", "HttpLayer",
"TCPLayer", "TCPLayer",
"ClientTLSLayer", "ServerTLSLayer", "ClientTLSLayer", "ServerTLSLayer",
"WebsocketLayer",
] ]

View File

@ -9,7 +9,7 @@ from mitmproxy.net.http import url
from mitmproxy.proxy.protocol.http import HTTPMode from mitmproxy.proxy.protocol.http import HTTPMode
from mitmproxy.proxy2 import commands, events, layer, tunnel from mitmproxy.proxy2 import commands, events, layer, tunnel
from mitmproxy.proxy2.context import Connection, ConnectionState, Context, Server from mitmproxy.proxy2.context import Connection, ConnectionState, Context, Server
from mitmproxy.proxy2.layers import tls from mitmproxy.proxy2.layers import tls, websocket, tcp
from mitmproxy.proxy2.layers.http import _upstream_proxy from mitmproxy.proxy2.layers.http import _upstream_proxy
from mitmproxy.proxy2.utils import expect from mitmproxy.proxy2.utils import expect
from mitmproxy.utils import human from mitmproxy.utils import human
@ -226,12 +226,14 @@ class HttpStream(layer.Layer):
self.flow.request.timestamp_end = time.time() self.flow.request.timestamp_end = time.time()
self.flow.request.data.content = self.request_body_buf self.flow.request.data.content = self.request_body_buf
self.request_body_buf = b"" self.request_body_buf = b""
self.client_state = self.state_done
yield HttpRequestHook(self.flow) yield HttpRequestHook(self.flow)
if (yield from self.check_killed(True)): if (yield from self.check_killed(True)):
return return
elif self.flow.response: elif self.flow.response:
# response was set by an inline script. # response was set by an inline script.
# we now need to emulate the responseheaders hook. # we now need to emulate the responseheaders hook.
self.flow.response.timestamp_start = time.time()
yield HttpResponseHeadersHook(self.flow) yield HttpResponseHeadersHook(self.flow)
if (yield from self.check_killed(True)): if (yield from self.check_killed(True)):
return return
@ -247,8 +249,6 @@ class HttpStream(layer.Layer):
yield SendHttp(RequestData(self.stream_id, self.flow.request.raw_content), self.context.server) yield SendHttp(RequestData(self.stream_id, self.flow.request.raw_content), self.context.server)
yield SendHttp(RequestEndOfMessage(self.stream_id), self.context.server) yield SendHttp(RequestEndOfMessage(self.stream_id), self.context.server)
self.client_state = self.state_done
@expect(ResponseHeaders) @expect(ResponseHeaders)
def state_wait_for_response_headers(self, event: ResponseHeaders) -> layer.CommandGenerator[None]: def state_wait_for_response_headers(self, event: ResponseHeaders) -> layer.CommandGenerator[None]:
self.flow.response = event.response self.flow.response = event.response
@ -270,32 +270,45 @@ class HttpStream(layer.Layer):
data = event.data data = event.data
yield SendHttp(ResponseData(self.stream_id, data), self.context.client) yield SendHttp(ResponseData(self.stream_id, data), self.context.client)
elif isinstance(event, ResponseEndOfMessage): elif isinstance(event, ResponseEndOfMessage):
self.flow.response.timestamp_end = time.time() yield from self.send_response(already_streamed=True)
yield HttpResponseHook(self.flow)
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
self.server_state = self.state_done
@expect(ResponseData, ResponseEndOfMessage) @expect(ResponseData, ResponseEndOfMessage)
def state_consume_response_body(self, event: events.Event) -> layer.CommandGenerator[None]: def state_consume_response_body(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, ResponseData): if isinstance(event, ResponseData):
self.response_body_buf += event.data self.response_body_buf += event.data
elif isinstance(event, ResponseEndOfMessage): elif isinstance(event, ResponseEndOfMessage):
self.flow.response.timestamp_end = time.time()
self.flow.response.data.content = self.response_body_buf self.flow.response.data.content = self.response_body_buf
self.response_body_buf = b"" self.response_body_buf = b""
yield from self.send_response() yield from self.send_response()
self.server_state = self.state_done
def send_response(self): def send_response(self, already_streamed: bool = False):
"""We have either consumed the entire response from the server or the response was set by an addon."""
self.flow.response.timestamp_end = time.time()
yield HttpResponseHook(self.flow) yield HttpResponseHook(self.flow)
if (yield from self.check_killed(False)): if (yield from self.check_killed(False)):
return return
if not already_streamed:
has_content = bool(self.flow.response.raw_content) has_content = bool(self.flow.response.raw_content)
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response, not has_content), self.context.client) yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response, not has_content), self.context.client)
if has_content: if has_content:
yield SendHttp(ResponseData(self.stream_id, self.flow.response.raw_content), self.context.client) yield SendHttp(ResponseData(self.stream_id, self.flow.response.raw_content), self.context.client)
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client) yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
if self.flow.response.status_code == 101:
if self.flow.response.headers.get("upgrade", "").strip().lower() == "websocket":
self.child_layer = websocket.WebsocketLayer(self.context, self.flow)
else:
self.child_layer = tcp.TCPLayer(self.context)
if self.debug:
yield commands.Log(f"{self.debug}[http] upgrading to {self.child_layer}", "debug")
yield from self.child_layer.handle_event(events.Start())
self._handle_event = self.passthrough
return
self.server_state = self.state_done
def check_killed(self, emit_error_hook: bool) -> layer.CommandGenerator[bool]: def check_killed(self, emit_error_hook: bool) -> layer.CommandGenerator[bool]:
killed_by_us = ( killed_by_us = (
self.flow.error and self.flow.error.msg == flow.Error.KILLED_MESSAGE self.flow.error and self.flow.error.msg == flow.Error.KILLED_MESSAGE
@ -398,28 +411,42 @@ class HttpStream(layer.Layer):
if 200 <= self.flow.response.status_code < 300: if 200 <= self.flow.response.status_code < 300:
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response), self.context.client) yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response), self.context.client)
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
self.child_layer = self.child_layer or layer.NextLayer(self.context) self.child_layer = self.child_layer or layer.NextLayer(self.context)
yield from self.child_layer.handle_event(events.Start()) yield from self.child_layer.handle_event(events.Start())
self._handle_event = self.passthrough self._handle_event = self.passthrough
else: else:
yield from self.send_response() yield from self.send_response()
return (yield SendHttp(ResponseProtocolError(self.stream_id, "EOF"), self.context.client))
@expect(RequestData, RequestEndOfMessage, events.Event) @expect(RequestData, RequestEndOfMessage, events.Event)
def passthrough(self, event: events.Event) -> layer.CommandGenerator[None]: def passthrough(self, event: events.Event) -> layer.CommandGenerator[None]:
# HTTP events -> normal connection events # HTTP events -> normal connection events
if isinstance(event, RequestData): if isinstance(event, RequestData):
event = events.DataReceived(self.context.client, event.data) event = events.DataReceived(self.context.client, event.data)
elif isinstance(event, ResponseData):
event = events.DataReceived(self.context.server, event.data)
elif isinstance(event, RequestEndOfMessage): elif isinstance(event, RequestEndOfMessage):
event = events.ConnectionClosed(self.context.client) event = events.ConnectionClosed(self.context.client)
elif isinstance(event, ResponseEndOfMessage):
event = events.ConnectionClosed(self.context.server)
for command in self.child_layer.handle_event(event): for command in self.child_layer.handle_event(event):
# normal connection events -> HTTP events # normal connection events -> HTTP events
if isinstance(command, commands.SendData) and command.connection == self.context.client: if isinstance(command, commands.SendData):
if command.connection == self.context.client:
yield SendHttp(ResponseData(self.stream_id, command.data), self.context.client) yield SendHttp(ResponseData(self.stream_id, command.data), self.context.client)
elif isinstance(command, commands.CloseConnection) and command.connection == self.context.client: elif command.connection == self.context.server and self.flow.response.status_code == 101:
yield SendHttp(ResponseProtocolError(self.stream_id, "EOF"), self.context.client) # there only is a HTTP server connection if we have switched protocols,
# not if a connection is established via CONNECT.
yield SendHttp(RequestData(self.stream_id, command.data), self.context.server)
else:
yield command
elif isinstance(command, commands.CloseConnection): elif isinstance(command, commands.CloseConnection):
if command.connection == self.context.client:
yield SendHttp(ResponseProtocolError(self.stream_id, "EOF"), self.context.client)
elif command.connection == self.context.server and self.flow.response.status_code == 101:
yield SendHttp(RequestProtocolError(self.stream_id, "EOF"), self.context.server)
else:
# If we are running TCP over HTTP we want to be consistent with half-closes. # If we are running TCP over HTTP we want to be consistent with half-closes.
# The easiest approach for this is to just always full close for now. # The easiest approach for this is to just always full close for now.
# Alternatively, we could signal that we want a half close only through ResponseProtocolError, # Alternatively, we could signal that we want a half close only through ResponseProtocolError,

View File

@ -401,6 +401,7 @@ def test_server_unreachable(tctx, connect):
playbook << http.HttpErrorHook(flow) playbook << http.HttpErrorHook(flow)
playbook >> reply() playbook >> reply()
playbook << SendData(tctx.client, err) playbook << SendData(tctx.client, err)
if not connect:
playbook << CloseConnection(tctx.client) playbook << CloseConnection(tctx.client)
assert playbook assert playbook