From 5692c7359c7c468e1ee35f9bdd3b7b92fd67592c Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 16 Feb 2011 11:21:06 +1300 Subject: [PATCH] Import Will McGugan's netstring module. Module is in the Public Domain. I expect to modify and extend this module, so I've imported into main library rather than contrib. Code has been reformatted to suite our code standard, tests have been extrated into /tests directory. --- libmproxy/netstring.py | 230 +++++++++++++++++++++++++++++++++++++++++ test/test_netstring.py | 63 +++++++++++ 2 files changed, 293 insertions(+) create mode 100644 libmproxy/netstring.py create mode 100644 test/test_netstring.py diff --git a/libmproxy/netstring.py b/libmproxy/netstring.py new file mode 100644 index 000000000..a9b71a329 --- /dev/null +++ b/libmproxy/netstring.py @@ -0,0 +1,230 @@ +""" + Netstring is a module for encoding and decoding netstring streams. + See http://cr.yp.to/proto/netstrings.txt for more information on netstrings. + Author: Will McGugan (http://www.willmcgugan.com) +""" +from cStringIO import StringIO + + +def header(data): + return str(len(data))+":" + + +def encode(data): + if not isinstance(data, str): + raise ValueError("data should be of type 'str'") + + return "%i:%s," % (len(data), data) + + +class FileEncoder(object): + def __init__(self, file_out): + """" + file_out -- A writable file object + """ + self.file_out = file_out + + def write(self, data): + """ + Encodes a netstring and writes it to the file object. + + data -- A string to be encoded and written + """ + if not isinstance(data, str): + raise ValueError("data should be of type 'str'") + write = self.file_out.write + write(header(data)) + write(data) + write(',') + return self + + +def netstrings_to_file(file_out, data_container): + """ + Writes a container of netstrings to a file. + + file_out -- A writeable file-like object + data_container -- An iterable of strings + """ + write = file_out.write + for s in data_container: + if not isinstance(s, str): + raise ValueError("data should be of type 'str'") + write(header(s)) + write(s) + write(',') + + +def encode_netstrings(data_container): + """ + Encodes a number of strings as sequence of netstrings. + + data_container -- An iterable of strings to be encoded + """ + return "".join(encode(s) for s in data_container) + + +class DecoderError(Exception): + ( + PRECEDING_ZERO_IN_SIZE, + MAX_SIZE_REACHED, + ILLEGAL_DIGIT_IN_SIZE, + ILLEGAL_DIGIT + ) = range(4) + error_text = { + PRECEDING_ZERO_IN_SIZE:"PRECEDING_ZERO_IN_SIZE", + MAX_SIZE_REACHED:"MAX_SIZE_REACHED", + ILLEGAL_DIGIT_IN_SIZE:"ILLEGAL_DIGIT_IN_SIZE", + ILLEGAL_DIGIT:"ILLEGAL_DIGIT" + } + def __init__(self, code, text): + Exception.__init__(self) + self.code = code + self.text = text + + def __str__(self): + return "%s (#%i), %s" % (DecoderError.error_text[self.code], self.code, self.text) + + +class Decoder(object): + """ + A netstring decoder. + Turns a netstring stream in to a number of discreet strings. + """ + def __init__(self, max_size=None): + """ + Create a netstring-stream decoder object. + + max_size -- The maximum size of a netstring encoded string, after which + a DecoderError will be throw. A value of None (the default) indicates + that there should be no maximum string size. + """ + self.max_size = max_size + self.data_pos = 0 + self.string_start = 0 + self.expecting_terminator = False + self.size_string = "" + self.data_size = None + self.remaining_bytes = 0 + self.data_out = StringIO() + self.yield_data = "" + + def __str__(self): + if self.data_size is None: + bytes = len(self.size_string) + else: + bytes = self.data_out.tell() + return ""%bytes + + def peek_buffer(self): + """ + Returns any bytes not used by decoder. + """ + return self.data_out.getvalue() + + def reset(self): + """ + Resets decoder to initial state, and discards any cached stream data. + """ + self.data_pos = 0 + self.string_start = 0 + self.expecting_terminator = False + self.size_string = "" + self.data_size = None + self.remaining_bytes = 0 + self.yield_data = "" + + self.data_out.reset() + self.data_out.truncate() + + def feed(self, data): + """ + A generator that yields 0 or more strings from the given data. + + data -- A string containing complete or partial netstring data + """ + if not isinstance(data, str): + raise ValueError("data should be of type 'str'") + self.data_pos = 0 + self.string_start = 0 + while self.data_pos < len(data): + if self.expecting_terminator: + c = data[self.data_pos] + self.data_pos += 1 + if c != ',': + raise DecoderError(DecoderError.ILLEGAL_DIGIT, "Illegal digit (%s) at end of data"%repr(c)) + yield self.yield_data + self.yield_data = "" + self.expecting_terminator = False + elif self.data_size is None: + c = data[self.data_pos] + self.data_pos += 1 + + if not len(self.size_string): + self.string_start = self.data_pos-1 + + if c in "0123456789": + if self.size_string == '0': + raise DecoderError(DecoderError.PRECEDING_ZERO_IN_SIZE, "Preceding zeros in size field illegal") + self.size_string += c + if self.max_size is not None and int(self.size_string) > self.max_size: + raise DecoderError(DecoderError.MAX_SIZE_REACHED, "Maximum size of netstring exceeded") + + elif c == ":": + if not len(self.size_string): + raise DecoderError(DecoderError.ILLEGAL_DIGIT_IN_SIZE, "Illegal digit (%s) in size field"%repr(c)) + self.data_size = int(self.size_string) + self.remaining_bytes = self.data_size + + else: + raise DecoderError(DecoderError.ILLEGAL_DIGIT_IN_SIZE, "Illegal digit (%s) in size field"%repr(c)) + + elif self.data_size is not None: + get_bytes = min(self.remaining_bytes, len(data)-self.data_pos) + chunk = data[self.data_pos:self.data_pos+get_bytes] + whole_string = len(chunk) == self.data_size + if not whole_string: + self.data_out.write(chunk) + self.data_pos += get_bytes + self.remaining_bytes -= get_bytes + if self.remaining_bytes == 0: + if whole_string: + self.yield_data = chunk + else: + self.yield_data = self.data_out.getvalue() + self.data_out.reset() + self.data_out.truncate() + self.data_size = None + self.size_string = "" + self.remaining_bytes = 0 + self.expecting_terminator = True + + +def decode(data): + """ + Decodes netstrings and returns a tuple containing a + list of strings, and any remaining data. + + data -- A string containing netstring data + """ + decoder = Decoder() + netstrings = list(decoder.feed(data)) + remaining = data[decoder.string_start:] + return netstrings, remaining + + +def decode_file(file_in, buffer_size=1024): + """ + Generates 0 or more strings from a netstring file. + + file_in -- A readable file-like object containing netstring data + buffer_size -- The number of bytes to attempt to read in each iteration + (default = 1024). + """ + decoder = Decoder() + while True: + data = file_in.read(buffer_size) + if not len(data): + return + for s in decoder.feed(data): + yield s diff --git a/test/test_netstring.py b/test/test_netstring.py new file mode 100644 index 000000000..efcfd83f1 --- /dev/null +++ b/test/test_netstring.py @@ -0,0 +1,63 @@ +from libmproxy import netstring +from cStringIO import StringIO +import libpry + + + +class uNetstring(libpry.AutoTree): + def setUp(self): + self.test_data = "Netstring module by Will McGugan" + self.encoded_data = "9:Netstring,6:module,2:by,4:Will,7:McGugan," + + def test_header(self): + tests = [ ("netstring", "9:"), + ("Will McGugan", "12:"), + ("", "0:") ] + for test, result in tests: + assert netstring.header(test) == result + + def test_encode(self): + tests = [ ("netstring", "9:netstring,"), + ("Will McGugan", "12:Will McGugan,"), + ("", "0:,") ] + for test, result in tests: + assert netstring.encode(test) == result + + def test_file_encoder(self): + file_out = StringIO() + data = self.test_data.split() + encoder = netstring.FileEncoder(file_out) + for s in data: + encoder.write(s) + encoded_data = file_out.getvalue() + assert encoded_data == self.encoded_data + + def test_decode_file(self): + data = self.test_data.split() + for buffer_size in range(1, len(self.encoded_data)): + file_in = StringIO(self.encoded_data[:]) + decoded_data = list(netstring.decode_file(file_in, buffer_size = buffer_size)) + assert decoded_data == data + + def test_decoder(self): + encoded_data = self.encoded_data + for step in range(1, len(encoded_data)): + i = 0 + chunks = [] + while i < len(encoded_data): + chunks.append(encoded_data[i:i+step]) + i += step + decoder = netstring.Decoder() + decoded_data = [] + for chunk in chunks: + for s in decoder.feed(chunk): + decoded_data.append(s) + assert decoded_data == self.test_data.split() + + + + +tests = [ + uNetstring() +] +