websockets: handshake checks only take headers

This commit is contained in:
Aldo Cortesi 2015-04-23 08:23:51 +12:00
parent 4fb49c8e55
commit 42a87a1d8b
4 changed files with 24 additions and 40 deletions

View File

@ -33,7 +33,7 @@ def _is_valid_host(host):
return True return True
def get_line(fp): def get_request_line(fp):
""" """
Get a line, possibly preceded by a blank. Get a line, possibly preceded by a blank.
""" """
@ -41,8 +41,6 @@ def get_line(fp):
if line == "\r\n" or line == "\n": if line == "\r\n" or line == "\n":
# Possible leftover from previous message # Possible leftover from previous message
line = fp.readline() line = fp.readline()
if line == "":
raise tcp.NetLibDisconnect()
return line return line
@ -457,7 +455,9 @@ def read_request(rfile, include_body=True, body_size_limit=None, wfile=None):
httpversion, host, port, scheme, method, path, headers, content = ( httpversion, host, port, scheme, method, path, headers, content = (
None, None, None, None, None, None, None, None) None, None, None, None, None, None, None, None)
request_line = get_line(rfile) request_line = get_request_line(rfile)
if not request_line:
raise tcp.NetLibDisconnect()
request_line_parts = parse_init(request_line) request_line_parts = parse_init(request_line)
if not request_line_parts: if not request_line_parts:

View File

@ -350,16 +350,16 @@ def get_payload_length_pair(payload_bytestring):
return (length_code, actual_length) return (length_code, actual_length)
def check_client_handshake(req): def check_client_handshake(headers):
if req.headers.get_first("upgrade", None) != "websocket": if headers.get_first("upgrade", None) != "websocket":
return return
return req.headers.get_first('sec-websocket-key') return headers.get_first('sec-websocket-key')
def check_server_handshake(resp): def check_server_handshake(headers):
if resp.headers.get_first("upgrade", None) != "websocket": if headers.get_first("upgrade", None) != "websocket":
return return
return resp.headers.get_first('sec-websocket-accept') return headers.get_first('sec-websocket-accept')
def create_server_nonce(client_nonce): def create_server_nonce(client_nonce):

View File

@ -412,10 +412,10 @@ def test_parse_http_basic_auth():
assert not http.parse_http_basic_auth(v) assert not http.parse_http_basic_auth(v)
def test_get_line(): def test_get_request_line():
r = cStringIO.StringIO("\nfoo") r = cStringIO.StringIO("\nfoo")
assert http.get_line(r) == "foo" assert http.get_request_line(r) == "foo"
tutils.raises(tcp.NetLibDisconnect, http.get_line, r) assert not http.get_request_line(r)
class TestReadRequest(): class TestReadRequest():

View File

@ -27,7 +27,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
def handshake(self): def handshake(self):
req = http.read_request(self.rfile) req = http.read_request(self.rfile)
key = websockets.check_client_handshake(req) key = websockets.check_client_handshake(req.headers)
self.wfile.write(http.response_preamble(101) + "\r\n") self.wfile.write(http.response_preamble(101) + "\r\n")
headers = websockets.server_handshake_headers(key) headers = websockets.server_handshake_headers(key)
@ -56,7 +56,7 @@ class WebSocketsClient(tcp.TCPClient):
self.wfile.flush() self.wfile.flush()
resp = http.read_response(self.rfile, "get", None) resp = http.read_response(self.rfile, "get", None)
server_nonce = websockets.check_server_handshake(resp) server_nonce = websockets.check_server_handshake(resp.headers)
if not server_nonce == websockets.create_server_nonce(self.client_nonce): if not server_nonce == websockets.create_server_nonce(self.client_nonce):
self.close() self.close()
@ -153,38 +153,22 @@ class TestWebSockets(test.ServerTestBase):
assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes
def test_check_server_handshake(self): def test_check_server_handshake(self):
resp = http.Response( headers = websockets.server_handshake_headers("key")
(1, 1), assert websockets.check_server_handshake(headers)
101, headers["Upgrade"] = ["not_websocket"]
"Switching Protocols", assert not websockets.check_server_handshake(headers)
websockets.server_handshake_headers("key"),
""
)
assert websockets.check_server_handshake(resp)
resp.headers["Upgrade"] = ["not_websocket"]
assert not websockets.check_server_handshake(resp)
def test_check_client_handshake(self): def test_check_client_handshake(self):
resp = http.Request( headers = websockets.client_handshake_headers("key")
"relative", assert websockets.check_client_handshake(headers) == "key"
"get", headers["Upgrade"] = ["not_websocket"]
"http", assert not websockets.check_client_handshake(headers)
"host",
22,
"/",
(1, 1),
websockets.client_handshake_headers("key"),
""
)
assert websockets.check_client_handshake(resp) == "key"
resp.headers["Upgrade"] = ["not_websocket"]
assert not websockets.check_client_handshake(resp)
class BadHandshakeHandler(WebSocketsEchoHandler): class BadHandshakeHandler(WebSocketsEchoHandler):
def handshake(self): def handshake(self):
client_hs = http.read_request(self.rfile) client_hs = http.read_request(self.rfile)
websockets.check_client_handshake(client_hs) websockets.check_client_handshake(client_hs.headers)
self.wfile.write(http.response_preamble(101) + "\r\n") self.wfile.write(http.response_preamble(101) + "\r\n")
headers = websockets.server_handshake_headers("malformed key") headers = websockets.server_handshake_headers("malformed key")