[sans-io] handle 101 switching protocols

This commit is contained in:
Maximilian Hils 2020-12-07 22:54:48 +01:00
parent 396673b2b1
commit a4a0428bc6
3 changed files with 62 additions and 32 deletions

View File

@ -2,10 +2,12 @@ from . import modes
from .http import HttpLayer
from .tcp import TCPLayer
from .tls import ClientTLSLayer, ServerTLSLayer
from .websocket import WebsocketLayer
__all__ = [
"modes",
"HttpLayer",
"TCPLayer",
"ClientTLSLayer", "ServerTLSLayer",
"WebsocketLayer",
]

View File

@ -9,7 +9,7 @@ from mitmproxy.net.http import url
from mitmproxy.proxy.protocol.http import HTTPMode
from mitmproxy.proxy2 import commands, events, layer, tunnel
from mitmproxy.proxy2.context import Connection, ConnectionState, Context, Server
from mitmproxy.proxy2.layers import tls
from mitmproxy.proxy2.layers import tls, websocket, tcp
from mitmproxy.proxy2.layers.http import _upstream_proxy
from mitmproxy.proxy2.utils import expect
from mitmproxy.utils import human
@ -226,12 +226,14 @@ class HttpStream(layer.Layer):
self.flow.request.timestamp_end = time.time()
self.flow.request.data.content = self.request_body_buf
self.request_body_buf = b""
self.client_state = self.state_done
yield HttpRequestHook(self.flow)
if (yield from self.check_killed(True)):
return
elif self.flow.response:
# response was set by an inline script.
# we now need to emulate the responseheaders hook.
self.flow.response.timestamp_start = time.time()
yield HttpResponseHeadersHook(self.flow)
if (yield from self.check_killed(True)):
return
@ -247,8 +249,6 @@ class HttpStream(layer.Layer):
yield SendHttp(RequestData(self.stream_id, self.flow.request.raw_content), self.context.server)
yield SendHttp(RequestEndOfMessage(self.stream_id), self.context.server)
self.client_state = self.state_done
@expect(ResponseHeaders)
def state_wait_for_response_headers(self, event: ResponseHeaders) -> layer.CommandGenerator[None]:
self.flow.response = event.response
@ -270,32 +270,45 @@ class HttpStream(layer.Layer):
data = event.data
yield SendHttp(ResponseData(self.stream_id, data), self.context.client)
elif isinstance(event, ResponseEndOfMessage):
self.flow.response.timestamp_end = time.time()
yield HttpResponseHook(self.flow)
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
self.server_state = self.state_done
yield from self.send_response(already_streamed=True)
@expect(ResponseData, ResponseEndOfMessage)
def state_consume_response_body(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, ResponseData):
self.response_body_buf += event.data
elif isinstance(event, ResponseEndOfMessage):
self.flow.response.timestamp_end = time.time()
self.flow.response.data.content = self.response_body_buf
self.response_body_buf = b""
yield from self.send_response()
self.server_state = self.state_done
def send_response(self):
def send_response(self, already_streamed: bool = False):
"""We have either consumed the entire response from the server or the response was set by an addon."""
self.flow.response.timestamp_end = time.time()
yield HttpResponseHook(self.flow)
if (yield from self.check_killed(False)):
return
if not already_streamed:
has_content = bool(self.flow.response.raw_content)
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response, not has_content), self.context.client)
if has_content:
yield SendHttp(ResponseData(self.stream_id, self.flow.response.raw_content), self.context.client)
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
if self.flow.response.status_code == 101:
if self.flow.response.headers.get("upgrade", "").strip().lower() == "websocket":
self.child_layer = websocket.WebsocketLayer(self.context, self.flow)
else:
self.child_layer = tcp.TCPLayer(self.context)
if self.debug:
yield commands.Log(f"{self.debug}[http] upgrading to {self.child_layer}", "debug")
yield from self.child_layer.handle_event(events.Start())
self._handle_event = self.passthrough
return
self.server_state = self.state_done
def check_killed(self, emit_error_hook: bool) -> layer.CommandGenerator[bool]:
killed_by_us = (
self.flow.error and self.flow.error.msg == flow.Error.KILLED_MESSAGE
@ -398,28 +411,42 @@ class HttpStream(layer.Layer):
if 200 <= self.flow.response.status_code < 300:
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response), self.context.client)
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
self.child_layer = self.child_layer or layer.NextLayer(self.context)
yield from self.child_layer.handle_event(events.Start())
self._handle_event = self.passthrough
else:
yield from self.send_response()
return (yield SendHttp(ResponseProtocolError(self.stream_id, "EOF"), self.context.client))
@expect(RequestData, RequestEndOfMessage, events.Event)
def passthrough(self, event: events.Event) -> layer.CommandGenerator[None]:
# HTTP events -> normal connection events
if isinstance(event, RequestData):
event = events.DataReceived(self.context.client, event.data)
elif isinstance(event, ResponseData):
event = events.DataReceived(self.context.server, event.data)
elif isinstance(event, RequestEndOfMessage):
event = events.ConnectionClosed(self.context.client)
elif isinstance(event, ResponseEndOfMessage):
event = events.ConnectionClosed(self.context.server)
for command in self.child_layer.handle_event(event):
# normal connection events -> HTTP events
if isinstance(command, commands.SendData) and command.connection == self.context.client:
if isinstance(command, commands.SendData):
if command.connection == self.context.client:
yield SendHttp(ResponseData(self.stream_id, command.data), self.context.client)
elif isinstance(command, commands.CloseConnection) and command.connection == self.context.client:
yield SendHttp(ResponseProtocolError(self.stream_id, "EOF"), self.context.client)
elif command.connection == self.context.server and self.flow.response.status_code == 101:
# there only is a HTTP server connection if we have switched protocols,
# not if a connection is established via CONNECT.
yield SendHttp(RequestData(self.stream_id, command.data), self.context.server)
else:
yield command
elif isinstance(command, commands.CloseConnection):
if command.connection == self.context.client:
yield SendHttp(ResponseProtocolError(self.stream_id, "EOF"), self.context.client)
elif command.connection == self.context.server and self.flow.response.status_code == 101:
yield SendHttp(RequestProtocolError(self.stream_id, "EOF"), self.context.server)
else:
# If we are running TCP over HTTP we want to be consistent with half-closes.
# The easiest approach for this is to just always full close for now.
# Alternatively, we could signal that we want a half close only through ResponseProtocolError,

View File

@ -401,6 +401,7 @@ def test_server_unreachable(tctx, connect):
playbook << http.HttpErrorHook(flow)
playbook >> reply()
playbook << SendData(tctx.client, err)
if not connect:
playbook << CloseConnection(tctx.client)
assert playbook