small cleanups, working on tests

This commit is contained in:
Chandler Abraham 2015-04-11 11:35:15 -07:00
parent e41e5cbfdd
commit 0edc04814e
3 changed files with 41 additions and 28 deletions

View File

@ -26,8 +26,8 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
self.on_message(decoded) self.on_message(decoded)
def send_message(self, message): def send_message(self, message):
frame = ws.WebSocketsFrame.default_frame_from_message(message, from_client = False) frame = ws.WebSocketsFrame.default(message, from_client = False)
self.wfile.write(frame.to_bytes()) self.wfile.write(frame.safe_to_bytes())
self.wfile.flush() self.wfile.flush()
def handshake(self): def handshake(self):
@ -47,7 +47,7 @@ class WebSocketsClient(tcp.TCPClient):
def __init__(self, address, source_address=None): def __init__(self, address, source_address=None):
super(WebSocketsClient, self).__init__(address, source_address) super(WebSocketsClient, self).__init__(address, source_address)
self.version = "13" self.version = "13"
self.key = b64encode(os.urandom(16)).decode('utf-8') self.key = ws.generate_client_nounce()
self.resource = "/" self.resource = "/"
def connect(self): def connect(self):
@ -76,6 +76,6 @@ class WebSocketsClient(tcp.TCPClient):
self.close() self.close()
def send_message(self, message): def send_message(self, message):
frame = ws.WebSocketsFrame.default_frame_from_message(message, from_client = True) frame = ws.WebSocketsFrame.default(message, from_client = True)
self.wfile.write(frame.to_bytes()) self.wfile.write(frame.safe_to_bytes())
self.wfile.flush() self.wfile.flush()

View File

@ -65,7 +65,6 @@ class WebSocketsFrame(object):
payload = None, # bytestring payload = None, # bytestring
masking_key = None, # 32 bit byte string masking_key = None, # 32 bit byte string
actual_payload_length = None, # any decimal integer actual_payload_length = None, # any decimal integer
use_validation = True # indicates whether or not you care if this frame adheres to the spec
): ):
self.fin = fin self.fin = fin
self.rsv1 = rsv1 self.rsv1 = rsv1
@ -78,21 +77,18 @@ class WebSocketsFrame(object):
self.payload = payload self.payload = payload
self.decoded_payload = decoded_payload self.decoded_payload = decoded_payload
self.actual_payload_length = actual_payload_length self.actual_payload_length = actual_payload_length
self.use_validation = use_validation
if self.use_validation:
self.validate_frame()
@classmethod @classmethod
def from_bytes(cls, bytestring): def from_bytes(cls, bytestring):
""" """
Construct a websocket frame from an in-memory bytestring Construct a websocket frame from an in-memory bytestring
to construct a frame from a stream of bytes, use read_frame() directly to construct a frame from a stream of bytes, use from_byte_stream() directly
""" """
self.from_byte_stream(io.BytesIO(bytestring).read) self.from_byte_stream(io.BytesIO(bytestring).read)
@classmethod @classmethod
def default_frame_from_message(cls, message, from_client = False): def default(cls, message, from_client = False):
""" """
Construct a basic websocket frame from some default values. Construct a basic websocket frame from some default values.
Creates a non-fragmented text frame. Creates a non-fragmented text frame.
@ -119,7 +115,7 @@ class WebSocketsFrame(object):
actual_payload_length = actual_length actual_payload_length = actual_length
) )
def validate_frame(self): def frame_is_valid(self):
""" """
Validate websocket frame invariants, call at anytime to ensure the WebSocketsFrame Validate websocket frame invariants, call at anytime to ensure the WebSocketsFrame
has not been corrupted. has not been corrupted.
@ -141,10 +137,11 @@ class WebSocketsFrame(object):
assert self.actual_payload_length == len(self.payload) assert self.actual_payload_length == len(self.payload)
if self.payload is not None and self.masking_key is not None: if self.payload is not None and self.masking_key is not None:
apply_mask(self.payload, self.masking_key) == self.decoded_payload assert apply_mask(self.payload, self.masking_key) == self.decoded_payload
return True
except AssertionError: except AssertionError:
raise WebSocketFrameValidationException() return False
def human_readable(self): def human_readable(self):
return "\n".join([ return "\n".join([
@ -161,15 +158,19 @@ class WebSocketsFrame(object):
("actual_payload_length - " + str(self.actual_payload_length)), ("actual_payload_length - " + str(self.actual_payload_length)),
("use_validation - " + str(self.use_validation))]) ("use_validation - " + str(self.use_validation))])
def safe_to_bytes(self):
try:
assert self.frame_is_valid()
return self.to_bytes()
except:
raise WebSocketFrameValidationException()
def to_bytes(self): def to_bytes(self):
""" """
Serialize the frame back into the wire format, returns a bytestring Serialize the frame back into the wire format, returns a bytestring
If you haven't checked is_valid_frame() then there's no guarentees that the
serialized bytes will be correct. see safe_to_bytes()
""" """
# validate enforces all the assumptions made by this serializer
# in the spritit of mitmproxy, it's possible to create and serialize invalid frames
# by skipping validation.
if self.use_validation:
self.validate_frame()
max_16_bit_int = (1 << 16) max_16_bit_int = (1 << 16)
max_64_bit_int = (1 << 63) max_64_bit_int = (1 << 63)
@ -198,6 +199,7 @@ class WebSocketsFrame(object):
pass pass
elif self.actual_payload_length < max_16_bit_int: elif self.actual_payload_length < max_16_bit_int:
# '!H' pack as 16 bit unsigned short # '!H' pack as 16 bit unsigned short
bytes += struct.pack('!H', self.actual_payload_length) # add 2 byte extended payload length bytes += struct.pack('!H', self.actual_payload_length) # add 2 byte extended payload length
@ -284,9 +286,6 @@ def apply_mask(message, masking_key):
def random_masking_key(): def random_masking_key():
return os.urandom(4) return os.urandom(4)
def masking_key_list(masking_key):
return [utils.bytes_to_int(byte) for byte in masking_key]
def create_client_handshake(host, port, key, version, resource): def create_client_handshake(host, port, key, version, resource):
""" """
WebSockets connections are intiated by the client with a valid HTTP upgrade request WebSockets connections are intiated by the client with a valid HTTP upgrade request

View File

@ -1,15 +1,29 @@
from netlib import test from netlib import test
from netlib.websockets import implementations as ws from netlib.websockets import implementations as impl
from netlib.websockets import websockets as ws
import os
class TestWebSockets(test.ServerTestBase): class TestWebSockets(test.ServerTestBase):
handler = ws.WebSocketsEchoHandler handler = impl.WebSocketsEchoHandler
def test_websockets_echo(self): def echo(self, msg):
msg = "hello I'm the client" client = impl.WebSocketsClient(("127.0.0.1", self.port))
client = ws.WebSocketsClient(("127.0.0.1", self.port))
client.connect() client.connect()
client.send_message(msg) client.send_message(msg)
response = client.read_next_message() response = client.read_next_message()
print "Assert response: " + response + " == msg: " + msg print "Assert response: " + response + " == msg: " + msg
assert response == msg assert response == msg
def test_simple_echo(self):
self.echo("hello I'm the client")
def test_frame_sizes(self):
small_string = os.urandom(100) # length can fit in the the 7 bit payload length
medium_string = os.urandom(50000) # 50kb, sligthly larger than can fit in a 7 bit int
large_string = os.urandom(150000) # 150kb, slightly larger than can fit in a 16 bit int
self.echo(small_string)
self.echo(medium_string)
self.echo(large_string)