refactor websockets into protocol

This commit is contained in:
Thomas Kriechbaumer 2015-07-08 09:34:10 +02:00
parent 6dcfc35011
commit bd5ee21284
4 changed files with 149 additions and 124 deletions

View File

@ -0,0 +1,2 @@
from frame import *
from protocol import *

View File

@ -5,26 +5,14 @@ import os
import struct
import io
from . import utils, odict, tcp
from .protocol import Masker
from .. import utils, odict, tcp
# Colleciton of utility functions that implement small portions of the RFC6455
# WebSockets Protocol Useful for building WebSocket clients and servers.
#
# Emphassis is on readabilty, simplicity and modularity, not performance or
# completeness
#
# This is a work in progress and does not yet contain all the utilites need to
# create fully complient client/servers #
# Spec: https://tools.ietf.org/html/rfc6455
DEFAULT = object()
# The magic sha that websocket servers must know to prove they understand
# RFC6455
websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
VERSION = "13"
MAX_16_BIT_INT = (1 << 16)
MAX_64_BIT_INT = (1 << 64)
OPCODE = utils.BiDi(
CONTINUE=0x00,
TEXT=0x01,
@ -34,101 +22,6 @@ OPCODE = utils.BiDi(
PONG=0x0a
)
class Masker(object):
"""
Data sent from the server must be masked to prevent malicious clients
from sending data over the wire in predictable patterns
Servers do not have to mask data they send to the client.
https://tools.ietf.org/html/rfc6455#section-5.3
"""
def __init__(self, key):
self.key = key
self.masks = [utils.bytes_to_int(byte) for byte in key]
self.offset = 0
def mask(self, offset, data):
result = ""
for c in data:
result += chr(ord(c) ^ self.masks[offset % 4])
offset += 1
return result
def __call__(self, data):
ret = self.mask(self.offset, data)
self.offset += len(ret)
return ret
def client_handshake_headers(key=None, version=VERSION):
"""
Create the headers for a valid HTTP upgrade request. If Key is not
specified, it is generated, and can be found in sec-websocket-key in
the returned header set.
Returns an instance of ODictCaseless
"""
if not key:
key = base64.b64encode(os.urandom(16)).decode('utf-8')
return odict.ODictCaseless([
('Connection', 'Upgrade'),
('Upgrade', 'websocket'),
('Sec-WebSocket-Key', key),
('Sec-WebSocket-Version', version)
])
def server_handshake_headers(key):
"""
The server response is a valid HTTP 101 response.
"""
return odict.ODictCaseless(
[
('Connection', 'Upgrade'),
('Upgrade', 'websocket'),
('Sec-WebSocket-Accept', create_server_nonce(key))
]
)
def make_length_code(length):
"""
A websockets frame contains an initial length_code, and an optional
extended length code to represent the actual length if length code is
larger than 125
"""
if length <= 125:
return length
elif length >= 126 and length <= 65535:
return 126
else:
return 127
def check_client_handshake(headers):
if headers.get_first("upgrade", None) != "websocket":
return
return headers.get_first('sec-websocket-key')
def check_server_handshake(headers):
if headers.get_first("upgrade", None) != "websocket":
return
return headers.get_first('sec-websocket-accept')
def create_server_nonce(client_nonce):
return base64.b64encode(
hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex')
)
DEFAULT = object()
class FrameHeader(object):
def __init__(
@ -153,7 +46,7 @@ class FrameHeader(object):
self.rsv3 = rsv3
if length_code is DEFAULT:
self.length_code = make_length_code(self.payload_length)
self.length_code = self._make_length_code(self.payload_length)
else:
self.length_code = length_code
@ -173,6 +66,20 @@ class FrameHeader(object):
if self.masking_key and len(self.masking_key) != 4:
raise ValueError("Masking key must be 4 bytes.")
@classmethod
def _make_length_code(self, length):
"""
A websockets frame contains an initial length_code, and an optional
extended length code to represent the actual length if length code is
larger than 125
"""
if length <= 125:
return length
elif length >= 126 and length <= 65535:
return 126
else:
return 127
def human_readable(self):
vals = [
"ws frame:",

View File

@ -0,0 +1,111 @@
from __future__ import absolute_import
import base64
import hashlib
import os
import struct
import io
from .. import utils, odict, tcp
# Colleciton of utility functions that implement small portions of the RFC6455
# WebSockets Protocol Useful for building WebSocket clients and servers.
#
# Emphassis is on readabilty, simplicity and modularity, not performance or
# completeness
#
# This is a work in progress and does not yet contain all the utilites need to
# create fully complient client/servers #
# Spec: https://tools.ietf.org/html/rfc6455
# The magic sha that websocket servers must know to prove they understand
# RFC6455
websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
VERSION = "13"
HEADER_WEBSOCKET_KEY = 'sec-websocket-key'
HEADER_WEBSOCKET_ACCEPT = 'sec-websocket-accept'
HEADER_WEBSOCKET_VERSION = 'sec-websocket-version'
class Masker(object):
"""
Data sent from the server must be masked to prevent malicious clients
from sending data over the wire in predictable patterns
Servers do not have to mask data they send to the client.
https://tools.ietf.org/html/rfc6455#section-5.3
"""
def __init__(self, key):
self.key = key
self.masks = [utils.bytes_to_int(byte) for byte in key]
self.offset = 0
def mask(self, offset, data):
result = ""
for c in data:
result += chr(ord(c) ^ self.masks[offset % 4])
offset += 1
return result
def __call__(self, data):
ret = self.mask(self.offset, data)
self.offset += len(ret)
return ret
class WebsocketsProtocol(object):
def __init__(self):
pass
@classmethod
def client_handshake_headers(self, key=None, version=VERSION):
"""
Create the headers for a valid HTTP upgrade request. If Key is not
specified, it is generated, and can be found in sec-websocket-key in
the returned header set.
Returns an instance of ODictCaseless
"""
if not key:
key = base64.b64encode(os.urandom(16)).decode('utf-8')
return odict.ODictCaseless([
('Connection', 'Upgrade'),
('Upgrade', 'websocket'),
(HEADER_WEBSOCKET_KEY, key),
(HEADER_WEBSOCKET_VERSION, version)
])
@classmethod
def server_handshake_headers(self, key):
"""
The server response is a valid HTTP 101 response.
"""
return odict.ODictCaseless(
[
('Connection', 'Upgrade'),
('Upgrade', 'websocket'),
(HEADER_WEBSOCKET_ACCEPT, self.create_server_nonce(key))
]
)
@classmethod
def check_client_handshake(self, headers):
if headers.get_first("upgrade", None) != "websocket":
return
return headers.get_first(HEADER_WEBSOCKET_KEY)
@classmethod
def check_server_handshake(self, headers):
if headers.get_first("upgrade", None) != "websocket":
return
return headers.get_first(HEADER_WEBSOCKET_ACCEPT)
@classmethod
def create_server_nonce(self, client_nonce):
return base64.b64encode(
hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex')
)

View File

@ -12,6 +12,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
super(WebSocketsEchoHandler, self).__init__(
connection, address, server
)
self.protocol = websockets.WebsocketsProtocol()
self.handshake_done = False
def handle(self):
@ -31,10 +32,10 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
def handshake(self):
req = http.read_request(self.rfile)
key = websockets.check_client_handshake(req.headers)
key = self.protocol.check_client_handshake(req.headers)
self.wfile.write(http.response_preamble(101) + "\r\n")
headers = websockets.server_handshake_headers(key)
headers = self.protocol.server_handshake_headers(key)
self.wfile.write(headers.format() + "\r\n")
self.wfile.flush()
self.handshake_done = True
@ -48,6 +49,7 @@ class WebSocketsClient(tcp.TCPClient):
def __init__(self, address, source_address=None):
super(WebSocketsClient, self).__init__(address, source_address)
self.protocol = websockets.WebsocketsProtocol()
self.client_nonce = None
def connect(self):
@ -55,15 +57,15 @@ class WebSocketsClient(tcp.TCPClient):
preamble = http.request_preamble("GET", "/")
self.wfile.write(preamble + "\r\n")
headers = websockets.client_handshake_headers()
headers = self.protocol.client_handshake_headers()
self.client_nonce = headers.get_first("sec-websocket-key")
self.wfile.write(headers.format() + "\r\n")
self.wfile.flush()
resp = http.read_response(self.rfile, "get", None)
server_nonce = websockets.check_server_handshake(resp.headers)
server_nonce = self.protocol.check_server_handshake(resp.headers)
if not server_nonce == websockets.create_server_nonce(
if not server_nonce == self.protocol.create_server_nonce(
self.client_nonce):
self.close()
@ -78,6 +80,9 @@ class WebSocketsClient(tcp.TCPClient):
class TestWebSockets(tservers.ServerTestBase):
handler = WebSocketsEchoHandler
def __init__(self):
self.protocol = websockets.WebsocketsProtocol()
def random_bytes(self, n=100):
return os.urandom(n)
@ -130,26 +135,26 @@ class TestWebSockets(tservers.ServerTestBase):
assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes
def test_check_server_handshake(self):
headers = websockets.server_handshake_headers("key")
assert websockets.check_server_handshake(headers)
headers = self.protocol.server_handshake_headers("key")
assert self.protocol.check_server_handshake(headers)
headers["Upgrade"] = ["not_websocket"]
assert not websockets.check_server_handshake(headers)
assert not self.protocol.check_server_handshake(headers)
def test_check_client_handshake(self):
headers = websockets.client_handshake_headers("key")
assert websockets.check_client_handshake(headers) == "key"
headers = self.protocol.client_handshake_headers("key")
assert self.protocol.check_client_handshake(headers) == "key"
headers["Upgrade"] = ["not_websocket"]
assert not websockets.check_client_handshake(headers)
assert not self.protocol.check_client_handshake(headers)
class BadHandshakeHandler(WebSocketsEchoHandler):
def handshake(self):
client_hs = http.read_request(self.rfile)
websockets.check_client_handshake(client_hs.headers)
self.protocol.check_client_handshake(client_hs.headers)
self.wfile.write(http.response_preamble(101) + "\r\n")
headers = websockets.server_handshake_headers("malformed key")
headers = self.protocol.server_handshake_headers("malformed key")
self.wfile.write(headers.format() + "\r\n")
self.wfile.flush()
self.handshake_done = True