mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 18:18:25 +00:00
Initial checkin.
This commit is contained in:
commit
b558997fd9
2
.coveragerc
Normal file
2
.coveragerc
Normal file
@ -0,0 +1,2 @@
|
||||
[report]
|
||||
include = *netlib*
|
9
.gitignore
vendored
Normal file
9
.gitignore
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
MANIFEST
|
||||
/build
|
||||
/dist
|
||||
/tmp
|
||||
/doc
|
||||
*.py[cdo]
|
||||
*.swp
|
||||
*.swo
|
||||
.coverage
|
2
README
Normal file
2
README
Normal file
@ -0,0 +1,2 @@
|
||||
Netlib is a collection of common utility functions, used by the pathod and
|
||||
mitmproxy projects.
|
0
netlib/__init__.py
Normal file
0
netlib/__init__.py
Normal file
160
netlib/odict.py
Normal file
160
netlib/odict.py
Normal file
@ -0,0 +1,160 @@
|
||||
import re, copy
|
||||
|
||||
def safe_subn(pattern, repl, target, *args, **kwargs):
|
||||
"""
|
||||
There are Unicode conversion problems with re.subn. We try to smooth
|
||||
that over by casting the pattern and replacement to strings. We really
|
||||
need a better solution that is aware of the actual content ecoding.
|
||||
"""
|
||||
return re.subn(str(pattern), str(repl), target, *args, **kwargs)
|
||||
|
||||
|
||||
class ODict:
|
||||
"""
|
||||
A dictionary-like object for managing ordered (key, value) data.
|
||||
"""
|
||||
def __init__(self, lst=None):
|
||||
self.lst = lst or []
|
||||
|
||||
def _kconv(self, s):
|
||||
return s
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.lst == other.lst
|
||||
|
||||
def __getitem__(self, k):
|
||||
"""
|
||||
Returns a list of values matching key.
|
||||
"""
|
||||
ret = []
|
||||
k = self._kconv(k)
|
||||
for i in self.lst:
|
||||
if self._kconv(i[0]) == k:
|
||||
ret.append(i[1])
|
||||
return ret
|
||||
|
||||
def _filter_lst(self, k, lst):
|
||||
k = self._kconv(k)
|
||||
new = []
|
||||
for i in lst:
|
||||
if self._kconv(i[0]) != k:
|
||||
new.append(i)
|
||||
return new
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Total number of (key, value) pairs.
|
||||
"""
|
||||
return len(self.lst)
|
||||
|
||||
def __setitem__(self, k, valuelist):
|
||||
"""
|
||||
Sets the values for key k. If there are existing values for this
|
||||
key, they are cleared.
|
||||
"""
|
||||
if isinstance(valuelist, basestring):
|
||||
raise ValueError("ODict valuelist should be lists.")
|
||||
new = self._filter_lst(k, self.lst)
|
||||
for i in valuelist:
|
||||
new.append([k, i])
|
||||
self.lst = new
|
||||
|
||||
def __delitem__(self, k):
|
||||
"""
|
||||
Delete all items matching k.
|
||||
"""
|
||||
self.lst = self._filter_lst(k, self.lst)
|
||||
|
||||
def __contains__(self, k):
|
||||
for i in self.lst:
|
||||
if self._kconv(i[0]) == self._kconv(k):
|
||||
return True
|
||||
return False
|
||||
|
||||
def add(self, key, value):
|
||||
self.lst.append([key, str(value)])
|
||||
|
||||
def get(self, k, d=None):
|
||||
if k in self:
|
||||
return self[k]
|
||||
else:
|
||||
return d
|
||||
|
||||
def items(self):
|
||||
return self.lst[:]
|
||||
|
||||
def _get_state(self):
|
||||
return [tuple(i) for i in self.lst]
|
||||
|
||||
@classmethod
|
||||
def _from_state(klass, state):
|
||||
return klass([list(i) for i in state])
|
||||
|
||||
def copy(self):
|
||||
"""
|
||||
Returns a copy of this object.
|
||||
"""
|
||||
lst = copy.deepcopy(self.lst)
|
||||
return self.__class__(lst)
|
||||
|
||||
def __repr__(self):
|
||||
elements = []
|
||||
for itm in self.lst:
|
||||
elements.append(itm[0] + ": " + itm[1])
|
||||
elements.append("")
|
||||
return "\r\n".join(elements)
|
||||
|
||||
def in_any(self, key, value, caseless=False):
|
||||
"""
|
||||
Do any of the values matching key contain value?
|
||||
|
||||
If caseless is true, value comparison is case-insensitive.
|
||||
"""
|
||||
if caseless:
|
||||
value = value.lower()
|
||||
for i in self[key]:
|
||||
if caseless:
|
||||
i = i.lower()
|
||||
if value in i:
|
||||
return True
|
||||
return False
|
||||
|
||||
def match_re(self, expr):
|
||||
"""
|
||||
Match the regular expression against each (key, value) pair. For
|
||||
each pair a string of the following format is matched against:
|
||||
|
||||
"key: value"
|
||||
"""
|
||||
for k, v in self.lst:
|
||||
s = "%s: %s"%(k, v)
|
||||
if re.search(expr, s):
|
||||
return True
|
||||
return False
|
||||
|
||||
def replace(self, pattern, repl, *args, **kwargs):
|
||||
"""
|
||||
Replaces a regular expression pattern with repl in both keys and
|
||||
values. Encoded content will be decoded before replacement, and
|
||||
re-encoded afterwards.
|
||||
|
||||
Returns the number of replacements made.
|
||||
"""
|
||||
nlst, count = [], 0
|
||||
for i in self.lst:
|
||||
k, c = safe_subn(pattern, repl, i[0], *args, **kwargs)
|
||||
count += c
|
||||
v, c = safe_subn(pattern, repl, i[1], *args, **kwargs)
|
||||
count += c
|
||||
nlst.append([k, v])
|
||||
self.lst = nlst
|
||||
return count
|
||||
|
||||
|
||||
class ODictCaseless(ODict):
|
||||
"""
|
||||
A variant of ODict with "caseless" keys. This version _preserves_ key
|
||||
case, but does not consider case when setting or getting items.
|
||||
"""
|
||||
def _kconv(self, s):
|
||||
return s.lower()
|
218
netlib/protocol.py
Normal file
218
netlib/protocol.py
Normal file
@ -0,0 +1,218 @@
|
||||
import string, urlparse
|
||||
|
||||
class ProtocolError(Exception):
|
||||
def __init__(self, code, msg):
|
||||
self.code, self.msg = code, msg
|
||||
|
||||
def __str__(self):
|
||||
return "ProtocolError(%s, %s)"%(self.code, self.msg)
|
||||
|
||||
|
||||
def parse_url(url):
|
||||
"""
|
||||
Returns a (scheme, host, port, path) tuple, or None on error.
|
||||
"""
|
||||
scheme, netloc, path, params, query, fragment = urlparse.urlparse(url)
|
||||
if not scheme:
|
||||
return None
|
||||
if ':' in netloc:
|
||||
host, port = string.rsplit(netloc, ':', maxsplit=1)
|
||||
try:
|
||||
port = int(port)
|
||||
except ValueError:
|
||||
return None
|
||||
else:
|
||||
host = netloc
|
||||
if scheme == "https":
|
||||
port = 443
|
||||
else:
|
||||
port = 80
|
||||
path = urlparse.urlunparse(('', '', path, params, query, fragment))
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
return scheme, host, port, path
|
||||
|
||||
|
||||
def read_headers(fp):
|
||||
"""
|
||||
Read a set of headers from a file pointer. Stop once a blank line
|
||||
is reached. Return a ODictCaseless object.
|
||||
"""
|
||||
ret = []
|
||||
name = ''
|
||||
while 1:
|
||||
line = fp.readline()
|
||||
if not line or line == '\r\n' or line == '\n':
|
||||
break
|
||||
if line[0] in ' \t':
|
||||
# continued header
|
||||
ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip()
|
||||
else:
|
||||
i = line.find(':')
|
||||
# We're being liberal in what we accept, here.
|
||||
if i > 0:
|
||||
name = line[:i]
|
||||
value = line[i+1:].strip()
|
||||
ret.append([name, value])
|
||||
return ret
|
||||
|
||||
|
||||
def read_chunked(fp, limit):
|
||||
content = ""
|
||||
total = 0
|
||||
while 1:
|
||||
line = fp.readline(128)
|
||||
if line == "":
|
||||
raise IOError("Connection closed")
|
||||
if line == '\r\n' or line == '\n':
|
||||
continue
|
||||
try:
|
||||
length = int(line,16)
|
||||
except ValueError:
|
||||
# FIXME: Not strictly correct - this could be from the server, in which
|
||||
# case we should send a 502.
|
||||
raise ProtocolError(400, "Invalid chunked encoding length: %s"%line)
|
||||
if not length:
|
||||
break
|
||||
total += length
|
||||
if limit is not None and total > limit:
|
||||
msg = "HTTP Body too large."\
|
||||
" Limit is %s, chunked content length was at least %s"%(limit, total)
|
||||
raise ProtocolError(509, msg)
|
||||
content += fp.read(length)
|
||||
line = fp.readline(5)
|
||||
if line != '\r\n':
|
||||
raise IOError("Malformed chunked body")
|
||||
while 1:
|
||||
line = fp.readline()
|
||||
if line == "":
|
||||
raise IOError("Connection closed")
|
||||
if line == '\r\n' or line == '\n':
|
||||
break
|
||||
return content
|
||||
|
||||
|
||||
def has_chunked_encoding(headers):
|
||||
for i in headers["transfer-encoding"]:
|
||||
for j in i.split(","):
|
||||
if j.lower() == "chunked":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def read_http_body(rfile, headers, all, limit):
|
||||
if has_chunked_encoding(headers):
|
||||
content = read_chunked(rfile, limit)
|
||||
elif "content-length" in headers:
|
||||
try:
|
||||
l = int(headers["content-length"][0])
|
||||
except ValueError:
|
||||
# FIXME: Not strictly correct - this could be from the server, in which
|
||||
# case we should send a 502.
|
||||
raise ProtocolError(400, "Invalid content-length header: %s"%headers["content-length"])
|
||||
if limit is not None and l > limit:
|
||||
raise ProtocolError(509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l))
|
||||
content = rfile.read(l)
|
||||
elif all:
|
||||
content = rfile.read(limit if limit else None)
|
||||
else:
|
||||
content = ""
|
||||
return content
|
||||
|
||||
|
||||
def parse_http_protocol(s):
|
||||
if not s.startswith("HTTP/"):
|
||||
return None
|
||||
major, minor = s.split('/')[1].split('.')
|
||||
major = int(major)
|
||||
minor = int(minor)
|
||||
return major, minor
|
||||
|
||||
|
||||
def parse_init_connect(line):
|
||||
try:
|
||||
method, url, protocol = string.split(line)
|
||||
except ValueError:
|
||||
return None
|
||||
if method != 'CONNECT':
|
||||
return None
|
||||
try:
|
||||
host, port = url.split(":")
|
||||
except ValueError:
|
||||
return None
|
||||
port = int(port)
|
||||
httpversion = parse_http_protocol(protocol)
|
||||
if not httpversion:
|
||||
return None
|
||||
return host, port, httpversion
|
||||
|
||||
|
||||
def parse_init_proxy(line):
|
||||
try:
|
||||
method, url, protocol = string.split(line)
|
||||
except ValueError:
|
||||
return None
|
||||
parts = parse_url(url)
|
||||
if not parts:
|
||||
return None
|
||||
scheme, host, port, path = parts
|
||||
httpversion = parse_http_protocol(protocol)
|
||||
if not httpversion:
|
||||
return None
|
||||
return method, scheme, host, port, path, httpversion
|
||||
|
||||
|
||||
def parse_init_http(line):
|
||||
"""
|
||||
Returns (method, url, httpversion)
|
||||
"""
|
||||
try:
|
||||
method, url, protocol = string.split(line)
|
||||
except ValueError:
|
||||
return None
|
||||
if not (url.startswith("/") or url == "*"):
|
||||
return None
|
||||
httpversion = parse_http_protocol(protocol)
|
||||
if not httpversion:
|
||||
return None
|
||||
return method, url, httpversion
|
||||
|
||||
|
||||
def request_connection_close(httpversion, headers):
|
||||
"""
|
||||
Checks the request to see if the client connection should be closed.
|
||||
"""
|
||||
if "connection" in headers:
|
||||
for value in ",".join(headers['connection']).split(","):
|
||||
value = value.strip()
|
||||
if value == "close":
|
||||
return True
|
||||
elif value == "keep-alive":
|
||||
return False
|
||||
# HTTP 1.1 connections are assumed to be persistent
|
||||
if httpversion == (1, 1):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def response_connection_close(httpversion, headers):
|
||||
"""
|
||||
Checks the response to see if the client connection should be closed.
|
||||
"""
|
||||
if request_connection_close(httpversion, headers):
|
||||
return True
|
||||
elif not has_chunked_encoding(headers) and "content-length" in headers:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def read_http_body_request(rfile, wfile, headers, httpversion, limit):
|
||||
if "expect" in headers:
|
||||
# FIXME: Should be forwarded upstream
|
||||
expect = ",".join(headers['expect'])
|
||||
if expect == "100-continue" and httpversion >= (1, 1):
|
||||
wfile.write('HTTP/1.1 100 Continue\r\n')
|
||||
wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION)
|
||||
wfile.write('\r\n')
|
||||
del headers['expect']
|
||||
return read_http_body(rfile, headers, False, limit)
|
182
netlib/tcp.py
Normal file
182
netlib/tcp.py
Normal file
@ -0,0 +1,182 @@
|
||||
import select, socket, threading, traceback, sys
|
||||
from OpenSSL import SSL
|
||||
|
||||
|
||||
class NetLibError(Exception): pass
|
||||
|
||||
|
||||
class FileLike:
|
||||
def __init__(self, o):
|
||||
self.o = o
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.o, attr)
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
def read(self, length):
|
||||
result = ''
|
||||
while len(result) < length:
|
||||
try:
|
||||
data = self.o.read(length)
|
||||
except SSL.ZeroReturnError:
|
||||
break
|
||||
if not data:
|
||||
break
|
||||
result += data
|
||||
return result
|
||||
|
||||
def write(self, v):
|
||||
self.o.sendall(v)
|
||||
|
||||
def readline(self, size = None):
|
||||
result = ''
|
||||
bytes_read = 0
|
||||
while True:
|
||||
if size is not None and bytes_read >= size:
|
||||
break
|
||||
ch = self.read(1)
|
||||
bytes_read += 1
|
||||
if not ch:
|
||||
break
|
||||
else:
|
||||
result += ch
|
||||
if ch == '\n':
|
||||
break
|
||||
return result
|
||||
|
||||
|
||||
class TCPClient:
|
||||
def __init__(self, ssl, host, port, clientcert):
|
||||
self.ssl, self.host, self.port, self.clientcert = ssl, host, port, clientcert
|
||||
self.connection, self.rfile, self.wfile = None, None, None
|
||||
self.cert = None
|
||||
self.connect()
|
||||
|
||||
def connect(self):
|
||||
try:
|
||||
addr = socket.gethostbyname(self.host)
|
||||
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
if self.ssl:
|
||||
context = SSL.Context(SSL.SSLv23_METHOD)
|
||||
if self.clientcert:
|
||||
context.use_certificate_file(self.clientcert)
|
||||
server = SSL.Connection(context, server)
|
||||
server.connect((addr, self.port))
|
||||
if self.ssl:
|
||||
self.cert = server.get_peer_certificate()
|
||||
self.rfile, self.wfile = FileLike(server), FileLike(server)
|
||||
else:
|
||||
self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb')
|
||||
except socket.error, err:
|
||||
raise NetLibError('Error connecting to "%s": %s' % (self.host, err))
|
||||
self.connection = server
|
||||
|
||||
|
||||
class BaseHandler:
|
||||
rbufsize = -1
|
||||
wbufsize = 0
|
||||
def __init__(self, connection, client_address, server):
|
||||
self.connection = connection
|
||||
self.rfile = self.connection.makefile('rb', self.rbufsize)
|
||||
self.wfile = self.connection.makefile('wb', self.wbufsize)
|
||||
|
||||
self.client_address = client_address
|
||||
self.server = server
|
||||
self.handle()
|
||||
self.finish()
|
||||
|
||||
def convert_to_ssl(self, cert, key):
|
||||
ctx = SSL.Context(SSL.SSLv23_METHOD)
|
||||
ctx.use_privatekey_file(key)
|
||||
ctx.use_certificate_file(cert)
|
||||
self.connection = SSL.Connection(ctx, self.connection)
|
||||
self.connection.set_accept_state()
|
||||
self.rfile = FileLike(self.connection)
|
||||
self.wfile = FileLike(self.connection)
|
||||
|
||||
def finish(self):
|
||||
try:
|
||||
if not getattr(self.wfile, "closed", False):
|
||||
self.wfile.flush()
|
||||
self.connection.close()
|
||||
self.wfile.close()
|
||||
self.rfile.close()
|
||||
except IOError: # pragma: no cover
|
||||
pass
|
||||
|
||||
def handle(self): # pragma: no cover
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TCPServer:
|
||||
request_queue_size = 20
|
||||
def __init__(self, server_address):
|
||||
self.server_address = server_address
|
||||
self.__is_shut_down = threading.Event()
|
||||
self.__shutdown_request = False
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.socket.bind(self.server_address)
|
||||
self.server_address = self.socket.getsockname()
|
||||
self.socket.listen(self.request_queue_size)
|
||||
self.port = self.socket.getsockname()[1]
|
||||
|
||||
def request_thread(self, request, client_address):
|
||||
try:
|
||||
self.handle_connection(request, client_address)
|
||||
request.close()
|
||||
except:
|
||||
self.handle_error(request, client_address)
|
||||
request.close()
|
||||
|
||||
def serve_forever(self, poll_interval=0.5):
|
||||
self.__is_shut_down.clear()
|
||||
try:
|
||||
while not self.__shutdown_request:
|
||||
r, w, e = select.select([self.socket], [], [], poll_interval)
|
||||
if self.socket in r:
|
||||
try:
|
||||
request, client_address = self.socket.accept()
|
||||
except socket.error:
|
||||
return
|
||||
try:
|
||||
t = threading.Thread(
|
||||
target = self.request_thread,
|
||||
args = (request, client_address)
|
||||
)
|
||||
t.setDaemon(1)
|
||||
t.start()
|
||||
except:
|
||||
self.handle_error(request, client_address)
|
||||
request.close()
|
||||
finally:
|
||||
self.__shutdown_request = False
|
||||
self.__is_shut_down.set()
|
||||
|
||||
def shutdown(self):
|
||||
self.__shutdown_request = True
|
||||
self.__is_shut_down.wait()
|
||||
self.handle_shutdown()
|
||||
|
||||
def handle_error(self, request, client_address, fp=sys.stderr):
|
||||
"""
|
||||
Called when handle_connection raises an exception.
|
||||
"""
|
||||
print >> fp, '-'*40
|
||||
print >> fp, "Error processing of request from %s:%s"%client_address
|
||||
print >> fp, traceback.format_exc()
|
||||
print >> fp, '-'*40
|
||||
|
||||
def handle_connection(self, request, client_address): # pragma: no cover
|
||||
"""
|
||||
Called after client connection.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def handle_shutdown(self):
|
||||
"""
|
||||
Called after server shutdown.
|
||||
"""
|
||||
pass
|
113
test/test_odict.py
Normal file
113
test/test_odict.py
Normal file
@ -0,0 +1,113 @@
|
||||
from netlib import odict
|
||||
import tutils
|
||||
|
||||
|
||||
class TestODict:
|
||||
def setUp(self):
|
||||
self.od = odict.ODict()
|
||||
|
||||
def test_str_err(self):
|
||||
h = odict.ODict()
|
||||
tutils.raises(ValueError, h.__setitem__, "key", "foo")
|
||||
|
||||
def test_dictToHeader1(self):
|
||||
self.od.add("one", "uno")
|
||||
self.od.add("two", "due")
|
||||
self.od.add("two", "tre")
|
||||
expected = [
|
||||
"one: uno\r\n",
|
||||
"two: due\r\n",
|
||||
"two: tre\r\n",
|
||||
"\r\n"
|
||||
]
|
||||
out = repr(self.od)
|
||||
for i in expected:
|
||||
assert out.find(i) >= 0
|
||||
|
||||
def test_dictToHeader2(self):
|
||||
self.od["one"] = ["uno"]
|
||||
expected1 = "one: uno\r\n"
|
||||
expected2 = "\r\n"
|
||||
out = repr(self.od)
|
||||
assert out.find(expected1) >= 0
|
||||
assert out.find(expected2) >= 0
|
||||
|
||||
def test_match_re(self):
|
||||
h = odict.ODict()
|
||||
h.add("one", "uno")
|
||||
h.add("two", "due")
|
||||
h.add("two", "tre")
|
||||
assert h.match_re("uno")
|
||||
assert h.match_re("two: due")
|
||||
assert not h.match_re("nonono")
|
||||
|
||||
def test_getset_state(self):
|
||||
self.od.add("foo", 1)
|
||||
self.od.add("foo", 2)
|
||||
self.od.add("bar", 3)
|
||||
state = self.od._get_state()
|
||||
nd = odict.ODict._from_state(state)
|
||||
assert nd == self.od
|
||||
|
||||
def test_in_any(self):
|
||||
self.od["one"] = ["atwoa", "athreea"]
|
||||
assert self.od.in_any("one", "two")
|
||||
assert self.od.in_any("one", "three")
|
||||
assert not self.od.in_any("one", "four")
|
||||
assert not self.od.in_any("nonexistent", "foo")
|
||||
assert not self.od.in_any("one", "TWO")
|
||||
assert self.od.in_any("one", "TWO", True)
|
||||
|
||||
def test_copy(self):
|
||||
self.od.add("foo", 1)
|
||||
self.od.add("foo", 2)
|
||||
self.od.add("bar", 3)
|
||||
assert self.od == self.od.copy()
|
||||
|
||||
def test_del(self):
|
||||
self.od.add("foo", 1)
|
||||
self.od.add("Foo", 2)
|
||||
self.od.add("bar", 3)
|
||||
del self.od["foo"]
|
||||
assert len(self.od.lst) == 2
|
||||
|
||||
def test_replace(self):
|
||||
self.od.add("one", "two")
|
||||
self.od.add("two", "one")
|
||||
assert self.od.replace("one", "vun") == 2
|
||||
assert self.od.lst == [
|
||||
["vun", "two"],
|
||||
["two", "vun"],
|
||||
]
|
||||
|
||||
def test_get(self):
|
||||
self.od.add("one", "two")
|
||||
assert self.od.get("one") == ["two"]
|
||||
assert self.od.get("two") == None
|
||||
|
||||
|
||||
class TestODictCaseless:
|
||||
def setUp(self):
|
||||
self.od = odict.ODictCaseless()
|
||||
|
||||
def test_override(self):
|
||||
o = odict.ODictCaseless()
|
||||
o.add('T', 'application/x-www-form-urlencoded; charset=UTF-8')
|
||||
o["T"] = ["foo"]
|
||||
assert o["T"] == ["foo"]
|
||||
|
||||
def test_case_preservation(self):
|
||||
self.od["Foo"] = ["1"]
|
||||
assert "foo" in self.od
|
||||
assert self.od.items()[0][0] == "Foo"
|
||||
assert self.od.get("foo") == ["1"]
|
||||
assert self.od.get("foo", [""]) == ["1"]
|
||||
assert self.od.get("Foo", [""]) == ["1"]
|
||||
assert self.od.get("xx", "yy") == "yy"
|
||||
|
||||
def test_del(self):
|
||||
self.od.add("foo", 1)
|
||||
self.od.add("Foo", 2)
|
||||
self.od.add("bar", 3)
|
||||
del self.od["foo"]
|
||||
assert len(self.od) == 1
|
163
test/test_protocol.py
Normal file
163
test/test_protocol.py
Normal file
@ -0,0 +1,163 @@
|
||||
import cStringIO, textwrap
|
||||
from netlib import protocol, odict
|
||||
import tutils
|
||||
|
||||
def test_has_chunked_encoding():
|
||||
h = odict.ODictCaseless()
|
||||
assert not protocol.has_chunked_encoding(h)
|
||||
h["transfer-encoding"] = ["chunked"]
|
||||
assert protocol.has_chunked_encoding(h)
|
||||
|
||||
|
||||
def test_read_chunked():
|
||||
s = cStringIO.StringIO("1\r\na\r\n0\r\n")
|
||||
tutils.raises(IOError, protocol.read_chunked, s, None)
|
||||
|
||||
s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n")
|
||||
assert protocol.read_chunked(s, None) == "a"
|
||||
|
||||
s = cStringIO.StringIO("\r\n")
|
||||
tutils.raises(IOError, protocol.read_chunked, s, None)
|
||||
|
||||
s = cStringIO.StringIO("1\r\nfoo")
|
||||
tutils.raises(IOError, protocol.read_chunked, s, None)
|
||||
|
||||
s = cStringIO.StringIO("foo\r\nfoo")
|
||||
tutils.raises(protocol.ProtocolError, protocol.read_chunked, s, None)
|
||||
|
||||
|
||||
def test_request_connection_close():
|
||||
h = odict.ODictCaseless()
|
||||
assert protocol.request_connection_close((1, 0), h)
|
||||
assert not protocol.request_connection_close((1, 1), h)
|
||||
|
||||
h["connection"] = ["keep-alive"]
|
||||
assert not protocol.request_connection_close((1, 1), h)
|
||||
|
||||
|
||||
def test_read_http_body():
|
||||
h = odict.ODict()
|
||||
s = cStringIO.StringIO("testing")
|
||||
assert protocol.read_http_body(s, h, False, None) == ""
|
||||
|
||||
h["content-length"] = ["foo"]
|
||||
s = cStringIO.StringIO("testing")
|
||||
tutils.raises(protocol.ProtocolError, protocol.read_http_body, s, h, False, None)
|
||||
|
||||
h["content-length"] = [5]
|
||||
s = cStringIO.StringIO("testing")
|
||||
assert len(protocol.read_http_body(s, h, False, None)) == 5
|
||||
s = cStringIO.StringIO("testing")
|
||||
tutils.raises(protocol.ProtocolError, protocol.read_http_body, s, h, False, 4)
|
||||
|
||||
h = odict.ODict()
|
||||
s = cStringIO.StringIO("testing")
|
||||
assert len(protocol.read_http_body(s, h, True, 4)) == 4
|
||||
s = cStringIO.StringIO("testing")
|
||||
assert len(protocol.read_http_body(s, h, True, 100)) == 7
|
||||
|
||||
def test_parse_http_protocol():
|
||||
assert protocol.parse_http_protocol("HTTP/1.1") == (1, 1)
|
||||
assert protocol.parse_http_protocol("HTTP/0.0") == (0, 0)
|
||||
assert not protocol.parse_http_protocol("foo/0.0")
|
||||
|
||||
|
||||
def test_parse_init_connect():
|
||||
assert protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0")
|
||||
assert not protocol.parse_init_connect("bogus")
|
||||
assert not protocol.parse_init_connect("GET host.com:443 HTTP/1.0")
|
||||
assert not protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0")
|
||||
assert not protocol.parse_init_connect("CONNECT host.com:443 foo/1.0")
|
||||
|
||||
|
||||
def test_prase_init_proxy():
|
||||
u = "GET http://foo.com:8888/test HTTP/1.1"
|
||||
m, s, h, po, pa, httpversion = protocol.parse_init_proxy(u)
|
||||
assert m == "GET"
|
||||
assert s == "http"
|
||||
assert h == "foo.com"
|
||||
assert po == 8888
|
||||
assert pa == "/test"
|
||||
assert httpversion == (1, 1)
|
||||
|
||||
assert not protocol.parse_init_proxy("invalid")
|
||||
assert not protocol.parse_init_proxy("GET invalid HTTP/1.1")
|
||||
assert not protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1")
|
||||
|
||||
|
||||
def test_parse_init_http():
|
||||
u = "GET /test HTTP/1.1"
|
||||
m, u, httpversion= protocol.parse_init_http(u)
|
||||
assert m == "GET"
|
||||
assert u == "/test"
|
||||
assert httpversion == (1, 1)
|
||||
|
||||
assert not protocol.parse_init_http("invalid")
|
||||
assert not protocol.parse_init_http("GET invalid HTTP/1.1")
|
||||
assert not protocol.parse_init_http("GET /test foo/1.1")
|
||||
|
||||
|
||||
class TestReadHeaders:
|
||||
def test_read_simple(self):
|
||||
data = """
|
||||
Header: one
|
||||
Header2: two
|
||||
\r\n
|
||||
"""
|
||||
data = textwrap.dedent(data)
|
||||
data = data.strip()
|
||||
s = cStringIO.StringIO(data)
|
||||
h = protocol.read_headers(s)
|
||||
assert h == [["Header", "one"], ["Header2", "two"]]
|
||||
|
||||
def test_read_multi(self):
|
||||
data = """
|
||||
Header: one
|
||||
Header: two
|
||||
\r\n
|
||||
"""
|
||||
data = textwrap.dedent(data)
|
||||
data = data.strip()
|
||||
s = cStringIO.StringIO(data)
|
||||
h = protocol.read_headers(s)
|
||||
assert h == [["Header", "one"], ["Header", "two"]]
|
||||
|
||||
def test_read_continued(self):
|
||||
data = """
|
||||
Header: one
|
||||
\ttwo
|
||||
Header2: three
|
||||
\r\n
|
||||
"""
|
||||
data = textwrap.dedent(data)
|
||||
data = data.strip()
|
||||
s = cStringIO.StringIO(data)
|
||||
h = protocol.read_headers(s)
|
||||
assert h == [["Header", "one\r\n two"], ["Header2", "three"]]
|
||||
|
||||
|
||||
def test_parse_url():
|
||||
assert not protocol.parse_url("")
|
||||
|
||||
u = "http://foo.com:8888/test"
|
||||
s, h, po, pa = protocol.parse_url(u)
|
||||
assert s == "http"
|
||||
assert h == "foo.com"
|
||||
assert po == 8888
|
||||
assert pa == "/test"
|
||||
|
||||
s, h, po, pa = protocol.parse_url("http://foo/bar")
|
||||
assert s == "http"
|
||||
assert h == "foo"
|
||||
assert po == 80
|
||||
assert pa == "/bar"
|
||||
|
||||
s, h, po, pa = protocol.parse_url("http://foo")
|
||||
assert pa == "/"
|
||||
|
||||
s, h, po, pa = protocol.parse_url("https://foo")
|
||||
assert po == 443
|
||||
|
||||
assert not protocol.parse_url("https://foo:bar")
|
||||
assert not protocol.parse_url("https://foo:")
|
||||
|
93
test/test_tcp.py
Normal file
93
test/test_tcp.py
Normal file
@ -0,0 +1,93 @@
|
||||
import cStringIO, threading, Queue
|
||||
from netlib import tcp
|
||||
import tutils
|
||||
|
||||
class ServerThread(threading.Thread):
|
||||
def __init__(self, server):
|
||||
self.server = server
|
||||
threading.Thread.__init__(self)
|
||||
|
||||
def run(self):
|
||||
self.server.serve_forever()
|
||||
|
||||
def shutdown(self):
|
||||
self.server.shutdown()
|
||||
|
||||
|
||||
class ServerTestBase:
|
||||
@classmethod
|
||||
def setupAll(cls):
|
||||
cls.server = ServerThread(cls.makeserver())
|
||||
cls.server.start()
|
||||
|
||||
@classmethod
|
||||
def teardownAll(cls):
|
||||
cls.server.shutdown()
|
||||
|
||||
|
||||
class THandler(tcp.BaseHandler):
|
||||
def handle(self):
|
||||
v = self.rfile.readline()
|
||||
if v.startswith("echo"):
|
||||
self.wfile.write(v)
|
||||
elif v.startswith("error"):
|
||||
raise ValueError("Testing an error.")
|
||||
self.wfile.flush()
|
||||
|
||||
|
||||
class TServer(tcp.TCPServer):
|
||||
def __init__(self, addr, q):
|
||||
tcp.TCPServer.__init__(self, addr)
|
||||
self.q = q
|
||||
|
||||
def handle_connection(self, request, client_address):
|
||||
THandler(request, client_address, self)
|
||||
|
||||
def handle_error(self, request, client_address):
|
||||
s = cStringIO.StringIO()
|
||||
tcp.TCPServer.handle_error(self, request, client_address, s)
|
||||
self.q.put(s.getvalue())
|
||||
|
||||
|
||||
class TestServer(ServerTestBase):
|
||||
@classmethod
|
||||
def makeserver(cls):
|
||||
cls.q = Queue.Queue()
|
||||
s = TServer(("127.0.0.1", 0), cls.q)
|
||||
cls.port = s.port
|
||||
return s
|
||||
|
||||
def test_echo(self):
|
||||
testval = "echo!\n"
|
||||
c = tcp.TCPClient(False, "127.0.0.1", self.port, None)
|
||||
c.wfile.write(testval)
|
||||
c.wfile.flush()
|
||||
assert c.rfile.readline() == testval
|
||||
|
||||
def test_error(self):
|
||||
testval = "error!\n"
|
||||
c = tcp.TCPClient(False, "127.0.0.1", self.port, None)
|
||||
c.wfile.write(testval)
|
||||
c.wfile.flush()
|
||||
assert "Testing an error" in self.q.get()
|
||||
|
||||
|
||||
class TestTCPClient:
|
||||
def test_conerr(self):
|
||||
tutils.raises(tcp.NetLibError, tcp.TCPClient, False, "127.0.0.1", 0, None)
|
||||
|
||||
|
||||
class TestFileLike:
|
||||
def test_wrap(self):
|
||||
s = cStringIO.StringIO("foobar\nfoobar")
|
||||
s = tcp.FileLike(s)
|
||||
s.flush()
|
||||
assert s.readline() == "foobar\n"
|
||||
assert s.readline() == "foobar"
|
||||
# Test __getattr__
|
||||
assert s.isatty
|
||||
|
||||
def test_limit(self):
|
||||
s = cStringIO.StringIO("foobar\nfoobar")
|
||||
s = tcp.FileLike(s)
|
||||
assert s.readline(3) == "foo"
|
56
test/tutils.py
Normal file
56
test/tutils.py
Normal file
@ -0,0 +1,56 @@
|
||||
import tempfile, os, shutil
|
||||
from contextlib import contextmanager
|
||||
from libpathod import utils
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tmpdir(*args, **kwargs):
|
||||
orig_workdir = os.getcwd()
|
||||
temp_workdir = tempfile.mkdtemp(*args, **kwargs)
|
||||
os.chdir(temp_workdir)
|
||||
|
||||
yield temp_workdir
|
||||
|
||||
os.chdir(orig_workdir)
|
||||
shutil.rmtree(temp_workdir)
|
||||
|
||||
|
||||
def raises(exc, obj, *args, **kwargs):
|
||||
"""
|
||||
Assert that a callable raises a specified exception.
|
||||
|
||||
:exc An exception class or a string. If a class, assert that an
|
||||
exception of this type is raised. If a string, assert that the string
|
||||
occurs in the string representation of the exception, based on a
|
||||
case-insenstivie match.
|
||||
|
||||
:obj A callable object.
|
||||
|
||||
:args Arguments to be passsed to the callable.
|
||||
|
||||
:kwargs Arguments to be passed to the callable.
|
||||
"""
|
||||
try:
|
||||
apply(obj, args, kwargs)
|
||||
except Exception, v:
|
||||
if isinstance(exc, basestring):
|
||||
if exc.lower() in str(v).lower():
|
||||
return
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Expected %s, but caught %s"%(
|
||||
repr(str(exc)), v
|
||||
)
|
||||
)
|
||||
else:
|
||||
if isinstance(v, exc):
|
||||
return
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Expected %s, but caught %s %s"%(
|
||||
exc.__name__, v.__class__.__name__, str(v)
|
||||
)
|
||||
)
|
||||
raise AssertionError("No exception raised.")
|
||||
|
||||
test_data = utils.Data(__name__)
|
Loading…
Reference in New Issue
Block a user