mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2025-01-30 14:58:38 +00:00
refactor websockets into protocol
This commit is contained in:
parent
6dcfc35011
commit
bd5ee21284
2
netlib/websockets/__init__.py
Normal file
2
netlib/websockets/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from frame import *
|
||||
from protocol import *
|
@ -5,26 +5,14 @@ import os
|
||||
import struct
|
||||
import io
|
||||
|
||||
from . import utils, odict, tcp
|
||||
from .protocol import Masker
|
||||
from .. import utils, odict, tcp
|
||||
|
||||
# Colleciton of utility functions that implement small portions of the RFC6455
|
||||
# WebSockets Protocol Useful for building WebSocket clients and servers.
|
||||
#
|
||||
# Emphassis is on readabilty, simplicity and modularity, not performance or
|
||||
# completeness
|
||||
#
|
||||
# This is a work in progress and does not yet contain all the utilites need to
|
||||
# create fully complient client/servers #
|
||||
# Spec: https://tools.ietf.org/html/rfc6455
|
||||
DEFAULT = object()
|
||||
|
||||
# The magic sha that websocket servers must know to prove they understand
|
||||
# RFC6455
|
||||
websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
|
||||
VERSION = "13"
|
||||
MAX_16_BIT_INT = (1 << 16)
|
||||
MAX_64_BIT_INT = (1 << 64)
|
||||
|
||||
|
||||
OPCODE = utils.BiDi(
|
||||
CONTINUE=0x00,
|
||||
TEXT=0x01,
|
||||
@ -34,101 +22,6 @@ OPCODE = utils.BiDi(
|
||||
PONG=0x0a
|
||||
)
|
||||
|
||||
|
||||
class Masker(object):
|
||||
|
||||
"""
|
||||
Data sent from the server must be masked to prevent malicious clients
|
||||
from sending data over the wire in predictable patterns
|
||||
|
||||
Servers do not have to mask data they send to the client.
|
||||
https://tools.ietf.org/html/rfc6455#section-5.3
|
||||
"""
|
||||
|
||||
def __init__(self, key):
|
||||
self.key = key
|
||||
self.masks = [utils.bytes_to_int(byte) for byte in key]
|
||||
self.offset = 0
|
||||
|
||||
def mask(self, offset, data):
|
||||
result = ""
|
||||
for c in data:
|
||||
result += chr(ord(c) ^ self.masks[offset % 4])
|
||||
offset += 1
|
||||
return result
|
||||
|
||||
def __call__(self, data):
|
||||
ret = self.mask(self.offset, data)
|
||||
self.offset += len(ret)
|
||||
return ret
|
||||
|
||||
|
||||
def client_handshake_headers(key=None, version=VERSION):
|
||||
"""
|
||||
Create the headers for a valid HTTP upgrade request. If Key is not
|
||||
specified, it is generated, and can be found in sec-websocket-key in
|
||||
the returned header set.
|
||||
|
||||
Returns an instance of ODictCaseless
|
||||
"""
|
||||
if not key:
|
||||
key = base64.b64encode(os.urandom(16)).decode('utf-8')
|
||||
return odict.ODictCaseless([
|
||||
('Connection', 'Upgrade'),
|
||||
('Upgrade', 'websocket'),
|
||||
('Sec-WebSocket-Key', key),
|
||||
('Sec-WebSocket-Version', version)
|
||||
])
|
||||
|
||||
|
||||
def server_handshake_headers(key):
|
||||
"""
|
||||
The server response is a valid HTTP 101 response.
|
||||
"""
|
||||
return odict.ODictCaseless(
|
||||
[
|
||||
('Connection', 'Upgrade'),
|
||||
('Upgrade', 'websocket'),
|
||||
('Sec-WebSocket-Accept', create_server_nonce(key))
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def make_length_code(length):
|
||||
"""
|
||||
A websockets frame contains an initial length_code, and an optional
|
||||
extended length code to represent the actual length if length code is
|
||||
larger than 125
|
||||
"""
|
||||
if length <= 125:
|
||||
return length
|
||||
elif length >= 126 and length <= 65535:
|
||||
return 126
|
||||
else:
|
||||
return 127
|
||||
|
||||
|
||||
def check_client_handshake(headers):
|
||||
if headers.get_first("upgrade", None) != "websocket":
|
||||
return
|
||||
return headers.get_first('sec-websocket-key')
|
||||
|
||||
|
||||
def check_server_handshake(headers):
|
||||
if headers.get_first("upgrade", None) != "websocket":
|
||||
return
|
||||
return headers.get_first('sec-websocket-accept')
|
||||
|
||||
|
||||
def create_server_nonce(client_nonce):
|
||||
return base64.b64encode(
|
||||
hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex')
|
||||
)
|
||||
|
||||
|
||||
DEFAULT = object()
|
||||
|
||||
|
||||
class FrameHeader(object):
|
||||
|
||||
def __init__(
|
||||
@ -153,7 +46,7 @@ class FrameHeader(object):
|
||||
self.rsv3 = rsv3
|
||||
|
||||
if length_code is DEFAULT:
|
||||
self.length_code = make_length_code(self.payload_length)
|
||||
self.length_code = self._make_length_code(self.payload_length)
|
||||
else:
|
||||
self.length_code = length_code
|
||||
|
||||
@ -173,6 +66,20 @@ class FrameHeader(object):
|
||||
if self.masking_key and len(self.masking_key) != 4:
|
||||
raise ValueError("Masking key must be 4 bytes.")
|
||||
|
||||
@classmethod
|
||||
def _make_length_code(self, length):
|
||||
"""
|
||||
A websockets frame contains an initial length_code, and an optional
|
||||
extended length code to represent the actual length if length code is
|
||||
larger than 125
|
||||
"""
|
||||
if length <= 125:
|
||||
return length
|
||||
elif length >= 126 and length <= 65535:
|
||||
return 126
|
||||
else:
|
||||
return 127
|
||||
|
||||
def human_readable(self):
|
||||
vals = [
|
||||
"ws frame:",
|
111
netlib/websockets/protocol.py
Normal file
111
netlib/websockets/protocol.py
Normal file
@ -0,0 +1,111 @@
|
||||
from __future__ import absolute_import
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import struct
|
||||
import io
|
||||
|
||||
from .. import utils, odict, tcp
|
||||
|
||||
# Colleciton of utility functions that implement small portions of the RFC6455
|
||||
# WebSockets Protocol Useful for building WebSocket clients and servers.
|
||||
#
|
||||
# Emphassis is on readabilty, simplicity and modularity, not performance or
|
||||
# completeness
|
||||
#
|
||||
# This is a work in progress and does not yet contain all the utilites need to
|
||||
# create fully complient client/servers #
|
||||
# Spec: https://tools.ietf.org/html/rfc6455
|
||||
|
||||
# The magic sha that websocket servers must know to prove they understand
|
||||
# RFC6455
|
||||
websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
|
||||
VERSION = "13"
|
||||
|
||||
HEADER_WEBSOCKET_KEY = 'sec-websocket-key'
|
||||
HEADER_WEBSOCKET_ACCEPT = 'sec-websocket-accept'
|
||||
HEADER_WEBSOCKET_VERSION = 'sec-websocket-version'
|
||||
|
||||
class Masker(object):
|
||||
|
||||
"""
|
||||
Data sent from the server must be masked to prevent malicious clients
|
||||
from sending data over the wire in predictable patterns
|
||||
|
||||
Servers do not have to mask data they send to the client.
|
||||
https://tools.ietf.org/html/rfc6455#section-5.3
|
||||
"""
|
||||
|
||||
def __init__(self, key):
|
||||
self.key = key
|
||||
self.masks = [utils.bytes_to_int(byte) for byte in key]
|
||||
self.offset = 0
|
||||
|
||||
def mask(self, offset, data):
|
||||
result = ""
|
||||
for c in data:
|
||||
result += chr(ord(c) ^ self.masks[offset % 4])
|
||||
offset += 1
|
||||
return result
|
||||
|
||||
def __call__(self, data):
|
||||
ret = self.mask(self.offset, data)
|
||||
self.offset += len(ret)
|
||||
return ret
|
||||
|
||||
class WebsocketsProtocol(object):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def client_handshake_headers(self, key=None, version=VERSION):
|
||||
"""
|
||||
Create the headers for a valid HTTP upgrade request. If Key is not
|
||||
specified, it is generated, and can be found in sec-websocket-key in
|
||||
the returned header set.
|
||||
|
||||
Returns an instance of ODictCaseless
|
||||
"""
|
||||
if not key:
|
||||
key = base64.b64encode(os.urandom(16)).decode('utf-8')
|
||||
return odict.ODictCaseless([
|
||||
('Connection', 'Upgrade'),
|
||||
('Upgrade', 'websocket'),
|
||||
(HEADER_WEBSOCKET_KEY, key),
|
||||
(HEADER_WEBSOCKET_VERSION, version)
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def server_handshake_headers(self, key):
|
||||
"""
|
||||
The server response is a valid HTTP 101 response.
|
||||
"""
|
||||
return odict.ODictCaseless(
|
||||
[
|
||||
('Connection', 'Upgrade'),
|
||||
('Upgrade', 'websocket'),
|
||||
(HEADER_WEBSOCKET_ACCEPT, self.create_server_nonce(key))
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def check_client_handshake(self, headers):
|
||||
if headers.get_first("upgrade", None) != "websocket":
|
||||
return
|
||||
return headers.get_first(HEADER_WEBSOCKET_KEY)
|
||||
|
||||
|
||||
@classmethod
|
||||
def check_server_handshake(self, headers):
|
||||
if headers.get_first("upgrade", None) != "websocket":
|
||||
return
|
||||
return headers.get_first(HEADER_WEBSOCKET_ACCEPT)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_server_nonce(self, client_nonce):
|
||||
return base64.b64encode(
|
||||
hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex')
|
||||
)
|
@ -12,6 +12,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
|
||||
super(WebSocketsEchoHandler, self).__init__(
|
||||
connection, address, server
|
||||
)
|
||||
self.protocol = websockets.WebsocketsProtocol()
|
||||
self.handshake_done = False
|
||||
|
||||
def handle(self):
|
||||
@ -31,10 +32,10 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
|
||||
|
||||
def handshake(self):
|
||||
req = http.read_request(self.rfile)
|
||||
key = websockets.check_client_handshake(req.headers)
|
||||
key = self.protocol.check_client_handshake(req.headers)
|
||||
|
||||
self.wfile.write(http.response_preamble(101) + "\r\n")
|
||||
headers = websockets.server_handshake_headers(key)
|
||||
headers = self.protocol.server_handshake_headers(key)
|
||||
self.wfile.write(headers.format() + "\r\n")
|
||||
self.wfile.flush()
|
||||
self.handshake_done = True
|
||||
@ -48,6 +49,7 @@ class WebSocketsClient(tcp.TCPClient):
|
||||
|
||||
def __init__(self, address, source_address=None):
|
||||
super(WebSocketsClient, self).__init__(address, source_address)
|
||||
self.protocol = websockets.WebsocketsProtocol()
|
||||
self.client_nonce = None
|
||||
|
||||
def connect(self):
|
||||
@ -55,15 +57,15 @@ class WebSocketsClient(tcp.TCPClient):
|
||||
|
||||
preamble = http.request_preamble("GET", "/")
|
||||
self.wfile.write(preamble + "\r\n")
|
||||
headers = websockets.client_handshake_headers()
|
||||
headers = self.protocol.client_handshake_headers()
|
||||
self.client_nonce = headers.get_first("sec-websocket-key")
|
||||
self.wfile.write(headers.format() + "\r\n")
|
||||
self.wfile.flush()
|
||||
|
||||
resp = http.read_response(self.rfile, "get", None)
|
||||
server_nonce = websockets.check_server_handshake(resp.headers)
|
||||
server_nonce = self.protocol.check_server_handshake(resp.headers)
|
||||
|
||||
if not server_nonce == websockets.create_server_nonce(
|
||||
if not server_nonce == self.protocol.create_server_nonce(
|
||||
self.client_nonce):
|
||||
self.close()
|
||||
|
||||
@ -78,6 +80,9 @@ class WebSocketsClient(tcp.TCPClient):
|
||||
class TestWebSockets(tservers.ServerTestBase):
|
||||
handler = WebSocketsEchoHandler
|
||||
|
||||
def __init__(self):
|
||||
self.protocol = websockets.WebsocketsProtocol()
|
||||
|
||||
def random_bytes(self, n=100):
|
||||
return os.urandom(n)
|
||||
|
||||
@ -130,26 +135,26 @@ class TestWebSockets(tservers.ServerTestBase):
|
||||
assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes
|
||||
|
||||
def test_check_server_handshake(self):
|
||||
headers = websockets.server_handshake_headers("key")
|
||||
assert websockets.check_server_handshake(headers)
|
||||
headers = self.protocol.server_handshake_headers("key")
|
||||
assert self.protocol.check_server_handshake(headers)
|
||||
headers["Upgrade"] = ["not_websocket"]
|
||||
assert not websockets.check_server_handshake(headers)
|
||||
assert not self.protocol.check_server_handshake(headers)
|
||||
|
||||
def test_check_client_handshake(self):
|
||||
headers = websockets.client_handshake_headers("key")
|
||||
assert websockets.check_client_handshake(headers) == "key"
|
||||
headers = self.protocol.client_handshake_headers("key")
|
||||
assert self.protocol.check_client_handshake(headers) == "key"
|
||||
headers["Upgrade"] = ["not_websocket"]
|
||||
assert not websockets.check_client_handshake(headers)
|
||||
assert not self.protocol.check_client_handshake(headers)
|
||||
|
||||
|
||||
class BadHandshakeHandler(WebSocketsEchoHandler):
|
||||
|
||||
def handshake(self):
|
||||
client_hs = http.read_request(self.rfile)
|
||||
websockets.check_client_handshake(client_hs.headers)
|
||||
self.protocol.check_client_handshake(client_hs.headers)
|
||||
|
||||
self.wfile.write(http.response_preamble(101) + "\r\n")
|
||||
headers = websockets.server_handshake_headers("malformed key")
|
||||
headers = self.protocol.server_handshake_headers("malformed key")
|
||||
self.wfile.write(headers.format() + "\r\n")
|
||||
self.wfile.flush()
|
||||
self.handshake_done = True
|
||||
|
Loading…
Reference in New Issue
Block a user