From 42a87a1d8b3eeccfdd8e5e504f1cd4d90ae1dbfb Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 23 Apr 2015 08:23:51 +1200 Subject: [PATCH] websockets: handshake checks only take headers --- netlib/http.py | 8 ++++---- netlib/websockets.py | 12 ++++++------ test/test_http.py | 6 +++--- test/test_websockets.py | 38 +++++++++++--------------------------- 4 files changed, 24 insertions(+), 40 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index fe27240a8..43155486c 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -33,7 +33,7 @@ def _is_valid_host(host): return True -def get_line(fp): +def get_request_line(fp): """ Get a line, possibly preceded by a blank. """ @@ -41,8 +41,6 @@ def get_line(fp): if line == "\r\n" or line == "\n": # Possible leftover from previous message line = fp.readline() - if line == "": - raise tcp.NetLibDisconnect() return line @@ -457,7 +455,9 @@ def read_request(rfile, include_body=True, body_size_limit=None, wfile=None): httpversion, host, port, scheme, method, path, headers, content = ( None, None, None, None, None, None, None, None) - request_line = get_line(rfile) + request_line = get_request_line(rfile) + if not request_line: + raise tcp.NetLibDisconnect() request_line_parts = parse_init(request_line) if not request_line_parts: diff --git a/netlib/websockets.py b/netlib/websockets.py index d5c5c2fe4..da03768d2 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -350,16 +350,16 @@ def get_payload_length_pair(payload_bytestring): return (length_code, actual_length) -def check_client_handshake(req): - if req.headers.get_first("upgrade", None) != "websocket": +def check_client_handshake(headers): + if headers.get_first("upgrade", None) != "websocket": return - return req.headers.get_first('sec-websocket-key') + return headers.get_first('sec-websocket-key') -def check_server_handshake(resp): - if resp.headers.get_first("upgrade", None) != "websocket": +def check_server_handshake(headers): + if headers.get_first("upgrade", None) != "websocket": return - return resp.headers.get_first('sec-websocket-accept') + return headers.get_first('sec-websocket-accept') def create_server_nonce(client_nonce): diff --git a/test/test_http.py b/test/test_http.py index 8b99c769e..962eb9cb0 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -412,10 +412,10 @@ def test_parse_http_basic_auth(): assert not http.parse_http_basic_auth(v) -def test_get_line(): +def test_get_request_line(): r = cStringIO.StringIO("\nfoo") - assert http.get_line(r) == "foo" - tutils.raises(tcp.NetLibDisconnect, http.get_line, r) + assert http.get_request_line(r) == "foo" + assert not http.get_request_line(r) class TestReadRequest(): diff --git a/test/test_websockets.py b/test/test_websockets.py index 9e205e701..6f3b429df 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -27,7 +27,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler): def handshake(self): req = http.read_request(self.rfile) - key = websockets.check_client_handshake(req) + key = websockets.check_client_handshake(req.headers) self.wfile.write(http.response_preamble(101) + "\r\n") headers = websockets.server_handshake_headers(key) @@ -56,7 +56,7 @@ class WebSocketsClient(tcp.TCPClient): self.wfile.flush() resp = http.read_response(self.rfile, "get", None) - server_nonce = websockets.check_server_handshake(resp) + server_nonce = websockets.check_server_handshake(resp.headers) if not server_nonce == websockets.create_server_nonce(self.client_nonce): self.close() @@ -153,38 +153,22 @@ class TestWebSockets(test.ServerTestBase): assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes def test_check_server_handshake(self): - resp = http.Response( - (1, 1), - 101, - "Switching Protocols", - websockets.server_handshake_headers("key"), - "" - ) - assert websockets.check_server_handshake(resp) - resp.headers["Upgrade"] = ["not_websocket"] - assert not websockets.check_server_handshake(resp) + headers = websockets.server_handshake_headers("key") + assert websockets.check_server_handshake(headers) + headers["Upgrade"] = ["not_websocket"] + assert not websockets.check_server_handshake(headers) def test_check_client_handshake(self): - resp = http.Request( - "relative", - "get", - "http", - "host", - 22, - "/", - (1, 1), - websockets.client_handshake_headers("key"), - "" - ) - assert websockets.check_client_handshake(resp) == "key" - resp.headers["Upgrade"] = ["not_websocket"] - assert not websockets.check_client_handshake(resp) + headers = websockets.client_handshake_headers("key") + assert websockets.check_client_handshake(headers) == "key" + headers["Upgrade"] = ["not_websocket"] + assert not websockets.check_client_handshake(headers) class BadHandshakeHandler(WebSocketsEchoHandler): def handshake(self): client_hs = http.read_request(self.rfile) - websockets.check_client_handshake(client_hs) + websockets.check_client_handshake(client_hs.headers) self.wfile.write(http.response_preamble(101) + "\r\n") headers = websockets.server_handshake_headers("malformed key")