mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-27 02:24:18 +00:00
add sans-io websocket layer
This commit is contained in:
parent
6d59d213e3
commit
4c2fb7f250
@ -1,13 +1,16 @@
|
|||||||
import typing
|
import typing
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import h11
|
import h11
|
||||||
|
from mitmproxy.net import http
|
||||||
from mitmproxy.proxy.protocol2 import events, commands, websocket
|
from mitmproxy.proxy.protocol2 import events, commands, websocket
|
||||||
from mitmproxy.proxy.protocol2.context import ClientServerContext
|
from mitmproxy.proxy.protocol2.context import ClientServerContext
|
||||||
from mitmproxy.proxy.protocol2.layer import Layer
|
from mitmproxy.proxy.protocol2.layer import Layer
|
||||||
from mitmproxy.proxy.protocol2.utils import expect
|
from mitmproxy.proxy.protocol2.utils import expect
|
||||||
|
from mitmproxy.net import websockets
|
||||||
|
|
||||||
|
|
||||||
class HTTPLayer(Layer):
|
class HTTPLayer(Layer):
|
||||||
@ -22,7 +25,7 @@ class HTTPLayer(Layer):
|
|||||||
def __init__(self, context: ClientServerContext):
|
def __init__(self, context: ClientServerContext):
|
||||||
super().__init__(context)
|
super().__init__(context)
|
||||||
self.state = self.read_request_headers
|
self.state = self.read_request_headers
|
||||||
|
self.flow = mock.Mock("mitmproxy.http.HTTPFlow")
|
||||||
self.client_conn = h11.Connection(h11.SERVER)
|
self.client_conn = h11.Connection(h11.SERVER)
|
||||||
self.server_conn = h11.Connection(h11.CLIENT)
|
self.server_conn = h11.Connection(h11.CLIENT)
|
||||||
|
|
||||||
@ -63,6 +66,7 @@ class HTTPLayer(Layer):
|
|||||||
if self.client_conn.client_is_waiting_for_100_continue:
|
if self.client_conn.client_is_waiting_for_100_continue:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
self.flow.request.headers = http.Headers(event.headers)
|
||||||
self.flow_events[0].append(event)
|
self.flow_events[0].append(event)
|
||||||
self.state = self.read_request_body
|
self.state = self.read_request_body
|
||||||
yield from self.read_request_body() # there may already be further events.
|
yield from self.read_request_body() # there may already be further events.
|
||||||
@ -108,9 +112,9 @@ class HTTPLayer(Layer):
|
|||||||
self.state = self.read_response_body
|
self.state = self.read_response_body
|
||||||
yield from self.read_response_body() # there may already be further events.
|
yield from self.read_response_body() # there may already be further events.
|
||||||
elif isinstance(event, h11.InformationalResponse):
|
elif isinstance(event, h11.InformationalResponse):
|
||||||
if event.status_code == 101:
|
self.flow.response.headers = http.Headers(event.headers)
|
||||||
# FIXME: check if this is actually WebSocket
|
if event.status_code == 101 and websockets.check_handshake(self.flow.response.headers):
|
||||||
child_layer = websocket.WebsocketLayer(self.context, None)
|
child_layer = websocket.WebsocketLayer(self.context, self.flow)
|
||||||
yield from child_layer.handle_event(events.Start())
|
yield from child_layer.handle_event(events.Start())
|
||||||
self._handle_event = child_layer.handle_event
|
self._handle_event = child_layer.handle_event
|
||||||
return
|
return
|
||||||
|
146
mitmproxy/proxy/protocol2/test/test_websocket.py
Normal file
146
mitmproxy/proxy/protocol2/test/test_websocket.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from . import tutils
|
||||||
|
from .. import commands
|
||||||
|
from .. import events
|
||||||
|
from .. import websocket
|
||||||
|
from mitmproxy.test import tflow
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ws_playbook(tctx):
|
||||||
|
tctx.server.connected = True
|
||||||
|
playbook = tutils.playbook(
|
||||||
|
websocket.WebsocketLayer(
|
||||||
|
tctx,
|
||||||
|
tflow.twebsocketflow().handshake_flow
|
||||||
|
)
|
||||||
|
)
|
||||||
|
with mock.patch("os.urandom") as m:
|
||||||
|
m.return_value = b"\x10\x11\x12\x13"
|
||||||
|
yield playbook
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple(tctx, ws_playbook):
|
||||||
|
f = tutils.Placeholder()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
ws_playbook
|
||||||
|
<< commands.Hook("websocket_start", f)
|
||||||
|
>> events.HookReply(-1, None)
|
||||||
|
>> events.DataReceived(tctx.client, b"\x82\x85\x10\x11\x12\x13Xt~\x7f\x7f") # Frame with payload b"Hello"
|
||||||
|
<< commands.Hook("websocket_message", f)
|
||||||
|
>> events.HookReply(-1, None)
|
||||||
|
<< commands.SendData(tctx.server, b"\x82\x85\x10\x11\x12\x13Xt~\x7f\x7f")
|
||||||
|
>> events.DataReceived(tctx.server, b'\x81\x05Hello') # Frame with payload "Hello"
|
||||||
|
<< commands.Hook("websocket_message", f)
|
||||||
|
>> events.HookReply(-1, None)
|
||||||
|
<< commands.SendData(tctx.client, b'\x81\x05Hello')
|
||||||
|
>> events.DataReceived(tctx.client, b'\x88\x82\x10\x11\x12\x13\x13\xf9') # Closing frame
|
||||||
|
<< commands.SendData(tctx.server, b'\x88\x82\x10\x11\x12\x13\x13\xf9')
|
||||||
|
<< commands.SendData(tctx.client, b'\x88\x02\x03\xe8')
|
||||||
|
<< commands.Hook("websocket_end", f)
|
||||||
|
>> events.HookReply(-1, None)
|
||||||
|
>> events.DataReceived(tctx.server, b'\x81\x05Hello')
|
||||||
|
<< None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(f().messages) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_server_close(tctx, ws_playbook):
|
||||||
|
f = tutils.Placeholder()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
ws_playbook
|
||||||
|
<< commands.Hook("websocket_start", f)
|
||||||
|
>> events.HookReply(-1, None)
|
||||||
|
>> events.DataReceived(tctx.server, b'\x88\x02\x03\xe8')
|
||||||
|
<< commands.SendData(tctx.client, b'\x88\x02\x03\xe8')
|
||||||
|
<< commands.SendData(tctx.server, b'\x88\x82\x10\x11\x12\x13\x13\xf9')
|
||||||
|
<< commands.Hook("websocket_end", f)
|
||||||
|
>> events.HookReply(-1, None)
|
||||||
|
<< commands.CloseConnection(tctx.client)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ping_pong(tctx, ws_playbook):
|
||||||
|
f = tutils.Placeholder()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
ws_playbook
|
||||||
|
<< commands.Hook("websocket_start", f)
|
||||||
|
>> events.HookReply(-1, None)
|
||||||
|
>> events.DataReceived(tctx.client, b'\x89\x80\x10\x11\x12\x13') # Ping
|
||||||
|
<< commands.Log("info", "Websocket PING received ")
|
||||||
|
>> events.HookReply(-1, None)
|
||||||
|
<< commands.SendData(tctx.server, b'\x89\x80\x10\x11\x12\x13')
|
||||||
|
<< commands.SendData(tctx.client, b'\x8a\x00')
|
||||||
|
>> events.DataReceived(tctx.server, b'\x8a\x00') # Pong
|
||||||
|
<< commands.Log("info", "Websocket PONG received ")
|
||||||
|
>> events.HookReply(-1, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_connection_failed(tctx, ws_playbook):
|
||||||
|
f = tutils.Placeholder()
|
||||||
|
assert (
|
||||||
|
ws_playbook
|
||||||
|
<< commands.Hook("websocket_start", f)
|
||||||
|
>> events.HookReply(-1, None)
|
||||||
|
>> events.DataReceived(tctx.client, b"Not a valid frame")
|
||||||
|
<< commands.SendData(tctx.server, b'\x88\x94\x10\x11\x12\x13\x13\xfb[}fp~zt1}cs~vv0!jv')
|
||||||
|
<< commands.SendData(tctx.client, b'\x88\x14\x03\xeaInvalid opcode 0xe')
|
||||||
|
<< commands.Hook("websocket_error", f)
|
||||||
|
>> events.HookReply(-1, None)
|
||||||
|
<< commands.Hook("websocket_end", f)
|
||||||
|
>> events.HookReply(-1, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_extension(tctx):
|
||||||
|
f = tutils.Placeholder()
|
||||||
|
tctx.server.connected = True
|
||||||
|
handshake_flow = tflow.twebsocketflow().handshake_flow
|
||||||
|
handshake_flow.request.headers["sec-websocket-extensions"] = "permessage-deflate;"
|
||||||
|
handshake_flow.response.headers["sec-websocket-extensions"] = "permessage-deflate;"
|
||||||
|
playbook = tutils.playbook(websocket.WebsocketLayer(tctx, handshake_flow))
|
||||||
|
with mock.patch("os.urandom") as m:
|
||||||
|
m.return_value = b"\x10\x11\x12\x13"
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
<< commands.Hook("websocket_start", f)
|
||||||
|
>> events.HookReply(-1, None)
|
||||||
|
>> events.DataReceived(tctx.client, b'\xc1\x87\x10\x11\x12\x13\xe2Y\xdf\xda\xd9\x16\x12') # Compressed Frame
|
||||||
|
<< commands.Hook("websocket_message", f)
|
||||||
|
>> events.HookReply(-1, None)
|
||||||
|
<< commands.SendData(tctx.server, b'\xc1\x87\x10\x11\x12\x13\xe2Y\xdf\xda\xd9\x16\x12')
|
||||||
|
>> events.DataReceived(tctx.server, b'\xc1\x07\xf2H\xcd\xc9\xc9\x07\x00') # Compressed Frame
|
||||||
|
<< commands.Hook("websocket_message", f)
|
||||||
|
>> events.HookReply(-1, None)
|
||||||
|
<< commands.SendData(tctx.client, b'\xc1\x07\xf2H\xcd\xc9\xc9\x07\x00')
|
||||||
|
)
|
||||||
|
assert len(f().messages) == 2
|
||||||
|
assert f().messages[0].content == "Hello"
|
||||||
|
assert f().messages[1].content == "Hello"
|
||||||
|
|
||||||
|
|
||||||
|
def test_connection_closed(tctx, ws_playbook):
|
||||||
|
f = tutils.Placeholder()
|
||||||
|
assert (
|
||||||
|
ws_playbook
|
||||||
|
<< commands.Hook("websocket_start", f)
|
||||||
|
>> events.HookReply(-1, None)
|
||||||
|
>> events.ConnectionClosed(tctx.server)
|
||||||
|
<< commands.Log("error", "Connection closed abnormally")
|
||||||
|
>> events.HookReply(-1, None)
|
||||||
|
<< commands.CloseConnection(tctx.client)
|
||||||
|
<< commands.Hook("websocket_error", f)
|
||||||
|
>> events.HookReply(-1, None)
|
||||||
|
<< commands.Hook("websocket_end", f)
|
||||||
|
>> events.HookReply(-1, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert f().error
|
@ -1,8 +1,10 @@
|
|||||||
from mitmproxy import websocket, http
|
from mitmproxy import websocket, http, flow
|
||||||
from mitmproxy.proxy.protocol2 import events, commands
|
from mitmproxy.proxy.protocol2 import events, commands
|
||||||
from mitmproxy.proxy.protocol2.context import ClientServerContext
|
from mitmproxy.proxy.protocol2.context import ClientServerContext
|
||||||
from mitmproxy.proxy.protocol2.layer import Layer
|
from mitmproxy.proxy.protocol2.layer import Layer
|
||||||
from mitmproxy.proxy.protocol2.utils import expect
|
from mitmproxy.proxy.protocol2.utils import expect
|
||||||
|
from wsproto import connection as wsconn
|
||||||
|
from wsproto import events as wsevents
|
||||||
|
|
||||||
|
|
||||||
class WebsocketLayer(Layer):
|
class WebsocketLayer(Layer):
|
||||||
@ -16,14 +18,101 @@ class WebsocketLayer(Layer):
|
|||||||
super().__init__(context)
|
super().__init__(context)
|
||||||
self.flow = websocket.WebSocketFlow(context.client, context.server, handshake_flow)
|
self.flow = websocket.WebSocketFlow(context.client, context.server, handshake_flow)
|
||||||
assert context.server.connected
|
assert context.server.connected
|
||||||
|
self.client_frame_buffer = []
|
||||||
|
self.server_frame_buffer = []
|
||||||
|
extension = self.flow.server_extensions
|
||||||
|
self.client_conn = wsconn.WSConnection(wsconn.SERVER, wsconn.ConnectionState.OPEN,
|
||||||
|
extensions=[extension] if extension else None)
|
||||||
|
|
||||||
|
self.server_conn = wsconn.WSConnection(wsconn.CLIENT, wsconn.ConnectionState.OPEN,
|
||||||
|
extensions=[extension] if extension else None)
|
||||||
|
|
||||||
@expect(events.Start)
|
@expect(events.Start)
|
||||||
def start(self, _) -> commands.TCommandGenerator:
|
def start(self, _) -> commands.TCommandGenerator:
|
||||||
yield from ()
|
yield commands.Hook("websocket_start", self.flow)
|
||||||
self._handle_event = self.relay_messages
|
self._handle_event = self.relay_messages
|
||||||
|
|
||||||
|
_handle_event = start
|
||||||
|
|
||||||
@expect(events.DataReceived, events.ConnectionClosed)
|
@expect(events.DataReceived, events.ConnectionClosed)
|
||||||
def relay_messages(self, event: events.Event) -> commands.TCommandGenerator:
|
def relay_messages(self, event: events.Event) -> commands.TCommandGenerator:
|
||||||
raise NotImplementedError()
|
if isinstance(event, events.DataReceived):
|
||||||
|
from_client = event.connection == self.context.client
|
||||||
|
if from_client:
|
||||||
|
source = self.client_conn
|
||||||
|
other = self.server_conn
|
||||||
|
fb = self.client_frame_buffer
|
||||||
|
send_to = self.context.server
|
||||||
|
else:
|
||||||
|
source = self.server_conn
|
||||||
|
other = self.client_conn
|
||||||
|
fb = self.server_frame_buffer
|
||||||
|
send_to = self.context.client
|
||||||
|
|
||||||
_handle_event = start
|
source.receive_bytes(event.data)
|
||||||
|
|
||||||
|
for ws_event in source.events():
|
||||||
|
if isinstance(ws_event, wsevents.DataReceived):
|
||||||
|
fb.append(ws_event.data)
|
||||||
|
if ws_event.message_finished:
|
||||||
|
if isinstance(ws_event, wsevents.BytesReceived):
|
||||||
|
payload = b"".join(fb)
|
||||||
|
else:
|
||||||
|
payload = "".join(fb)
|
||||||
|
fb.clear()
|
||||||
|
websocket_message = websocket.WebSocketMessage(0x1 if isinstance(ws_event, wsevents.TextReceived) else 0x2,
|
||||||
|
from_client, payload)
|
||||||
|
self.flow.messages.append(websocket_message)
|
||||||
|
yield commands.Hook("websocket_message", self.flow)
|
||||||
|
|
||||||
|
other.send_data(ws_event.data, ws_event.message_finished)
|
||||||
|
yield commands.SendData(send_to, other.bytes_to_send())
|
||||||
|
elif isinstance(ws_event, wsevents.PingReceived):
|
||||||
|
yield commands.Log(
|
||||||
|
"info",
|
||||||
|
"Websocket PING received {}".format(ws_event.payload.decode())
|
||||||
|
)
|
||||||
|
other.ping()
|
||||||
|
yield commands.SendData(send_to, other.bytes_to_send())
|
||||||
|
yield commands.SendData(self.context.client if from_client else self.context.server, source.bytes_to_send())
|
||||||
|
elif isinstance(ws_event, wsevents.PongReceived):
|
||||||
|
yield commands.Log(
|
||||||
|
"info",
|
||||||
|
"Websocket PONG received {}".format(ws_event.payload.decode())
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(ws_event, wsevents.ConnectionClosed):
|
||||||
|
other.close(ws_event.code, ws_event.reason)
|
||||||
|
yield commands.SendData(send_to, other.bytes_to_send())
|
||||||
|
# FIXME: Wait for other end to actually send the closing frame
|
||||||
|
yield commands.SendData(self.context.client if from_client else self.context.server, source.bytes_to_send())
|
||||||
|
|
||||||
|
if ws_event.code != 1000:
|
||||||
|
self.flow.error = flow.Error(
|
||||||
|
"WebSocket connection closed unexpectedly by {}: {}".format(
|
||||||
|
"client" if from_client else "server",
|
||||||
|
ws_event.reason
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield commands.Hook("websocket_error", self.flow)
|
||||||
|
|
||||||
|
yield commands.Hook("websocket_end", self.flow)
|
||||||
|
if not from_client:
|
||||||
|
yield commands.CloseConnection(self.context.client)
|
||||||
|
self._handle_event = self.done
|
||||||
|
elif isinstance(event, events.ConnectionClosed):
|
||||||
|
yield commands.Log("error", "Connection closed abnormally")
|
||||||
|
self.flow.error = flow.Error(
|
||||||
|
"WebSocket connection closed unexpectedly by {}".format(
|
||||||
|
"client" if event.connection == self.context.client else "server"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if event.connection == self.context.server:
|
||||||
|
yield commands.CloseConnection(self.context.client)
|
||||||
|
yield commands.Hook("websocket_error", self.flow)
|
||||||
|
yield commands.Hook("websocket_end", self.flow)
|
||||||
|
self._handle_event = self.done
|
||||||
|
|
||||||
|
@expect(events.DataReceived, events.ConnectionClosed)
|
||||||
|
def done(self, _):
|
||||||
|
yield from ()
|
||||||
|
Loading…
Reference in New Issue
Block a user