diff --git a/mitmproxy/tnetstring.py b/mitmproxy/tnetstring.py index f40e8ad8b..9bf20b09d 100644 --- a/mitmproxy/tnetstring.py +++ b/mitmproxy/tnetstring.py @@ -79,9 +79,8 @@ __version__ = "%d.%d.%d%s" % ( __ver_major__, __ver_minor__, __ver_patch__, __ver_sub__) -def dumps(value, encoding=None): - """dumps(object,encoding=None) -> string - +def dumps(value): + """ This function dumps a python object as a tnetstring. """ # This uses a deque to collect output fragments in reverse order, @@ -91,22 +90,21 @@ def dumps(value, encoding=None): # consider the _gdumps() function instead; it's a standard top-down # generator that's simpler to understand but much less efficient. q = deque() - _rdumpq(q, 0, value, encoding) - return "".join(q) + _rdumpq(q, 0, value) + return b''.join(q) -def dump(value, file, encoding=None): - """dump(object,file,encoding=None) - - This function dumps a python object as a tnetstring and writes it to - the given file. +def dump(value, file_handle): """ - file.write(dumps(value, encoding)) - file.flush() + This function dumps a python object as a tnetstring and + writes it to the given file. + """ + file_handle.write(dumps(value)) -def _rdumpq(q, size, value, encoding=None): - """Dump value as a tnetstring, to a deque instance, last chunks first. +def _rdumpq(q, size, value): + """ + Dump value as a tnetstring, to a deque instance, last chunks first. This function generates the tnetstring representation of the given value, pushing chunks of the output onto the given deque instance. It pushes @@ -122,79 +120,70 @@ def _rdumpq(q, size, value, encoding=None): """ write = q.appendleft if value is None: - write("0:~") + write(b'0:~') return size + 3 - if value is True: - write("4:true!") + elif value is True: + write(b'4:true!') return size + 7 - if value is False: - write("5:false!") + elif value is False: + write(b'5:false!') return size + 8 - if isinstance(value, six.integer_types): - data = str(value) + elif isinstance(value, six.integer_types): + data = str(value).encode() ldata = len(data) - span = str(ldata) - write("#") + span = str(ldata).encode() + write(b'#') write(data) - write(":") + write(b':') write(span) return size + 2 + len(span) + ldata - if isinstance(value, (float,)): + elif isinstance(value, float): # Use repr() for float rather than str(). # It round-trips more accurately. # Probably unnecessary in later python versions that # use David Gay's ftoa routines. - data = repr(value) + data = repr(value).encode() ldata = len(data) - span = str(ldata) - write("^") + span = str(ldata).encode() + write(b'^') write(data) - write(":") + write(b':') write(span) return size + 2 + len(span) + ldata - if isinstance(value, str): + elif isinstance(value, bytes): lvalue = len(value) - span = str(lvalue) - write(",") + span = str(lvalue).encode() + write(b',') write(value) - write(":") + write(b':') write(span) return size + 2 + len(span) + lvalue - if isinstance(value, (list, tuple,)): - write("]") + elif isinstance(value, (list, tuple)): + write(b']') init_size = size = size + 1 for item in reversed(value): - size = _rdumpq(q, size, item, encoding) - span = str(size - init_size) - write(":") + size = _rdumpq(q, size, item) + span = str(size - init_size).encode() + write(b':') write(span) return size + 1 + len(span) - if isinstance(value, dict): - write("}") + elif isinstance(value, dict): + write(b'}') init_size = size = size + 1 - for (k, v) in six.iteritems(value): - size = _rdumpq(q, size, v, encoding) - size = _rdumpq(q, size, k, encoding) - span = str(size - init_size) - write(":") + for (k, v) in value.items(): + size = _rdumpq(q, size, v) + size = _rdumpq(q, size, k) + span = str(size - init_size).encode() + write(b':') write(span) return size + 1 + len(span) - if isinstance(value, unicode): - if encoding is None: - raise ValueError("must specify encoding to dump unicode strings") - value = value.encode(encoding) - lvalue = len(value) - span = str(lvalue) - write(",") - write(value) - write(":") - write(span) - return size + 2 + len(span) + lvalue - raise ValueError("unserializable object") + else: + raise ValueError("unserializable object: {} ({})".format(value, type(value))) -def _gdumps(value, encoding): - """Generate fragments of value dumped as a tnetstring. +def _gdumps(value): + """ + Generate fragments of value dumped as a tnetstring. This is the naive dumping algorithm, implemented as a generator so that it's easy to pass to "".join() without building a new list. @@ -203,72 +192,63 @@ def _gdumps(value, encoding): measurably faster as it doesn't have to build intermediate strins. """ if value is None: - yield "0:~" + yield b'0:~' elif value is True: - yield "4:true!" + yield b'4:true!' elif value is False: - yield "5:false!" + yield b'5:false!' elif isinstance(value, six.integer_types): - data = str(value) - yield str(len(data)) - yield ":" + data = str(value).encode() + yield str(len(data)).encode() + yield b':' yield data - yield "#" - elif isinstance(value, (float,)): - data = repr(value) - yield str(len(data)) - yield ":" + yield b'#' + elif isinstance(value, float): + data = repr(value).encode() + yield str(len(data)).encode() + yield b':' yield data - yield "^" - elif isinstance(value, (str,)): - yield str(len(value)) - yield ":" + yield b'^' + elif isinstance(value, bytes): + yield str(len(value)).encode() + yield b':' yield value - yield "," - elif isinstance(value, (list, tuple,)): + yield b',' + elif isinstance(value, (list, tuple)): sub = [] for item in value: sub.extend(_gdumps(item)) - sub = "".join(sub) - yield str(len(sub)) - yield ":" + sub = b''.join(sub) + yield str(len(sub)).encode() + yield b':' yield sub - yield "]" + yield b']' elif isinstance(value, (dict,)): sub = [] - for (k, v) in six.iteritems(value): + for (k, v) in value.items(): sub.extend(_gdumps(k)) sub.extend(_gdumps(v)) - sub = "".join(sub) - yield str(len(sub)) - yield ":" + sub = b''.join(sub) + yield str(len(sub)).encode() + yield b':' yield sub - yield "}" - elif isinstance(value, (unicode,)): - if encoding is None: - raise ValueError("must specify encoding to dump unicode strings") - value = value.encode(encoding) - yield str(len(value)) - yield ":" - yield value - yield "," + yield b'}' else: raise ValueError("unserializable object") -def loads(string, encoding=None): - """loads(string,encoding=None) -> object - +def loads(string): + """ This function parses a tnetstring into a python object. """ # No point duplicating effort here. In the C-extension version, # loads() is measurably faster then pop() since it can avoid # the overhead of building a second string. - return pop(string, encoding)[0] + return pop(string)[0] -def load(file, encoding=None): - """load(file,encoding=None) -> object +def load(file_handle): + """load(file) -> object This function reads a tnetstring from a file and parses it into a python object. The file must support the read() method, and this @@ -276,70 +256,68 @@ def load(file, encoding=None): """ # Read the length prefix one char at a time. # Note that the netstring spec explicitly forbids padding zeros. - c = file.read(1) + c = file_handle.read(1) if not c.isdigit(): raise ValueError("not a tnetstring: missing or invalid length prefix") - datalen = ord(c) - ord("0") - c = file.read(1) + datalen = ord(c) - ord('0') + c = file_handle.read(1) if datalen != 0: while c.isdigit(): - datalen = (10 * datalen) + (ord(c) - ord("0")) + datalen = (10 * datalen) + (ord(c) - ord('0')) if datalen > 999999999: errmsg = "not a tnetstring: absurdly large length prefix" raise ValueError(errmsg) - c = file.read(1) - if c != ":": + c = file_handle.read(1) + if c != b':': raise ValueError("not a tnetstring: missing or invalid length prefix") # Now we can read and parse the payload. # This repeats the dispatch logic of pop() so we can avoid # re-constructing the outermost tnetstring. - data = file.read(datalen) + data = file_handle.read(datalen) if len(data) != datalen: raise ValueError("not a tnetstring: length prefix too big") - type = file.read(1) - if type == ",": - if encoding is not None: - return data.decode(encoding) + tns_type = file_handle.read(1) + if tns_type == b',': return data - if type == "#": + if tns_type == b'#': try: return int(data) except ValueError: raise ValueError("not a tnetstring: invalid integer literal") - if type == "^": + if tns_type == b'^': try: return float(data) except ValueError: raise ValueError("not a tnetstring: invalid float literal") - if type == "!": - if data == "true": + if tns_type == b'!': + if data == b'true': return True - elif data == "false": + elif data == b'false': return False else: raise ValueError("not a tnetstring: invalid boolean literal") - if type == "~": + if tns_type == b'~': if data: raise ValueError("not a tnetstring: invalid null literal") return None - if type == "]": + if tns_type == b']': l = [] while data: - (item, data) = pop(data, encoding) + item, data = pop(data) l.append(item) return l - if type == "}": + if tns_type == b'}': d = {} while data: - (key, data) = pop(data, encoding) - (val, data) = pop(data, encoding) + key, data = pop(data) + val, data = pop(data) d[key] = val return d raise ValueError("unknown type tag") -def pop(string, encoding=None): - """pop(string,encoding=None) -> (object, remain) +def pop(string): + """pop(string,encoding='utf_8') -> (object, remain) This function parses a tnetstring into a python object. It returns a tuple giving the parsed object and a string @@ -347,53 +325,51 @@ def pop(string, encoding=None): """ # Parse out data length, type and remaining string. try: - (dlen, rest) = string.split(":", 1) + dlen, rest = string.split(b':', 1) dlen = int(dlen) except ValueError: - raise ValueError("not a tnetstring: missing or invalid length prefix") + raise ValueError("not a tnetstring: missing or invalid length prefix: {}".format(string)) try: - (data, type, remain) = (rest[:dlen], rest[dlen], rest[dlen + 1:]) + data, tns_type, remain = rest[:dlen], rest[dlen:dlen + 1], rest[dlen + 1:] except IndexError: # This fires if len(rest) < dlen, meaning we don't need # to further validate that data is the right length. - raise ValueError("not a tnetstring: invalid length prefix") + raise ValueError("not a tnetstring: invalid length prefix: {}".format(dlen)) # Parse the data based on the type tag. - if type == ",": - if encoding is not None: - return (data.decode(encoding), remain) - return (data, remain) - if type == "#": + if tns_type == b',': + return data, remain + if tns_type == b'#': try: - return (int(data), remain) + return int(data), remain except ValueError: - raise ValueError("not a tnetstring: invalid integer literal") - if type == "^": + raise ValueError("not a tnetstring: invalid integer literal: {}".format(data)) + if tns_type == b'^': try: - return (float(data), remain) + return float(data), remain except ValueError: - raise ValueError("not a tnetstring: invalid float literal") - if type == "!": - if data == "true": - return (True, remain) - elif data == "false": - return (False, remain) + raise ValueError("not a tnetstring: invalid float literal: {}".format(data)) + if tns_type == b'!': + if data == b'true': + return True, remain + elif data == b'false': + return False, remain else: - raise ValueError("not a tnetstring: invalid boolean literal") - if type == "~": + raise ValueError("not a tnetstring: invalid boolean literal: {}".format(data)) + if tns_type == b'~': if data: raise ValueError("not a tnetstring: invalid null literal") - return (None, remain) - if type == "]": + return None, remain + if tns_type == b']': l = [] while data: - (item, data) = pop(data, encoding) + item, data = pop(data) l.append(item) return (l, remain) - if type == "}": + if tns_type == b'}': d = {} while data: - (key, data) = pop(data, encoding) - (val, data) = pop(data, encoding) + key, data = pop(data) + val, data = pop(data) d[key] = val - return (d, remain) - raise ValueError("unknown type tag") + return d, remain + raise ValueError("unknown type tag: {}".format(tns_type))