From 1b1ccab8b7f88c9e7e6f1d5ae8d6782bc9a1ac2e Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 19 Jun 2012 09:58:50 +1200 Subject: [PATCH] Extract protocol and tcp server implementations into netlib. --- .gitignore | 3 +- libmproxy/flow.py | 157 +----------------------------- libmproxy/netlib.py | 182 ---------------------------------- libmproxy/protocol.py | 220 ------------------------------------------ libmproxy/proxy.py | 29 +++--- libmproxy/utils.py | 5 +- setup.py | 2 +- test/test_netlib.py | 93 ------------------ test/test_protocol.py | 163 ------------------------------- 9 files changed, 23 insertions(+), 831 deletions(-) delete mode 100644 libmproxy/netlib.py delete mode 100644 libmproxy/protocol.py delete mode 100644 test/test_netlib.py delete mode 100644 test/test_protocol.py diff --git a/.gitignore b/.gitignore index 78b1cdb59..b88b179bd 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,5 @@ MANIFEST *.swo mitmproxyc mitmdumpc -mitmplaybackc -mitmrecordc +netlib .coverage diff --git a/libmproxy/flow.py b/libmproxy/flow.py index a737057e3..f9a9a75d7 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -21,11 +21,15 @@ import hashlib, Cookie, cookielib, copy, re, urlparse import time import tnetstring, filt, script, utils, encoding, proxy from email.utils import parsedate_tz, formatdate, mktime_tz -import controller, version, certutils, protocol +from netlib import odict, protocol +import controller, version, certutils HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" CONTENT_MISSING = 0 +ODict = odict.ODict +ODictCaseless = odict.ODictCaseless + class ReplaceHooks: def __init__(self): @@ -117,157 +121,6 @@ class ScriptContext: self._master.replay_request(f) -class ODict: - """ - A dictionary-like object for managing ordered (key, value) data. - """ - def __init__(self, lst=None): - self.lst = lst or [] - - def _kconv(self, s): - return s - - def __eq__(self, other): - return self.lst == other.lst - - def __getitem__(self, k): - """ - Returns a list of values matching key. - """ - ret = [] - k = self._kconv(k) - for i in self.lst: - if self._kconv(i[0]) == k: - ret.append(i[1]) - return ret - - def _filter_lst(self, k, lst): - k = self._kconv(k) - new = [] - for i in lst: - if self._kconv(i[0]) != k: - new.append(i) - return new - - def __len__(self): - """ - Total number of (key, value) pairs. - """ - return len(self.lst) - - def __setitem__(self, k, valuelist): - """ - Sets the values for key k. If there are existing values for this - key, they are cleared. - """ - if isinstance(valuelist, basestring): - raise ValueError("ODict valuelist should be lists.") - new = self._filter_lst(k, self.lst) - for i in valuelist: - new.append([k, i]) - self.lst = new - - def __delitem__(self, k): - """ - Delete all items matching k. - """ - self.lst = self._filter_lst(k, self.lst) - - def __contains__(self, k): - for i in self.lst: - if self._kconv(i[0]) == self._kconv(k): - return True - return False - - def add(self, key, value): - self.lst.append([key, str(value)]) - - def get(self, k, d=None): - if k in self: - return self[k] - else: - return d - - def items(self): - return self.lst[:] - - def _get_state(self): - return [tuple(i) for i in self.lst] - - @classmethod - def _from_state(klass, state): - return klass([list(i) for i in state]) - - def copy(self): - """ - Returns a copy of this object. - """ - lst = copy.deepcopy(self.lst) - return self.__class__(lst) - - def __repr__(self): - elements = [] - for itm in self.lst: - elements.append(itm[0] + ": " + itm[1]) - elements.append("") - return "\r\n".join(elements) - - def in_any(self, key, value, caseless=False): - """ - Do any of the values matching key contain value? - - If caseless is true, value comparison is case-insensitive. - """ - if caseless: - value = value.lower() - for i in self[key]: - if caseless: - i = i.lower() - if value in i: - return True - return False - - def match_re(self, expr): - """ - Match the regular expression against each (key, value) pair. For - each pair a string of the following format is matched against: - - "key: value" - """ - for k, v in self.lst: - s = "%s: %s"%(k, v) - if re.search(expr, s): - return True - return False - - def replace(self, pattern, repl, *args, **kwargs): - """ - Replaces a regular expression pattern with repl in both keys and - values. Encoded content will be decoded before replacement, and - re-encoded afterwards. - - Returns the number of replacements made. - """ - nlst, count = [], 0 - for i in self.lst: - k, c = utils.safe_subn(pattern, repl, i[0], *args, **kwargs) - count += c - v, c = utils.safe_subn(pattern, repl, i[1], *args, **kwargs) - count += c - nlst.append([k, v]) - self.lst = nlst - return count - - -class ODictCaseless(ODict): - """ - A variant of ODict with "caseless" keys. This version _preserves_ key - case, but does not consider case when setting or getting items. - """ - def _kconv(self, s): - return s.lower() - - class decoded(object): """ diff --git a/libmproxy/netlib.py b/libmproxy/netlib.py deleted file mode 100644 index 08ccba091..000000000 --- a/libmproxy/netlib.py +++ /dev/null @@ -1,182 +0,0 @@ -import select, socket, threading, traceback, sys -from OpenSSL import SSL - - -class NetLibError(Exception): pass - - -class FileLike: - def __init__(self, o): - self.o = o - - def __getattr__(self, attr): - return getattr(self.o, attr) - - def flush(self): - pass - - def read(self, length): - result = '' - while len(result) < length: - try: - data = self.o.read(length) - except SSL.ZeroReturnError: - break - if not data: - break - result += data - return result - - def write(self, v): - self.o.sendall(v) - - def readline(self, size = None): - result = '' - bytes_read = 0 - while True: - if size is not None and bytes_read >= size: - break - ch = self.read(1) - bytes_read += 1 - if not ch: - break - else: - result += ch - if ch == '\n': - break - return result - - -class TCPClient: - def __init__(self, ssl, host, port, clientcert): - self.ssl, self.host, self.port, self.clientcert = ssl, host, port, clientcert - self.connection, self.rfile, self.wfile = None, None, None - self.cert = None - self.connect() - - def connect(self): - try: - addr = socket.gethostbyname(self.host) - server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - if self.ssl: - context = SSL.Context(SSL.SSLv23_METHOD) - if self.clientcert: - context.use_certificate_file(self.clientcert) - server = SSL.Connection(context, server) - server.connect((addr, self.port)) - if self.ssl: - self.cert = server.get_peer_certificate() - self.rfile, self.wfile = FileLike(server), FileLike(server) - else: - self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb') - except socket.error, err: - raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) - self.connection = server - - -class BaseHandler: - rbufsize = -1 - wbufsize = 0 - def __init__(self, connection, client_address, server): - self.connection = connection - self.rfile = self.connection.makefile('rb', self.rbufsize) - self.wfile = self.connection.makefile('wb', self.wbufsize) - - self.client_address = client_address - self.server = server - self.handle() - self.finish() - - def convert_to_ssl(self, cert, key): - ctx = SSL.Context(SSL.SSLv23_METHOD) - ctx.use_privatekey_file(key) - ctx.use_certificate_file(cert) - self.connection = SSL.Connection(ctx, self.connection) - self.connection.set_accept_state() - self.rfile = FileLike(self.connection) - self.wfile = FileLike(self.connection) - - def finish(self): - try: - if not getattr(self.wfile, "closed", False): - self.wfile.flush() - self.connection.close() - self.wfile.close() - self.rfile.close() - except IOError: # pragma: no cover - pass - - def handle(self): # pragma: no cover - raise NotImplementedError - - -class TCPServer: - request_queue_size = 20 - def __init__(self, server_address): - self.server_address = server_address - self.__is_shut_down = threading.Event() - self.__shutdown_request = False - self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.socket.bind(self.server_address) - self.server_address = self.socket.getsockname() - self.socket.listen(self.request_queue_size) - self.port = self.socket.getsockname()[1] - - def request_thread(self, request, client_address): - try: - self.handle_connection(request, client_address) - request.close() - except: - self.handle_error(request, client_address) - request.close() - - def serve_forever(self, poll_interval=0.5): - self.__is_shut_down.clear() - try: - while not self.__shutdown_request: - r, w, e = select.select([self.socket], [], [], poll_interval) - if self.socket in r: - try: - request, client_address = self.socket.accept() - except socket.error: - return - try: - t = threading.Thread( - target = self.request_thread, - args = (request, client_address) - ) - t.setDaemon(1) - t.start() - except: - self.handle_error(request, client_address) - request.close() - finally: - self.__shutdown_request = False - self.__is_shut_down.set() - - def shutdown(self): - self.__shutdown_request = True - self.__is_shut_down.wait() - self.handle_shutdown() - - def handle_error(self, request, client_address, fp=sys.stderr): - """ - Called when handle_connection raises an exception. - """ - print >> fp, '-'*40 - print >> fp, "Error processing of request from %s:%s"%client_address - print >> fp, traceback.format_exc() - print >> fp, '-'*40 - - def handle_connection(self, request, client_address): # pragma: no cover - """ - Called after client connection. - """ - raise NotImplementedError - - def handle_shutdown(self): - """ - Called after server shutdown. - """ - pass diff --git a/libmproxy/protocol.py b/libmproxy/protocol.py deleted file mode 100644 index 547bff9e1..000000000 --- a/libmproxy/protocol.py +++ /dev/null @@ -1,220 +0,0 @@ -import string, urlparse - -class ProtocolError(Exception): - def __init__(self, code, msg): - self.code, self.msg = code, msg - - def __str__(self): - return "ProtocolError(%s, %s)"%(self.code, self.msg) - - -def parse_url(url): - """ - Returns a (scheme, host, port, path) tuple, or None on error. - """ - scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) - if not scheme: - return None - if ':' in netloc: - host, port = string.rsplit(netloc, ':', maxsplit=1) - try: - port = int(port) - except ValueError: - return None - else: - host = netloc - if scheme == "https": - port = 443 - else: - port = 80 - path = urlparse.urlunparse(('', '', path, params, query, fragment)) - if not path.startswith("/"): - path = "/" + path - return scheme, host, port, path - - -def read_headers(fp): - """ - Read a set of headers from a file pointer. Stop once a blank line - is reached. Return a ODictCaseless object. - """ - ret = [] - name = '' - while 1: - line = fp.readline() - if not line or line == '\r\n' or line == '\n': - break - if line[0] in ' \t': - # continued header - ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() - else: - i = line.find(':') - # We're being liberal in what we accept, here. - if i > 0: - name = line[:i] - value = line[i+1:].strip() - ret.append([name, value]) - return ret - - -def read_chunked(fp, limit): - content = "" - total = 0 - while 1: - line = fp.readline(128) - if line == "": - raise IOError("Connection closed") - if line == '\r\n' or line == '\n': - continue - try: - length = int(line,16) - except ValueError: - # FIXME: Not strictly correct - this could be from the server, in which - # case we should send a 502. - raise ProtocolError(400, "Invalid chunked encoding length: %s"%line) - if not length: - break - total += length - if limit is not None and total > limit: - msg = "HTTP Body too large."\ - " Limit is %s, chunked content length was at least %s"%(limit, total) - raise ProtocolError(509, msg) - content += fp.read(length) - line = fp.readline(5) - if line != '\r\n': - raise IOError("Malformed chunked body") - while 1: - line = fp.readline() - if line == "": - raise IOError("Connection closed") - if line == '\r\n' or line == '\n': - break - return content - - -def has_chunked_encoding(headers): - for i in headers["transfer-encoding"]: - for j in i.split(","): - if j.lower() == "chunked": - return True - return False - - -def read_http_body(rfile, headers, all, limit): - if has_chunked_encoding(headers): - content = read_chunked(rfile, limit) - elif "content-length" in headers: - try: - l = int(headers["content-length"][0]) - except ValueError: - # FIXME: Not strictly correct - this could be from the server, in which - # case we should send a 502. - raise ProtocolError(400, "Invalid content-length header: %s"%headers["content-length"]) - if limit is not None and l > limit: - raise ProtocolError(509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) - content = rfile.read(l) - elif all: - content = rfile.read(limit if limit else None) - else: - content = "" - return content - - -def parse_http_protocol(s): - if not s.startswith("HTTP/"): - return None - major, minor = s.split('/')[1].split('.') - major = int(major) - minor = int(minor) - return major, minor - - -def parse_init_connect(line): - try: - method, url, protocol = string.split(line) - except ValueError: - return None - if method != 'CONNECT': - return None - try: - host, port = url.split(":") - except ValueError: - return None - port = int(port) - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None - return host, port, httpversion - - -def parse_init_proxy(line): - try: - method, url, protocol = string.split(line) - except ValueError: - return None - parts = parse_url(url) - if not parts: - return None - scheme, host, port, path = parts - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None - return method, scheme, host, port, path, httpversion - - -def parse_init_http(line): - """ - Returns (method, url, httpversion) - """ - try: - method, url, protocol = string.split(line) - except ValueError: - return None - if not (url.startswith("/") or url == "*"): - return None - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None - return method, url, httpversion - - -def request_connection_close(httpversion, headers): - """ - Checks the request to see if the client connection should be closed. - """ - if "connection" in headers: - for value in ",".join(headers['connection']).split(","): - value = value.strip() - if value == "close": - return True - elif value == "keep-alive": - return False - # HTTP 1.1 connections are assumed to be persistent - if httpversion == (1, 1): - return False - return True - - -def response_connection_close(httpversion, headers): - """ - Checks the response to see if the client connection should be closed. - """ - if request_connection_close(httpversion, headers): - return True - elif not has_chunked_encoding(headers) and "content-length" in headers: - return True - return False - - -def read_http_body_request(rfile, wfile, headers, httpversion, limit): - if "expect" in headers: - # FIXME: Should be forwarded upstream - expect = ",".join(headers['expect']) - if expect == "100-continue" and httpversion >= (1, 1): - wfile.write('HTTP/1.1 100 Continue\r\n') - wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION) - wfile.write('\r\n') - del headers['expect'] - return read_http_body(rfile, headers, False, limit) - - diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 58ab7a580..04734fcbb 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -15,8 +15,9 @@ import sys, os, string, socket, time import shutil, tempfile, threading import optparse, SocketServer -import utils, flow, certutils, version, wsgi, netlib, protocol from OpenSSL import SSL +from netlib import odict, tcp, protocol +import utils, flow, certutils, version, wsgi class ProxyError(Exception): @@ -56,18 +57,18 @@ class RequestReplayThread(threading.Thread): except (ProxyError, protocol.ProtocolError), v: err = flow.Error(self.flow.request, v.msg) err._send(self.masterq) - except netlib.NetLibError, v: + except tcp.NetLibError, v: raise ProxyError(502, v) -class ServerConnection(netlib.TCPClient): +class ServerConnection(tcp.TCPClient): def __init__(self, config, scheme, host, port): clientcert = None if config.clientcerts: path = os.path.join(config.clientcerts, self.host) + ".pem" if os.path.exists(clientcert): clientcert = path - netlib.TCPClient.__init__( + tcp.TCPClient.__init__( self, True if scheme == "https" else False, host, @@ -107,7 +108,7 @@ class ServerConnection(netlib.TCPClient): code = int(code) except ValueError: raise ProxyError(502, "Invalid server response: %s."%line) - headers = flow.ODictCaseless(protocol.read_headers(self.rfile)) + headers = odict.ODictCaseless(protocol.read_headers(self.rfile)) if code >= 100 and code <= 199: return self.read_response() if request.method == "HEAD" or code == 204 or code == 304: @@ -125,13 +126,13 @@ class ServerConnection(netlib.TCPClient): pass -class ProxyHandler(netlib.BaseHandler): +class ProxyHandler(tcp.BaseHandler): def __init__(self, config, connection, client_address, server, q): self.mqueue = q self.config = config self.server_conn = None self.proxy_connect_state = None - netlib.BaseHandler.__init__(self, connection, client_address, server) + tcp.BaseHandler.__init__(self, connection, client_address, server) def handle(self): cc = flow.ClientConnect(self.client_address) @@ -150,7 +151,7 @@ class ProxyHandler(netlib.BaseHandler): if not self.server_conn: try: self.server_conn = ServerConnection(self.config, scheme, host, port) - except netlib.NetLibError, v: + except tcp.NetLibError, v: raise ProxyError(502, v) def handle_request(self, cc): @@ -243,7 +244,7 @@ class ProxyHandler(netlib.BaseHandler): else: scheme = "http" method, path, httpversion = protocol.parse_init_http(line) - headers = flow.ODictCaseless(protocol.read_headers(self.rfile)) + headers = odict.ODictCaseless(protocol.read_headers(self.rfile)) content = protocol.read_http_body_request( self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit ) @@ -251,7 +252,7 @@ class ProxyHandler(netlib.BaseHandler): elif self.config.reverse_proxy: scheme, host, port = self.config.reverse_proxy method, path, httpversion = protocol.parse_init_http(line) - headers = flow.ODictCaseless(protocol.read_headers(self.rfile)) + headers = odict.ODictCaseless(protocol.read_headers(self.rfile)) content = protocol.read_http_body_request( self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit ) @@ -278,14 +279,14 @@ class ProxyHandler(netlib.BaseHandler): if self.proxy_connect_state: host, port, httpversion = self.proxy_connect_state method, path, httpversion = protocol.parse_init_http(line) - headers = flow.ODictCaseless(protocol.read_headers(self.rfile)) + headers = odict.ODictCaseless(protocol.read_headers(self.rfile)) content = protocol.read_http_body_request( self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit ) return flow.Request(client_conn, httpversion, host, port, "https", method, path, headers, content) else: method, scheme, host, port, path, httpversion = protocol.parse_init_proxy(line) - headers = flow.ODictCaseless(protocol.read_headers(self.rfile)) + headers = odict.ODictCaseless(protocol.read_headers(self.rfile)) content = protocol.read_http_body_request( self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit ) @@ -317,7 +318,7 @@ class ProxyHandler(netlib.BaseHandler): class ProxyServerError(Exception): pass -class ProxyServer(netlib.TCPServer): +class ProxyServer(tcp.TCPServer): allow_reuse_address = True bound = True def __init__(self, config, port, address=''): @@ -326,7 +327,7 @@ class ProxyServer(netlib.TCPServer): """ self.config, self.port, self.address = config, port, address try: - netlib.TCPServer.__init__(self, (address, port)) + tcp.TCPServer.__init__(self, (address, port)) except socket.error, v: raise ProxyServerError('Error starting proxy server: ' + v.strerror) self.masterq = None diff --git a/libmproxy/utils.py b/libmproxy/utils.py index 989bb6951..35c7a8782 100644 --- a/libmproxy/utils.py +++ b/libmproxy/utils.py @@ -15,7 +15,7 @@ import os, datetime, urlparse, string, urllib, re import time, functools, cgi import json -import protocol +from netlib import protocol def timestamp(): """ @@ -294,6 +294,3 @@ def safe_subn(pattern, repl, target, *args, **kwargs): need a better solution that is aware of the actual content ecoding. """ return re.subn(str(pattern), str(repl), target, *args, **kwargs) - - - diff --git a/setup.py b/setup.py index 4070eb1bf..88a39f381 100644 --- a/setup.py +++ b/setup.py @@ -92,5 +92,5 @@ setup( "Topic :: Internet :: Proxy Servers", "Topic :: Software Development :: Testing" ], - install_requires=['urwid>=1.0', 'pyasn1>0.1.2', 'pyopenssl>=0.12', "PIL", "lxml"], + install_requires=["netlib", "urwid>=1.0", "pyasn1>0.1.2", "pyopenssl>=0.12", "PIL", "lxml"], ) diff --git a/test/test_netlib.py b/test/test_netlib.py deleted file mode 100644 index 19902d177..000000000 --- a/test/test_netlib.py +++ /dev/null @@ -1,93 +0,0 @@ -import cStringIO, threading, Queue -from libmproxy import netlib -import tutils - -class ServerThread(threading.Thread): - def __init__(self, server): - self.server = server - threading.Thread.__init__(self) - - def run(self): - self.server.serve_forever() - - def shutdown(self): - self.server.shutdown() - - -class ServerTestBase: - @classmethod - def setupAll(cls): - cls.server = ServerThread(cls.makeserver()) - cls.server.start() - - @classmethod - def teardownAll(cls): - cls.server.shutdown() - - -class THandler(netlib.BaseHandler): - def handle(self): - v = self.rfile.readline() - if v.startswith("echo"): - self.wfile.write(v) - elif v.startswith("error"): - raise ValueError("Testing an error.") - self.wfile.flush() - - -class TServer(netlib.TCPServer): - def __init__(self, addr, q): - netlib.TCPServer.__init__(self, addr) - self.q = q - - def handle_connection(self, request, client_address): - THandler(request, client_address, self) - - def handle_error(self, request, client_address): - s = cStringIO.StringIO() - netlib.TCPServer.handle_error(self, request, client_address, s) - self.q.put(s.getvalue()) - - -class TestServer(ServerTestBase): - @classmethod - def makeserver(cls): - cls.q = Queue.Queue() - s = TServer(("127.0.0.1", 0), cls.q) - cls.port = s.port - return s - - def test_echo(self): - testval = "echo!\n" - c = netlib.TCPClient(False, "127.0.0.1", self.port, None) - c.wfile.write(testval) - c.wfile.flush() - assert c.rfile.readline() == testval - - def test_error(self): - testval = "error!\n" - c = netlib.TCPClient(False, "127.0.0.1", self.port, None) - c.wfile.write(testval) - c.wfile.flush() - assert "Testing an error" in self.q.get() - - -class TestTCPClient: - def test_conerr(self): - tutils.raises(netlib.NetLibError, netlib.TCPClient, False, "127.0.0.1", 0, None) - - -class TestFileLike: - def test_wrap(self): - s = cStringIO.StringIO("foobar\nfoobar") - s = netlib.FileLike(s) - s.flush() - assert s.readline() == "foobar\n" - assert s.readline() == "foobar" - # Test __getattr__ - assert s.isatty - - def test_limit(self): - s = cStringIO.StringIO("foobar\nfoobar") - s = netlib.FileLike(s) - assert s.readline(3) == "foo" diff --git a/test/test_protocol.py b/test/test_protocol.py deleted file mode 100644 index 81b5fefb6..000000000 --- a/test/test_protocol.py +++ /dev/null @@ -1,163 +0,0 @@ -import cStringIO, textwrap -from libmproxy import protocol, flow -import tutils - -def test_has_chunked_encoding(): - h = flow.ODictCaseless() - assert not protocol.has_chunked_encoding(h) - h["transfer-encoding"] = ["chunked"] - assert protocol.has_chunked_encoding(h) - - -def test_read_chunked(): - s = cStringIO.StringIO("1\r\na\r\n0\r\n") - tutils.raises(IOError, protocol.read_chunked, s, None) - - s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") - assert protocol.read_chunked(s, None) == "a" - - s = cStringIO.StringIO("\r\n") - tutils.raises(IOError, protocol.read_chunked, s, None) - - s = cStringIO.StringIO("1\r\nfoo") - tutils.raises(IOError, protocol.read_chunked, s, None) - - s = cStringIO.StringIO("foo\r\nfoo") - tutils.raises(protocol.ProtocolError, protocol.read_chunked, s, None) - - -def test_request_connection_close(): - h = flow.ODictCaseless() - assert protocol.request_connection_close((1, 0), h) - assert not protocol.request_connection_close((1, 1), h) - - h["connection"] = ["keep-alive"] - assert not protocol.request_connection_close((1, 1), h) - - -def test_read_http_body(): - h = flow.ODict() - s = cStringIO.StringIO("testing") - assert protocol.read_http_body(s, h, False, None) == "" - - h["content-length"] = ["foo"] - s = cStringIO.StringIO("testing") - tutils.raises(protocol.ProtocolError, protocol.read_http_body, s, h, False, None) - - h["content-length"] = [5] - s = cStringIO.StringIO("testing") - assert len(protocol.read_http_body(s, h, False, None)) == 5 - s = cStringIO.StringIO("testing") - tutils.raises(protocol.ProtocolError, protocol.read_http_body, s, h, False, 4) - - h = flow.ODict() - s = cStringIO.StringIO("testing") - assert len(protocol.read_http_body(s, h, True, 4)) == 4 - s = cStringIO.StringIO("testing") - assert len(protocol.read_http_body(s, h, True, 100)) == 7 - -def test_parse_http_protocol(): - assert protocol.parse_http_protocol("HTTP/1.1") == (1, 1) - assert protocol.parse_http_protocol("HTTP/0.0") == (0, 0) - assert not protocol.parse_http_protocol("foo/0.0") - - -def test_parse_init_connect(): - assert protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0") - assert not protocol.parse_init_connect("bogus") - assert not protocol.parse_init_connect("GET host.com:443 HTTP/1.0") - assert not protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0") - assert not protocol.parse_init_connect("CONNECT host.com:443 foo/1.0") - - -def test_prase_init_proxy(): - u = "GET http://foo.com:8888/test HTTP/1.1" - m, s, h, po, pa, httpversion = protocol.parse_init_proxy(u) - assert m == "GET" - assert s == "http" - assert h == "foo.com" - assert po == 8888 - assert pa == "/test" - assert httpversion == (1, 1) - - assert not protocol.parse_init_proxy("invalid") - assert not protocol.parse_init_proxy("GET invalid HTTP/1.1") - assert not protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1") - - -def test_parse_init_http(): - u = "GET /test HTTP/1.1" - m, u, httpversion= protocol.parse_init_http(u) - assert m == "GET" - assert u == "/test" - assert httpversion == (1, 1) - - assert not protocol.parse_init_http("invalid") - assert not protocol.parse_init_http("GET invalid HTTP/1.1") - assert not protocol.parse_init_http("GET /test foo/1.1") - - -class TestReadHeaders: - def test_read_simple(self): - data = """ - Header: one - Header2: two - \r\n - """ - data = textwrap.dedent(data) - data = data.strip() - s = cStringIO.StringIO(data) - h = protocol.read_headers(s) - assert h == [["Header", "one"], ["Header2", "two"]] - - def test_read_multi(self): - data = """ - Header: one - Header: two - \r\n - """ - data = textwrap.dedent(data) - data = data.strip() - s = cStringIO.StringIO(data) - h = protocol.read_headers(s) - assert h == [["Header", "one"], ["Header", "two"]] - - def test_read_continued(self): - data = """ - Header: one - \ttwo - Header2: three - \r\n - """ - data = textwrap.dedent(data) - data = data.strip() - s = cStringIO.StringIO(data) - h = protocol.read_headers(s) - assert h == [["Header", "one\r\n two"], ["Header2", "three"]] - - -def test_parse_url(): - assert not protocol.parse_url("") - - u = "http://foo.com:8888/test" - s, h, po, pa = protocol.parse_url(u) - assert s == "http" - assert h == "foo.com" - assert po == 8888 - assert pa == "/test" - - s, h, po, pa = protocol.parse_url("http://foo/bar") - assert s == "http" - assert h == "foo" - assert po == 80 - assert pa == "/bar" - - s, h, po, pa = protocol.parse_url("http://foo") - assert pa == "/" - - s, h, po, pa = protocol.parse_url("https://foo") - assert po == 443 - - assert not protocol.parse_url("https://foo:bar") - assert not protocol.parse_url("https://foo:") -