diff --git a/netlib/http.py b/netlib/http.py index 2c72621dd..264388636 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -29,20 +29,6 @@ def _is_valid_host(host): return None return True -def is_successful_upgrade(request, response): - """ - determines if a client and server successfully agreed to an HTTP protocol upgrade - - https://developer.mozilla.org/en-US/docs/Web/HTTP/Protocol_upgrade_mechanism - """ - http_switching_protocols_code = 101 - - if request and response: - responseUpgrade = request.headers.get("Upgrade") - requestUpgrade = response.headers.get("Upgrade") - if response.code == http_switching_protocols_code and responseUpgrade == requestUpgrade: - return requestUpgrade[0] if len(requestUpgrade) > 0 else None - return None def parse_url(url): """ diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py index 73a846905..1ded3b857 100644 --- a/netlib/websockets/implementations.py +++ b/netlib/websockets/implementations.py @@ -65,9 +65,6 @@ class WebSocketsClient(tcp.TCPClient): self.wfile.flush() server_handshake = ws.read_handshake(self.rfile.read, 1) - - if not server_handshake: - self.close() server_nounce = ws.process_handshake_from_server(server_handshake, self.client_nounce) @@ -75,11 +72,8 @@ class WebSocketsClient(tcp.TCPClient): self.close() def read_next_message(self): - try: - return ws.WebSocketsFrame.from_byte_stream(self.rfile.read).payload - except IndexError: - self.close() - + return ws.WebSocketsFrame.from_byte_stream(self.rfile.read).payload + def send_message(self, message): frame = ws.WebSocketsFrame.default(message, from_client = True) self.wfile.write(frame.safe_to_bytes()) diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py index cf9a68aa9..ea3db21d5 100644 --- a/netlib/websockets/websockets.py +++ b/netlib/websockets/websockets.py @@ -158,11 +158,10 @@ class WebSocketsFrame(object): ("actual_payload_length - " + str(self.actual_payload_length))]) def safe_to_bytes(self): - try: - assert self.is_valid() - return self.to_bytes() - except: - raise WebSocketFrameValidationException() + if self.is_valid(): + return self.to_bytes() + else: + raise WebSocketFrameValidationException() def to_bytes(self): """ diff --git a/test/test_websockets.py b/test/test_websockets.py index 0c23e355f..951aa41ff 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -22,8 +22,8 @@ class TestWebSockets(test.ServerTestBase): self.echo("hello I'm the client") def test_frame_sizes(self): - small_msg = self.random_bytes(100) # length can fit in the the 7 bit payload length - medium_msg = self.random_bytes(50000) # 50kb, sligthly larger than can fit in a 7 bit int + small_msg = self.random_bytes(100) # length can fit in the the 7 bit payload length + medium_msg = self.random_bytes(50000) # 50kb, sligthly larger than can fit in a 7 bit int large_msg = self.random_bytes(150000) # 150kb, slightly larger than can fit in a 16 bit int self.echo(small_msg) @@ -42,6 +42,10 @@ class TestWebSockets(test.ServerTestBase): assert server_frame.is_valid() def test_serialization_bijection(self): + """ + Ensure that various frame types can be serialized/deserialized back and forth + between to_bytes() and from_bytes() + """ for is_client in [True, False]: for num_bytes in [100, 50000, 150000]: frame = ws.WebSocketsFrame.default(self.random_bytes(num_bytes), is_client) @@ -50,6 +54,12 @@ class TestWebSockets(test.ServerTestBase): bytes = b'\x81\x11cba' assert ws.WebSocketsFrame.from_bytes(bytes).to_bytes() == bytes + @raises(ws.WebSocketFrameValidationException) + def test_safe_to_bytes(self): + frame = ws.WebSocketsFrame.default(self.random_bytes(8)) + frame.actual_payload_length = 1 #corrupt the frame + frame.safe_to_bytes() + class BadHandshakeHandler(impl.WebSocketsEchoHandler): def handshake(self):