websockets: refactor to use http and header functions in http.py

This commit is contained in:
Aldo Cortesi 2015-04-21 22:39:45 +12:00
parent e5f1264838
commit 3e0a71ea34
3 changed files with 152 additions and 194 deletions

View File

@ -4,7 +4,7 @@ import string
import urlparse import urlparse
import binascii import binascii
import sys import sys
from . import odict, utils, tcp from . import odict, utils, tcp, http_status
class HttpError(Exception): class HttpError(Exception):
@ -314,62 +314,6 @@ def parse_response_line(line):
return (proto, code, msg) return (proto, code, msg)
Response = collections.namedtuple(
"Response",
[
"httpversion",
"code",
"msg",
"headers",
"content"
]
)
def read_response(rfile, request_method, body_size_limit, include_body=True):
"""
Return an (httpversion, code, msg, headers, content) tuple.
By default, both response header and body are read.
If include_body=False is specified, content may be one of the
following:
- None, if the response is technically allowed to have a response body
- "", if the response must not have a response body (e.g. it's a
response to a HEAD request)
"""
line = rfile.readline()
# Possible leftover from previous message
if line == "\r\n" or line == "\n":
line = rfile.readline()
if not line:
raise HttpErrorConnClosed(502, "Server disconnect.")
parts = parse_response_line(line)
if not parts:
raise HttpError(502, "Invalid server response: %s" % repr(line))
proto, code, msg = parts
httpversion = parse_http_protocol(proto)
if httpversion is None:
raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto))
headers = read_headers(rfile)
if headers is None:
raise HttpError(502, "Invalid headers.")
if include_body:
content = read_http_body(
rfile,
headers,
body_size_limit,
request_method,
code,
False
)
else:
# if include_body==False then a None content means the body should be
# read separately
content = None
return Response(httpversion, code, msg, headers, content)
def read_http_body(*args, **kwargs): def read_http_body(*args, **kwargs):
return "".join( return "".join(
content for _, content, _ in read_http_body_chunked(*args, **kwargs) content for _, content, _ in read_http_body_chunked(*args, **kwargs)
@ -579,3 +523,71 @@ def read_request(rfile, include_body=True, body_size_limit=None, wfile=None):
headers, headers,
content content
) )
Response = collections.namedtuple(
"Response",
[
"httpversion",
"code",
"msg",
"headers",
"content"
]
)
def read_response(rfile, request_method, body_size_limit, include_body=True):
"""
Return an (httpversion, code, msg, headers, content) tuple.
By default, both response header and body are read.
If include_body=False is specified, content may be one of the
following:
- None, if the response is technically allowed to have a response body
- "", if the response must not have a response body (e.g. it's a
response to a HEAD request)
"""
line = rfile.readline()
# Possible leftover from previous message
if line == "\r\n" or line == "\n":
line = rfile.readline()
if not line:
raise HttpErrorConnClosed(502, "Server disconnect.")
parts = parse_response_line(line)
if not parts:
raise HttpError(502, "Invalid server response: %s" % repr(line))
proto, code, msg = parts
httpversion = parse_http_protocol(proto)
if httpversion is None:
raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto))
headers = read_headers(rfile)
if headers is None:
raise HttpError(502, "Invalid headers.")
if include_body:
content = read_http_body(
rfile,
headers,
body_size_limit,
request_method,
code,
False
)
else:
# if include_body==False then a None content means the body should be
# read separately
content = None
return Response(httpversion, code, msg, headers, content)
def request_preamble(method, resource, http_major="1", http_minor="1"):
return '%s %s HTTP/%s.%s' % (
method, resource, http_major, http_minor
)
def response_preamble(code, message=None, http_major="1", http_minor="1"):
if message is None:
message = http_status.RESPONSES.get(code)
return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message)

View File

@ -2,13 +2,11 @@ from __future__ import absolute_import
import base64 import base64
import hashlib import hashlib
import mimetools
import StringIO
import os import os
import struct import struct
import io import io
from . import utils from . import utils, odict
# Colleciton of utility functions that implement small portions of the RFC6455 # Colleciton of utility functions that implement small portions of the RFC6455
# WebSockets Protocol Useful for building WebSocket clients and servers. # WebSockets Protocol Useful for building WebSocket clients and servers.
@ -23,6 +21,7 @@ from . import utils
# The magic sha that websocket servers must know to prove they understand # The magic sha that websocket servers must know to prove they understand
# RFC6455 # RFC6455
websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
VERSION = "13"
class CONST(object): class CONST(object):
@ -151,9 +150,9 @@ class Frame(object):
("opcode - " + str(self.opcode)), ("opcode - " + str(self.opcode)),
("mask_bit - " + str(self.mask_bit)), ("mask_bit - " + str(self.mask_bit)),
("payload_length_code - " + str(self.payload_length_code)), ("payload_length_code - " + str(self.payload_length_code)),
("masking_key - " + str(self.masking_key)), ("masking_key - " + repr(str(self.masking_key))),
("payload - " + str(self.payload)), ("payload - " + repr(str(self.payload))),
("decoded_payload - " + str(self.decoded_payload)), ("decoded_payload - " + repr(str(self.decoded_payload))),
("actual_payload_length - " + str(self.actual_payload_length)) ("actual_payload_length - " + str(self.actual_payload_length))
]) ])
@ -198,24 +197,24 @@ class Frame(object):
second_byte = (self.mask_bit << 7) | self.payload_length_code second_byte = (self.mask_bit << 7) | self.payload_length_code
bytes = chr(first_byte) + chr(second_byte) b = chr(first_byte) + chr(second_byte)
if self.actual_payload_length < 126: if self.actual_payload_length < 126:
pass pass
elif self.actual_payload_length < CONST.MAX_16_BIT_INT: elif self.actual_payload_length < CONST.MAX_16_BIT_INT:
# '!H' pack as 16 bit unsigned short # '!H' pack as 16 bit unsigned short
# add 2 byte extended payload length # add 2 byte extended payload length
bytes += struct.pack('!H', self.actual_payload_length) b += struct.pack('!H', self.actual_payload_length)
elif self.actual_payload_length < CONST.MAX_64_BIT_INT: elif self.actual_payload_length < CONST.MAX_64_BIT_INT:
# '!Q' = pack as 64 bit unsigned long long # '!Q' = pack as 64 bit unsigned long long
# add 8 bytes extended payload length # add 8 bytes extended payload length
bytes += struct.pack('!Q', self.actual_payload_length) b += struct.pack('!Q', self.actual_payload_length)
if self.masking_key is not None: if self.masking_key is not None:
bytes += self.masking_key b += self.masking_key
bytes += self.payload # already will be encoded if neccessary b += self.payload # already will be encoded if neccessary
return bytes return b
def to_file(self, writer): def to_file(self, writer):
writer.write(self.to_bytes()) writer.write(self.to_bytes())
@ -313,58 +312,35 @@ def random_masking_key():
return os.urandom(4) return os.urandom(4)
def create_client_handshake(host, port, key, version, resource): def client_handshake_headers(key=None, version=VERSION):
""" """
WebSockets connections are intiated by the client with a valid HTTP Create the headers for a valid HTTP upgrade request. If Key is not
upgrade request specified, it is generated, and can be found in sec-websocket-key in
the returned header set.
Returns an instance of ODictCaseless
""" """
headers = [ if not key:
('Host', '%s:%s' % (host, port)), key = base64.b64encode(os.urandom(16)).decode('utf-8')
return odict.ODictCaseless([
('Connection', 'Upgrade'), ('Connection', 'Upgrade'),
('Upgrade', 'websocket'), ('Upgrade', 'websocket'),
('Sec-WebSocket-Key', key), ('Sec-WebSocket-Key', key),
('Sec-WebSocket-Version', version) ('Sec-WebSocket-Version', version)
] ])
request = "GET %s HTTP/1.1" % resource
return build_handshake(headers, request)
def create_server_handshake(key): def server_handshake_headers(key):
""" """
The server response is a valid HTTP 101 response. The server response is a valid HTTP 101 response.
""" """
headers = [ return odict.ODictCaseless(
[
('Connection', 'Upgrade'), ('Connection', 'Upgrade'),
('Upgrade', 'websocket'), ('Upgrade', 'websocket'),
('Sec-WebSocket-Accept', create_server_nonce(key)) ('Sec-WebSocket-Accept', create_server_nonce(key))
] ]
request = "HTTP/1.1 101 Switching Protocols" )
return build_handshake(headers, request)
def build_handshake(headers, request):
handshake = [request.encode('utf-8')]
for header, value in headers:
handshake.append(("%s: %s" % (header, value)).encode('utf-8'))
handshake.append(b'\r\n')
return b'\r\n'.join(handshake)
def read_handshake(reader, num_bytes_per_read):
"""
From provided function that reads bytes, read in a
complete HTTP request, which terminates with a CLRF
"""
response = b''
doubleCLRF = b'\r\n\r\n'
while True:
bytes = reader.read(num_bytes_per_read)
if not bytes:
break
response += bytes
if doubleCLRF in response:
break
return response
def get_payload_length_pair(payload_bytestring): def get_payload_length_pair(payload_bytestring):
@ -384,33 +360,19 @@ def get_payload_length_pair(payload_bytestring):
return (length_code, actual_length) return (length_code, actual_length)
def process_handshake_from_client(handshake): def check_client_handshake(req):
headers = headers_from_http_message(handshake) if req.headers.get_first("upgrade", None) != "websocket":
if headers.get("Upgrade", None) != "websocket":
return return
key = headers['Sec-WebSocket-Key'] return req.headers.get_first('sec-websocket-key')
return key
def process_handshake_from_server(handshake): def check_server_handshake(resp):
headers = headers_from_http_message(handshake) if resp.headers.get_first("upgrade", None) != "websocket":
if headers.get("Upgrade", None) != "websocket":
return return
key = headers['Sec-WebSocket-Accept'] return resp.headers.get_first('sec-websocket-accept')
return key
def headers_from_http_message(http_message):
return mimetools.Message(
StringIO.StringIO(http_message.split('\r\n', 1)[1])
)
def create_server_nonce(client_nonce): def create_server_nonce(client_nonce):
return base64.b64encode( return base64.b64encode(
hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex') hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex')
) )
def create_client_nonce():
return base64.b64encode(os.urandom(16)).decode('utf-8')

View File

@ -1,6 +1,4 @@
from netlib import tcp from netlib import tcp, test, websockets, http, odict
from netlib import test
from netlib import websockets
import io import io
import os import os
from nose.tools import raises from nose.tools import raises
@ -21,18 +19,20 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
self.read_next_message() self.read_next_message()
def read_next_message(self): def read_next_message(self):
decoded = websockets.Frame.from_file(self.rfile).decoded_payload frame = websockets.Frame.from_file(self.rfile)
self.on_message(decoded) self.on_message(frame.decoded_payload)
def send_message(self, message): def send_message(self, message):
frame = websockets.Frame.default(message, from_client = False) frame = websockets.Frame.default(message, from_client = False)
frame.to_file(self.wfile) frame.to_file(self.wfile)
def handshake(self): def handshake(self):
client_hs = websockets.read_handshake(self.rfile, 1) req = http.read_request(self.rfile)
key = websockets.process_handshake_from_client(client_hs) key = websockets.check_client_handshake(req)
response = websockets.create_server_handshake(key)
self.wfile.write(response) self.wfile.write(http.response_preamble(101) + "\r\n")
headers = websockets.server_handshake_headers(key)
self.wfile.write(headers.format() + "\r\n")
self.wfile.flush() self.wfile.flush()
self.handshake_done = True self.handshake_done = True
@ -44,28 +44,20 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
class WebSocketsClient(tcp.TCPClient): class WebSocketsClient(tcp.TCPClient):
def __init__(self, address, source_address=None): def __init__(self, address, source_address=None):
super(WebSocketsClient, self).__init__(address, source_address) super(WebSocketsClient, self).__init__(address, source_address)
self.version = "13" self.client_nonce = None
self.client_nonce = websockets.create_client_nonce()
self.resource = "/"
def connect(self): def connect(self):
super(WebSocketsClient, self).connect() super(WebSocketsClient, self).connect()
handshake = websockets.create_client_handshake( preamble = http.request_preamble("GET", "/")
self.address.host, self.wfile.write(preamble + "\r\n")
self.address.port, headers = websockets.client_handshake_headers()
self.client_nonce, self.client_nonce = headers.get_first("sec-websocket-key")
self.version, self.wfile.write(headers.format() + "\r\n")
self.resource
)
self.wfile.write(handshake)
self.wfile.flush() self.wfile.flush()
server_handshake = websockets.read_handshake(self.rfile, 1) resp = http.read_response(self.rfile, "get", None)
server_nonce = websockets.process_handshake_from_server( server_nonce = websockets.check_server_handshake(resp)
server_handshake
)
if not server_nonce == websockets.create_server_nonce(self.client_nonce): if not server_nonce == websockets.create_server_nonce(self.client_nonce):
self.close() self.close()
@ -140,51 +132,43 @@ class TestWebSockets(test.ServerTestBase):
frame.actual_payload_length = 1 # corrupt the frame frame.actual_payload_length = 1 # corrupt the frame
frame.safe_to_bytes() frame.safe_to_bytes()
def test_handshake(self): def test_check_server_handshake(self):
bad_upgrade = "not_websockets" resp = http.Response(
bad_header_handshake = websockets.build_handshake([ (1, 1),
('Host', '%s:%s' % ("a", "b")), 101,
('Connection', "c"), "Switching Protocols",
('Upgrade', bad_upgrade), websockets.server_handshake_headers("key"),
('Sec-WebSocket-Key', "d"), ""
('Sec-WebSocket-Version', "e")
], "f")
# check behavior when required header values are missing
assert None is websockets.process_handshake_from_server(
bad_header_handshake
) )
assert None is websockets.process_handshake_from_client( assert websockets.check_server_handshake(resp)
bad_header_handshake resp.headers["Upgrade"] = ["not_websocket"]
assert not websockets.check_server_handshake(resp)
def test_check_client_handshake(self):
resp = http.Request(
"relative",
"get",
"http",
"host",
22,
"/",
(1, 1),
websockets.client_handshake_headers("key"),
""
) )
assert websockets.check_client_handshake(resp) == "key"
key = "test_key" resp.headers["Upgrade"] = ["not_websocket"]
assert not websockets.check_client_handshake(resp)
client_handshake = websockets.create_client_handshake(
"a", "b", key, "d", "e"
)
assert key == websockets.process_handshake_from_client(
client_handshake
)
server_handshake = websockets.create_server_handshake(key)
assert websockets.create_server_nonce(key) == websockets.process_handshake_from_server(server_handshake)
handshake = websockets.create_client_handshake("a", "b", "c", "d", "e")
stream = io.BytesIO(handshake)
assert handshake == websockets.read_handshake(stream, 1)
# ensure readhandshake doesn't loop forever on empty stream
empty_stream = io.BytesIO("")
assert "" == websockets.read_handshake(empty_stream, 1)
class BadHandshakeHandler(WebSocketsEchoHandler): class BadHandshakeHandler(WebSocketsEchoHandler):
def handshake(self): def handshake(self):
client_hs = websockets.read_handshake(self.rfile, 1) client_hs = http.read_request(self.rfile)
websockets.process_handshake_from_client(client_hs) websockets.check_client_handshake(client_hs)
response = websockets.create_server_handshake("malformed_key")
self.wfile.write(response) self.wfile.write(http.response_preamble(101) + "\r\n")
headers = websockets.server_handshake_headers("malformed key")
self.wfile.write(headers.format() + "\r\n")
self.wfile.flush() self.wfile.flush()
self.handshake_done = True self.handshake_done = True