Merge pull request #4283 from Kriechi/decouple++

websocket: decouple from pathod
This commit is contained in:
Thomas Kriechbaumer 2020-11-07 17:15:04 +01:00 committed by GitHub
commit 488be14412
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 164 additions and 65 deletions

View File

@ -7,12 +7,70 @@ Spec: https://tools.ietf.org/html/rfc6455
import base64
import hashlib
import os
import struct
from wsproto.utilities import ACCEPT_GUID
from wsproto.handshake import WEBSOCKET_VERSION
from wsproto.frame_protocol import RsvBits, Header, Frame, XorMaskerSimple, XorMaskerNull
from mitmproxy.net import http
from mitmproxy.utils import strutils
from mitmproxy.utils import bits, strutils
def read_raw_frame(rfile):
consumed_bytes = b''
def consume(len):
nonlocal consumed_bytes
d = rfile.safe_read(len)
consumed_bytes += d
return d
first_byte, second_byte = consume(2)
fin = bits.getbit(first_byte, 7)
rsv1 = bits.getbit(first_byte, 6)
rsv2 = bits.getbit(first_byte, 5)
rsv3 = bits.getbit(first_byte, 4)
opcode = first_byte & 0xF
mask_bit = bits.getbit(second_byte, 7)
length_code = second_byte & 0x7F
# payload_len > 125 indicates you need to read more bytes
# to get the actual payload length
if length_code <= 125:
payload_len = length_code
elif length_code == 126:
payload_len, = struct.unpack("!H", consume(2))
else: # length_code == 127:
payload_len, = struct.unpack("!Q", consume(8))
# masking key only present if mask bit set
if mask_bit == 1:
masking_key = consume(4)
masker = XorMaskerSimple(masking_key)
else:
masking_key = None
masker = XorMaskerNull()
header = Header(
fin=fin,
rsv=RsvBits(rsv1, rsv2, rsv3),
opcode=opcode,
payload_len=payload_len,
masking_key=masking_key,
)
masked_payload = consume(payload_len)
payload = masker.process(masked_payload)
frame = Frame(
opcode=opcode,
payload=payload,
frame_finished=fin,
message_finished=fin
)
return header, frame, consumed_bytes
def client_handshake_headers(version=None, key=None, protocol=None, extensions=None):

View File

@ -8,15 +8,12 @@ from wsproto.connection import ConnectionType
from wsproto.events import AcceptConnection, CloseConnection, Message, Ping, Request
from wsproto.extensions import PerMessageDeflate
from mitmproxy import exceptions
from mitmproxy import flow
from mitmproxy import exceptions, flow
from mitmproxy.proxy.protocol import base
from mitmproxy.net import tcp
from mitmproxy.net import tcp, websocket_utils
from mitmproxy.websocket import WebSocketFlow, WebSocketMessage
from mitmproxy.utils import strutils
from pathod.language import websockets_frame
class WebSocketLayer(base.Layer):
"""
@ -79,6 +76,10 @@ class WebSocketLayer(base.Layer):
assert isinstance(next(self.connections[self.server_conn].events()), events.AcceptConnection)
def _handle_event(self, event, source_conn, other_conn, is_server):
self.log(
"WebSocket Event from {}: {}".format("server" if is_server else "client", event),
"debug"
)
if isinstance(event, events.Message):
return self._handle_message(event, source_conn, other_conn, is_server)
elif isinstance(event, events.Ping):
@ -199,9 +200,17 @@ class WebSocketLayer(base.Layer):
other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn
is_server = (source_conn == self.server_conn)
# TODO: replace this method from pathod with a stack-agnostic version
frame = websockets_frame.Frame.from_file(source_conn.rfile)
data = self.connections[source_conn].receive_data(bytes(frame))
header, frame, consumed_bytes = websocket_utils.read_raw_frame(source_conn.rfile)
self.log(
"WebSocket Frame from {}: {}, {}".format(
"server" if is_server else "client",
header,
frame,
),
"debug"
)
data = self.connections[source_conn].receive_data(consumed_bytes)
source_conn.send(data)
if close_received:

View File

@ -1,9 +1,42 @@
import pytest
from io import BytesIO
from unittest import mock
from wsproto.frame_protocol import Opcode, RsvBits, Header, Frame
from mitmproxy.net.http import Headers
from mitmproxy.net import websocket_utils
@pytest.mark.parametrize("input,masking_key,payload_length", [
(b'\x01\rserver-foobar', None, 13),
(b'\x01\x8dasdf\x12\x16\x16\x10\x04\x01I\x00\x0e\x1c\x06\x07\x13', b'asdf', 13),
(b'\x01~\x04\x00server-foobar', None, 1024),
(b'\x01\x7f\x00\x00\x00\x00\x00\x02\x00\x00server-foobar', None, 131072),
])
def test_read_raw_frame(input, masking_key, payload_length):
bio = BytesIO(input)
bio.safe_read = bio.read
header, frame, consumed_bytes = websocket_utils.read_raw_frame(bio)
assert header == \
Header(
fin=False,
rsv=RsvBits(rsv1=False, rsv2=False, rsv3=False),
opcode=Opcode.TEXT,
payload_len=payload_length,
masking_key=masking_key,
)
assert frame == \
Frame(
opcode=Opcode.TEXT,
payload=b'server-foobar',
frame_finished=False,
message_finished=False,
)
assert consumed_bytes == input
@mock.patch('os.urandom', return_value=b'pumpkinspumpkins')
def test_client_handshake_headers(_):
assert websocket_utils.client_handshake_headers() == \

View File

@ -146,12 +146,12 @@ class TestSimple(_WebSocketTest):
wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=Opcode.TEXT, payload=b'server-foobar')))
wfile.flush()
frame = websockets_frame.Frame.from_file(rfile)
wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload)))
header, frame, _ = websocket_utils.read_raw_frame(rfile)
wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=header.opcode, payload=frame.payload)))
wfile.flush()
frame = websockets_frame.Frame.from_file(rfile)
wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload)))
header, frame, _ = websocket_utils.read_raw_frame(rfile)
wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=header.opcode, payload=frame.payload)))
wfile.flush()
@pytest.mark.parametrize('streaming', [True, False])
@ -163,19 +163,19 @@ class TestSimple(_WebSocketTest):
self.proxy.set_addons(Stream())
self.setup_connection()
frame = websockets_frame.Frame.from_file(self.client.rfile)
_, frame, _ = websocket_utils.read_raw_frame(self.client.rfile)
assert frame.payload == b'server-foobar'
self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.TEXT, payload=b'self.client-foobar')))
self.client.wfile.flush()
frame = websockets_frame.Frame.from_file(self.client.rfile)
_, frame, _ = websocket_utils.read_raw_frame(self.client.rfile)
assert frame.payload == b'self.client-foobar'
self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.BINARY, payload=b'\xde\xad\xbe\xef')))
self.client.wfile.flush()
frame = websockets_frame.Frame.from_file(self.client.rfile)
_, frame, _ = websocket_utils.read_raw_frame(self.client.rfile)
assert frame.payload == b'\xde\xad\xbe\xef'
self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.CLOSE)))
@ -204,19 +204,19 @@ class TestSimple(_WebSocketTest):
self.proxy.set_addons(Addon())
self.setup_connection()
frame = websockets_frame.Frame.from_file(self.client.rfile)
_, frame, _ = websocket_utils.read_raw_frame(self.client.rfile)
assert frame.payload == b'foo'
self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.TEXT, payload=b'self.client-foobar')))
self.client.wfile.flush()
frame = websockets_frame.Frame.from_file(self.client.rfile)
_, frame, _ = websocket_utils.read_raw_frame(self.client.rfile)
assert frame.payload == b'foo'
self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.BINARY, payload=b'\xde\xad\xbe\xef')))
self.client.wfile.flush()
frame = websockets_frame.Frame.from_file(self.client.rfile)
_, frame, _ = websocket_utils.read_raw_frame(self.client.rfile)
assert frame.payload == b'foo'
@ -236,7 +236,7 @@ class TestKillFlow(_WebSocketTest):
self.setup_connection()
with pytest.raises(exceptions.TcpDisconnect):
websockets_frame.Frame.from_file(self.client.rfile)
_, _, _ = websocket_utils.read_raw_frame(self.client.rfile)
class TestSimpleTLS(_WebSocketTest):
@ -247,20 +247,20 @@ class TestSimpleTLS(_WebSocketTest):
wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=Opcode.TEXT, payload=b'server-foobar')))
wfile.flush()
frame = websockets_frame.Frame.from_file(rfile)
wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload)))
header, frame, _ = websocket_utils.read_raw_frame(rfile)
wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=header.opcode, payload=frame.payload)))
wfile.flush()
def test_simple_tls(self):
self.setup_connection()
frame = websockets_frame.Frame.from_file(self.client.rfile)
_, frame, _ = websocket_utils.read_raw_frame(self.client.rfile)
assert frame.payload == b'server-foobar'
self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.TEXT, payload=b'self.client-foobar')))
self.client.wfile.flush()
frame = websockets_frame.Frame.from_file(self.client.rfile)
_, frame, _ = websocket_utils.read_raw_frame(self.client.rfile)
assert frame.payload == b'self.client-foobar'
self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.CLOSE)))
@ -274,8 +274,8 @@ class TestPing(_WebSocketTest):
wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=Opcode.PING, payload=b'foobar')))
wfile.flush()
frame = websockets_frame.Frame.from_file(rfile)
assert frame.header.opcode == Opcode.PONG
header, frame, _ = websocket_utils.read_raw_frame(rfile)
assert header.opcode == Opcode.PONG
assert frame.payload == b'foobar'
wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=Opcode.PONG, payload=b'done')))
@ -283,17 +283,17 @@ class TestPing(_WebSocketTest):
wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=Opcode.CLOSE)))
wfile.flush()
websockets_frame.Frame.from_file(rfile)
_, _, _ = websocket_utils.read_raw_frame(rfile)
@pytest.mark.asyncio
async def test_ping(self):
self.setup_connection()
frame = websockets_frame.Frame.from_file(self.client.rfile)
websockets_frame.Frame.from_file(self.client.rfile)
header, frame, _ = websocket_utils.read_raw_frame(self.client.rfile)
_ = websocket_utils.read_raw_frame(self.client.rfile)
self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.CLOSE)))
self.client.wfile.flush()
assert frame.header.opcode == Opcode.PING
assert header.opcode == Opcode.PING
assert frame.payload == b'' # We don't send payload to other end
assert await self.master.await_log("Pong Received from server", "info")
@ -303,8 +303,8 @@ class TestPong(_WebSocketTest):
@classmethod
def handle_websockets(cls, rfile, wfile):
frame = websockets_frame.Frame.from_file(rfile)
assert frame.header.opcode == Opcode.PING
header, frame, _ = websocket_utils.read_raw_frame(rfile)
assert header.opcode == Opcode.PING
assert frame.payload == b''
wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=Opcode.PONG, payload=frame.payload)))
@ -312,7 +312,7 @@ class TestPong(_WebSocketTest):
wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=Opcode.CLOSE)))
wfile.flush()
websockets_frame.Frame.from_file(rfile)
_ = websocket_utils.read_raw_frame(rfile)
@pytest.mark.asyncio
async def test_pong(self):
@ -321,12 +321,12 @@ class TestPong(_WebSocketTest):
self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.PING, payload=b'foobar')))
self.client.wfile.flush()
frame = websockets_frame.Frame.from_file(self.client.rfile)
websockets_frame.Frame.from_file(self.client.rfile)
header, frame, _ = websocket_utils.read_raw_frame(self.client.rfile)
_ = websocket_utils.read_raw_frame(self.client.rfile)
self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.CLOSE)))
self.client.wfile.flush()
assert frame.header.opcode == Opcode.PONG
assert header.opcode == Opcode.PONG
assert frame.payload == b'foobar'
assert await self.master.await_log("pong received")
@ -335,13 +335,13 @@ class TestClose(_WebSocketTest):
@classmethod
def handle_websockets(cls, rfile, wfile):
frame = websockets_frame.Frame.from_file(rfile)
wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload)))
header, frame, _ = websocket_utils.read_raw_frame(rfile)
wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=header.opcode, payload=frame.payload)))
wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=Opcode.CLOSE)))
wfile.flush()
with pytest.raises(exceptions.TcpDisconnect):
websockets_frame.Frame.from_file(rfile)
_, _, _ = websocket_utils.read_raw_frame(rfile)
def test_close(self):
self.setup_connection()
@ -349,9 +349,9 @@ class TestClose(_WebSocketTest):
self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.CLOSE)))
self.client.wfile.flush()
websockets_frame.Frame.from_file(self.client.rfile)
_ = websocket_utils.read_raw_frame(self.client.rfile)
with pytest.raises(exceptions.TcpDisconnect):
websockets_frame.Frame.from_file(self.client.rfile)
_ = websocket_utils.read_raw_frame(self.client.rfile)
def test_close_payload_1(self):
self.setup_connection()
@ -359,9 +359,9 @@ class TestClose(_WebSocketTest):
self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.CLOSE, payload=b'\00\42')))
self.client.wfile.flush()
websockets_frame.Frame.from_file(self.client.rfile)
_ = websocket_utils.read_raw_frame(self.client.rfile)
with pytest.raises(exceptions.TcpDisconnect):
websockets_frame.Frame.from_file(self.client.rfile)
_ = websocket_utils.read_raw_frame(self.client.rfile)
def test_close_payload_2(self):
self.setup_connection()
@ -369,9 +369,9 @@ class TestClose(_WebSocketTest):
self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.CLOSE, payload=b'\00\42foobar')))
self.client.wfile.flush()
websockets_frame.Frame.from_file(self.client.rfile)
_ = websocket_utils.read_raw_frame(self.client.rfile)
with pytest.raises(exceptions.TcpDisconnect):
websockets_frame.Frame.from_file(self.client.rfile)
_ = websocket_utils.read_raw_frame(self.client.rfile)
class TestInvalidFrame(_WebSocketTest):
@ -384,8 +384,7 @@ class TestInvalidFrame(_WebSocketTest):
def test_invalid_frame(self):
self.setup_connection()
# with pytest.raises(exceptions.TcpDisconnect):
frame = websockets_frame.Frame.from_file(self.client.rfile)
_, frame, _ = websocket_utils.read_raw_frame(self.client.rfile)
code, = struct.unpack('!H', frame.payload[:2])
assert code == 1002
assert frame.payload[2:].startswith(b'Invalid opcode')
@ -410,11 +409,11 @@ class TestStreaming(_WebSocketTest):
frame = None
if not streaming:
with pytest.raises(exceptions.TcpDisconnect): # Reader.safe_read get nothing as result
frame = websockets_frame.Frame.from_file(self.client.rfile)
_, frame, _ = websocket_utils.read_raw_frame(self.client.rfile)
assert frame is None
else:
frame = websockets_frame.Frame.from_file(self.client.rfile)
_, frame, _ = websocket_utils.read_raw_frame(self.client.rfile)
assert frame
assert self.master.state.flows[1].messages == [] # Message not appended as the final frame isn't received
@ -427,33 +426,33 @@ class TestExtension(_WebSocketTest):
wfile.write(b'\xc1\x0f*N-*K-\xd2M\xcb\xcfOJ,\x02\x00')
wfile.flush()
frame = websockets_frame.Frame.from_file(rfile)
assert frame.header.rsv1
header, _, _ = websocket_utils.read_raw_frame(rfile)
assert header.rsv.rsv1
wfile.write(b'\xc1\nJ\xce\xc9L\xcd+\x81r\x00\x00')
wfile.flush()
frame = websockets_frame.Frame.from_file(rfile)
assert frame.header.rsv1
header, _, _ = websocket_utils.read_raw_frame(rfile)
assert header.rsv.rsv1
wfile.write(b'\xc2\x07\xba\xb7v\xdf{\x00\x00')
wfile.flush()
def test_extension(self):
self.setup_connection(True)
frame = websockets_frame.Frame.from_file(self.client.rfile)
assert frame.header.rsv1
header, _, _ = websocket_utils.read_raw_frame(self.client.rfile)
assert header.rsv.rsv1
self.client.wfile.write(b'\xc1\x8fQ\xb7vX\x1by\xbf\x14\x9c\x9c\xa7\x15\x9ax9\x12}\xb5v')
self.client.wfile.flush()
frame = websockets_frame.Frame.from_file(self.client.rfile)
assert frame.header.rsv1
header, _, _ = websocket_utils.read_raw_frame(self.client.rfile)
assert header.rsv.rsv1
self.client.wfile.write(b'\xc2\x87\xeb\xbb\x0csQ\x0cz\xac\x90\xbb\x0c')
self.client.wfile.flush()
frame = websockets_frame.Frame.from_file(self.client.rfile)
assert frame.header.rsv1
header, _, _ = websocket_utils.read_raw_frame(self.client.rfile)
assert header.rsv.rsv1
assert len(self.master.state.flows[1].messages) == 5
assert self.master.state.flows[1].messages[0].content == 'server-foobar'
@ -482,8 +481,8 @@ class TestInjectMessageClient(_WebSocketTest):
self.proxy.set_addons(Inject())
self.setup_connection()
frame = websockets_frame.Frame.from_file(self.client.rfile)
assert frame.header.opcode == Opcode.TEXT
header, frame, _ = websocket_utils.read_raw_frame(self.client.rfile)
assert header.opcode == Opcode.TEXT
assert frame.payload == b'This is an injected message!'
@ -491,8 +490,8 @@ class TestInjectMessageServer(_WebSocketTest):
@classmethod
def handle_websockets(cls, rfile, wfile):
frame = websockets_frame.Frame.from_file(rfile)
assert frame.header.opcode == Opcode.TEXT
header, frame, _ = websocket_utils.read_raw_frame(rfile)
assert header.opcode == Opcode.TEXT
success = frame.payload == b'This is an injected message!'
wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=Opcode.TEXT, payload=str(success).encode())))
@ -506,6 +505,6 @@ class TestInjectMessageServer(_WebSocketTest):
self.proxy.set_addons(Inject())
self.setup_connection()
frame = websockets_frame.Frame.from_file(self.client.rfile)
assert frame.header.opcode == Opcode.TEXT
header, frame, _ = websocket_utils.read_raw_frame(self.client.rfile)
assert header.opcode == Opcode.TEXT
assert frame.payload == b'True'