mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2025-01-30 14:58:38 +00:00
Extract protocol and tcp server implementations into netlib.
This commit is contained in:
parent
7b9756f48e
commit
1b1ccab8b7
3
.gitignore
vendored
3
.gitignore
vendored
@ -8,6 +8,5 @@ MANIFEST
|
||||
*.swo
|
||||
mitmproxyc
|
||||
mitmdumpc
|
||||
mitmplaybackc
|
||||
mitmrecordc
|
||||
netlib
|
||||
.coverage
|
||||
|
@ -21,11 +21,15 @@ import hashlib, Cookie, cookielib, copy, re, urlparse
|
||||
import time
|
||||
import tnetstring, filt, script, utils, encoding, proxy
|
||||
from email.utils import parsedate_tz, formatdate, mktime_tz
|
||||
import controller, version, certutils, protocol
|
||||
from netlib import odict, protocol
|
||||
import controller, version, certutils
|
||||
|
||||
HDR_FORM_URLENCODED = "application/x-www-form-urlencoded"
|
||||
CONTENT_MISSING = 0
|
||||
|
||||
ODict = odict.ODict
|
||||
ODictCaseless = odict.ODictCaseless
|
||||
|
||||
|
||||
class ReplaceHooks:
|
||||
def __init__(self):
|
||||
@ -117,157 +121,6 @@ class ScriptContext:
|
||||
self._master.replay_request(f)
|
||||
|
||||
|
||||
class ODict:
|
||||
"""
|
||||
A dictionary-like object for managing ordered (key, value) data.
|
||||
"""
|
||||
def __init__(self, lst=None):
|
||||
self.lst = lst or []
|
||||
|
||||
def _kconv(self, s):
|
||||
return s
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.lst == other.lst
|
||||
|
||||
def __getitem__(self, k):
|
||||
"""
|
||||
Returns a list of values matching key.
|
||||
"""
|
||||
ret = []
|
||||
k = self._kconv(k)
|
||||
for i in self.lst:
|
||||
if self._kconv(i[0]) == k:
|
||||
ret.append(i[1])
|
||||
return ret
|
||||
|
||||
def _filter_lst(self, k, lst):
|
||||
k = self._kconv(k)
|
||||
new = []
|
||||
for i in lst:
|
||||
if self._kconv(i[0]) != k:
|
||||
new.append(i)
|
||||
return new
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Total number of (key, value) pairs.
|
||||
"""
|
||||
return len(self.lst)
|
||||
|
||||
def __setitem__(self, k, valuelist):
|
||||
"""
|
||||
Sets the values for key k. If there are existing values for this
|
||||
key, they are cleared.
|
||||
"""
|
||||
if isinstance(valuelist, basestring):
|
||||
raise ValueError("ODict valuelist should be lists.")
|
||||
new = self._filter_lst(k, self.lst)
|
||||
for i in valuelist:
|
||||
new.append([k, i])
|
||||
self.lst = new
|
||||
|
||||
def __delitem__(self, k):
|
||||
"""
|
||||
Delete all items matching k.
|
||||
"""
|
||||
self.lst = self._filter_lst(k, self.lst)
|
||||
|
||||
def __contains__(self, k):
|
||||
for i in self.lst:
|
||||
if self._kconv(i[0]) == self._kconv(k):
|
||||
return True
|
||||
return False
|
||||
|
||||
def add(self, key, value):
|
||||
self.lst.append([key, str(value)])
|
||||
|
||||
def get(self, k, d=None):
|
||||
if k in self:
|
||||
return self[k]
|
||||
else:
|
||||
return d
|
||||
|
||||
def items(self):
|
||||
return self.lst[:]
|
||||
|
||||
def _get_state(self):
|
||||
return [tuple(i) for i in self.lst]
|
||||
|
||||
@classmethod
|
||||
def _from_state(klass, state):
|
||||
return klass([list(i) for i in state])
|
||||
|
||||
def copy(self):
|
||||
"""
|
||||
Returns a copy of this object.
|
||||
"""
|
||||
lst = copy.deepcopy(self.lst)
|
||||
return self.__class__(lst)
|
||||
|
||||
def __repr__(self):
|
||||
elements = []
|
||||
for itm in self.lst:
|
||||
elements.append(itm[0] + ": " + itm[1])
|
||||
elements.append("")
|
||||
return "\r\n".join(elements)
|
||||
|
||||
def in_any(self, key, value, caseless=False):
|
||||
"""
|
||||
Do any of the values matching key contain value?
|
||||
|
||||
If caseless is true, value comparison is case-insensitive.
|
||||
"""
|
||||
if caseless:
|
||||
value = value.lower()
|
||||
for i in self[key]:
|
||||
if caseless:
|
||||
i = i.lower()
|
||||
if value in i:
|
||||
return True
|
||||
return False
|
||||
|
||||
def match_re(self, expr):
|
||||
"""
|
||||
Match the regular expression against each (key, value) pair. For
|
||||
each pair a string of the following format is matched against:
|
||||
|
||||
"key: value"
|
||||
"""
|
||||
for k, v in self.lst:
|
||||
s = "%s: %s"%(k, v)
|
||||
if re.search(expr, s):
|
||||
return True
|
||||
return False
|
||||
|
||||
def replace(self, pattern, repl, *args, **kwargs):
|
||||
"""
|
||||
Replaces a regular expression pattern with repl in both keys and
|
||||
values. Encoded content will be decoded before replacement, and
|
||||
re-encoded afterwards.
|
||||
|
||||
Returns the number of replacements made.
|
||||
"""
|
||||
nlst, count = [], 0
|
||||
for i in self.lst:
|
||||
k, c = utils.safe_subn(pattern, repl, i[0], *args, **kwargs)
|
||||
count += c
|
||||
v, c = utils.safe_subn(pattern, repl, i[1], *args, **kwargs)
|
||||
count += c
|
||||
nlst.append([k, v])
|
||||
self.lst = nlst
|
||||
return count
|
||||
|
||||
|
||||
class ODictCaseless(ODict):
|
||||
"""
|
||||
A variant of ODict with "caseless" keys. This version _preserves_ key
|
||||
case, but does not consider case when setting or getting items.
|
||||
"""
|
||||
def _kconv(self, s):
|
||||
return s.lower()
|
||||
|
||||
|
||||
class decoded(object):
|
||||
"""
|
||||
|
||||
|
@ -1,182 +0,0 @@
|
||||
import select, socket, threading, traceback, sys
|
||||
from OpenSSL import SSL
|
||||
|
||||
|
||||
class NetLibError(Exception): pass
|
||||
|
||||
|
||||
class FileLike:
|
||||
def __init__(self, o):
|
||||
self.o = o
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.o, attr)
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
def read(self, length):
|
||||
result = ''
|
||||
while len(result) < length:
|
||||
try:
|
||||
data = self.o.read(length)
|
||||
except SSL.ZeroReturnError:
|
||||
break
|
||||
if not data:
|
||||
break
|
||||
result += data
|
||||
return result
|
||||
|
||||
def write(self, v):
|
||||
self.o.sendall(v)
|
||||
|
||||
def readline(self, size = None):
|
||||
result = ''
|
||||
bytes_read = 0
|
||||
while True:
|
||||
if size is not None and bytes_read >= size:
|
||||
break
|
||||
ch = self.read(1)
|
||||
bytes_read += 1
|
||||
if not ch:
|
||||
break
|
||||
else:
|
||||
result += ch
|
||||
if ch == '\n':
|
||||
break
|
||||
return result
|
||||
|
||||
|
||||
class TCPClient:
|
||||
def __init__(self, ssl, host, port, clientcert):
|
||||
self.ssl, self.host, self.port, self.clientcert = ssl, host, port, clientcert
|
||||
self.connection, self.rfile, self.wfile = None, None, None
|
||||
self.cert = None
|
||||
self.connect()
|
||||
|
||||
def connect(self):
|
||||
try:
|
||||
addr = socket.gethostbyname(self.host)
|
||||
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
if self.ssl:
|
||||
context = SSL.Context(SSL.SSLv23_METHOD)
|
||||
if self.clientcert:
|
||||
context.use_certificate_file(self.clientcert)
|
||||
server = SSL.Connection(context, server)
|
||||
server.connect((addr, self.port))
|
||||
if self.ssl:
|
||||
self.cert = server.get_peer_certificate()
|
||||
self.rfile, self.wfile = FileLike(server), FileLike(server)
|
||||
else:
|
||||
self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb')
|
||||
except socket.error, err:
|
||||
raise NetLibError('Error connecting to "%s": %s' % (self.host, err))
|
||||
self.connection = server
|
||||
|
||||
|
||||
class BaseHandler:
|
||||
rbufsize = -1
|
||||
wbufsize = 0
|
||||
def __init__(self, connection, client_address, server):
|
||||
self.connection = connection
|
||||
self.rfile = self.connection.makefile('rb', self.rbufsize)
|
||||
self.wfile = self.connection.makefile('wb', self.wbufsize)
|
||||
|
||||
self.client_address = client_address
|
||||
self.server = server
|
||||
self.handle()
|
||||
self.finish()
|
||||
|
||||
def convert_to_ssl(self, cert, key):
|
||||
ctx = SSL.Context(SSL.SSLv23_METHOD)
|
||||
ctx.use_privatekey_file(key)
|
||||
ctx.use_certificate_file(cert)
|
||||
self.connection = SSL.Connection(ctx, self.connection)
|
||||
self.connection.set_accept_state()
|
||||
self.rfile = FileLike(self.connection)
|
||||
self.wfile = FileLike(self.connection)
|
||||
|
||||
def finish(self):
|
||||
try:
|
||||
if not getattr(self.wfile, "closed", False):
|
||||
self.wfile.flush()
|
||||
self.connection.close()
|
||||
self.wfile.close()
|
||||
self.rfile.close()
|
||||
except IOError: # pragma: no cover
|
||||
pass
|
||||
|
||||
def handle(self): # pragma: no cover
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TCPServer:
|
||||
request_queue_size = 20
|
||||
def __init__(self, server_address):
|
||||
self.server_address = server_address
|
||||
self.__is_shut_down = threading.Event()
|
||||
self.__shutdown_request = False
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.socket.bind(self.server_address)
|
||||
self.server_address = self.socket.getsockname()
|
||||
self.socket.listen(self.request_queue_size)
|
||||
self.port = self.socket.getsockname()[1]
|
||||
|
||||
def request_thread(self, request, client_address):
|
||||
try:
|
||||
self.handle_connection(request, client_address)
|
||||
request.close()
|
||||
except:
|
||||
self.handle_error(request, client_address)
|
||||
request.close()
|
||||
|
||||
def serve_forever(self, poll_interval=0.5):
|
||||
self.__is_shut_down.clear()
|
||||
try:
|
||||
while not self.__shutdown_request:
|
||||
r, w, e = select.select([self.socket], [], [], poll_interval)
|
||||
if self.socket in r:
|
||||
try:
|
||||
request, client_address = self.socket.accept()
|
||||
except socket.error:
|
||||
return
|
||||
try:
|
||||
t = threading.Thread(
|
||||
target = self.request_thread,
|
||||
args = (request, client_address)
|
||||
)
|
||||
t.setDaemon(1)
|
||||
t.start()
|
||||
except:
|
||||
self.handle_error(request, client_address)
|
||||
request.close()
|
||||
finally:
|
||||
self.__shutdown_request = False
|
||||
self.__is_shut_down.set()
|
||||
|
||||
def shutdown(self):
|
||||
self.__shutdown_request = True
|
||||
self.__is_shut_down.wait()
|
||||
self.handle_shutdown()
|
||||
|
||||
def handle_error(self, request, client_address, fp=sys.stderr):
|
||||
"""
|
||||
Called when handle_connection raises an exception.
|
||||
"""
|
||||
print >> fp, '-'*40
|
||||
print >> fp, "Error processing of request from %s:%s"%client_address
|
||||
print >> fp, traceback.format_exc()
|
||||
print >> fp, '-'*40
|
||||
|
||||
def handle_connection(self, request, client_address): # pragma: no cover
|
||||
"""
|
||||
Called after client connection.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def handle_shutdown(self):
|
||||
"""
|
||||
Called after server shutdown.
|
||||
"""
|
||||
pass
|
@ -1,220 +0,0 @@
|
||||
import string, urlparse
|
||||
|
||||
class ProtocolError(Exception):
|
||||
def __init__(self, code, msg):
|
||||
self.code, self.msg = code, msg
|
||||
|
||||
def __str__(self):
|
||||
return "ProtocolError(%s, %s)"%(self.code, self.msg)
|
||||
|
||||
|
||||
def parse_url(url):
|
||||
"""
|
||||
Returns a (scheme, host, port, path) tuple, or None on error.
|
||||
"""
|
||||
scheme, netloc, path, params, query, fragment = urlparse.urlparse(url)
|
||||
if not scheme:
|
||||
return None
|
||||
if ':' in netloc:
|
||||
host, port = string.rsplit(netloc, ':', maxsplit=1)
|
||||
try:
|
||||
port = int(port)
|
||||
except ValueError:
|
||||
return None
|
||||
else:
|
||||
host = netloc
|
||||
if scheme == "https":
|
||||
port = 443
|
||||
else:
|
||||
port = 80
|
||||
path = urlparse.urlunparse(('', '', path, params, query, fragment))
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
return scheme, host, port, path
|
||||
|
||||
|
||||
def read_headers(fp):
|
||||
"""
|
||||
Read a set of headers from a file pointer. Stop once a blank line
|
||||
is reached. Return a ODictCaseless object.
|
||||
"""
|
||||
ret = []
|
||||
name = ''
|
||||
while 1:
|
||||
line = fp.readline()
|
||||
if not line or line == '\r\n' or line == '\n':
|
||||
break
|
||||
if line[0] in ' \t':
|
||||
# continued header
|
||||
ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip()
|
||||
else:
|
||||
i = line.find(':')
|
||||
# We're being liberal in what we accept, here.
|
||||
if i > 0:
|
||||
name = line[:i]
|
||||
value = line[i+1:].strip()
|
||||
ret.append([name, value])
|
||||
return ret
|
||||
|
||||
|
||||
def read_chunked(fp, limit):
|
||||
content = ""
|
||||
total = 0
|
||||
while 1:
|
||||
line = fp.readline(128)
|
||||
if line == "":
|
||||
raise IOError("Connection closed")
|
||||
if line == '\r\n' or line == '\n':
|
||||
continue
|
||||
try:
|
||||
length = int(line,16)
|
||||
except ValueError:
|
||||
# FIXME: Not strictly correct - this could be from the server, in which
|
||||
# case we should send a 502.
|
||||
raise ProtocolError(400, "Invalid chunked encoding length: %s"%line)
|
||||
if not length:
|
||||
break
|
||||
total += length
|
||||
if limit is not None and total > limit:
|
||||
msg = "HTTP Body too large."\
|
||||
" Limit is %s, chunked content length was at least %s"%(limit, total)
|
||||
raise ProtocolError(509, msg)
|
||||
content += fp.read(length)
|
||||
line = fp.readline(5)
|
||||
if line != '\r\n':
|
||||
raise IOError("Malformed chunked body")
|
||||
while 1:
|
||||
line = fp.readline()
|
||||
if line == "":
|
||||
raise IOError("Connection closed")
|
||||
if line == '\r\n' or line == '\n':
|
||||
break
|
||||
return content
|
||||
|
||||
|
||||
def has_chunked_encoding(headers):
|
||||
for i in headers["transfer-encoding"]:
|
||||
for j in i.split(","):
|
||||
if j.lower() == "chunked":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def read_http_body(rfile, headers, all, limit):
|
||||
if has_chunked_encoding(headers):
|
||||
content = read_chunked(rfile, limit)
|
||||
elif "content-length" in headers:
|
||||
try:
|
||||
l = int(headers["content-length"][0])
|
||||
except ValueError:
|
||||
# FIXME: Not strictly correct - this could be from the server, in which
|
||||
# case we should send a 502.
|
||||
raise ProtocolError(400, "Invalid content-length header: %s"%headers["content-length"])
|
||||
if limit is not None and l > limit:
|
||||
raise ProtocolError(509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l))
|
||||
content = rfile.read(l)
|
||||
elif all:
|
||||
content = rfile.read(limit if limit else None)
|
||||
else:
|
||||
content = ""
|
||||
return content
|
||||
|
||||
|
||||
def parse_http_protocol(s):
|
||||
if not s.startswith("HTTP/"):
|
||||
return None
|
||||
major, minor = s.split('/')[1].split('.')
|
||||
major = int(major)
|
||||
minor = int(minor)
|
||||
return major, minor
|
||||
|
||||
|
||||
def parse_init_connect(line):
|
||||
try:
|
||||
method, url, protocol = string.split(line)
|
||||
except ValueError:
|
||||
return None
|
||||
if method != 'CONNECT':
|
||||
return None
|
||||
try:
|
||||
host, port = url.split(":")
|
||||
except ValueError:
|
||||
return None
|
||||
port = int(port)
|
||||
httpversion = parse_http_protocol(protocol)
|
||||
if not httpversion:
|
||||
return None
|
||||
return host, port, httpversion
|
||||
|
||||
|
||||
def parse_init_proxy(line):
|
||||
try:
|
||||
method, url, protocol = string.split(line)
|
||||
except ValueError:
|
||||
return None
|
||||
parts = parse_url(url)
|
||||
if not parts:
|
||||
return None
|
||||
scheme, host, port, path = parts
|
||||
httpversion = parse_http_protocol(protocol)
|
||||
if not httpversion:
|
||||
return None
|
||||
return method, scheme, host, port, path, httpversion
|
||||
|
||||
|
||||
def parse_init_http(line):
|
||||
"""
|
||||
Returns (method, url, httpversion)
|
||||
"""
|
||||
try:
|
||||
method, url, protocol = string.split(line)
|
||||
except ValueError:
|
||||
return None
|
||||
if not (url.startswith("/") or url == "*"):
|
||||
return None
|
||||
httpversion = parse_http_protocol(protocol)
|
||||
if not httpversion:
|
||||
return None
|
||||
return method, url, httpversion
|
||||
|
||||
|
||||
def request_connection_close(httpversion, headers):
|
||||
"""
|
||||
Checks the request to see if the client connection should be closed.
|
||||
"""
|
||||
if "connection" in headers:
|
||||
for value in ",".join(headers['connection']).split(","):
|
||||
value = value.strip()
|
||||
if value == "close":
|
||||
return True
|
||||
elif value == "keep-alive":
|
||||
return False
|
||||
# HTTP 1.1 connections are assumed to be persistent
|
||||
if httpversion == (1, 1):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def response_connection_close(httpversion, headers):
|
||||
"""
|
||||
Checks the response to see if the client connection should be closed.
|
||||
"""
|
||||
if request_connection_close(httpversion, headers):
|
||||
return True
|
||||
elif not has_chunked_encoding(headers) and "content-length" in headers:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def read_http_body_request(rfile, wfile, headers, httpversion, limit):
|
||||
if "expect" in headers:
|
||||
# FIXME: Should be forwarded upstream
|
||||
expect = ",".join(headers['expect'])
|
||||
if expect == "100-continue" and httpversion >= (1, 1):
|
||||
wfile.write('HTTP/1.1 100 Continue\r\n')
|
||||
wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION)
|
||||
wfile.write('\r\n')
|
||||
del headers['expect']
|
||||
return read_http_body(rfile, headers, False, limit)
|
||||
|
||||
|
@ -15,8 +15,9 @@
|
||||
import sys, os, string, socket, time
|
||||
import shutil, tempfile, threading
|
||||
import optparse, SocketServer
|
||||
import utils, flow, certutils, version, wsgi, netlib, protocol
|
||||
from OpenSSL import SSL
|
||||
from netlib import odict, tcp, protocol
|
||||
import utils, flow, certutils, version, wsgi
|
||||
|
||||
|
||||
class ProxyError(Exception):
|
||||
@ -56,18 +57,18 @@ class RequestReplayThread(threading.Thread):
|
||||
except (ProxyError, protocol.ProtocolError), v:
|
||||
err = flow.Error(self.flow.request, v.msg)
|
||||
err._send(self.masterq)
|
||||
except netlib.NetLibError, v:
|
||||
except tcp.NetLibError, v:
|
||||
raise ProxyError(502, v)
|
||||
|
||||
|
||||
class ServerConnection(netlib.TCPClient):
|
||||
class ServerConnection(tcp.TCPClient):
|
||||
def __init__(self, config, scheme, host, port):
|
||||
clientcert = None
|
||||
if config.clientcerts:
|
||||
path = os.path.join(config.clientcerts, self.host) + ".pem"
|
||||
if os.path.exists(clientcert):
|
||||
clientcert = path
|
||||
netlib.TCPClient.__init__(
|
||||
tcp.TCPClient.__init__(
|
||||
self,
|
||||
True if scheme == "https" else False,
|
||||
host,
|
||||
@ -107,7 +108,7 @@ class ServerConnection(netlib.TCPClient):
|
||||
code = int(code)
|
||||
except ValueError:
|
||||
raise ProxyError(502, "Invalid server response: %s."%line)
|
||||
headers = flow.ODictCaseless(protocol.read_headers(self.rfile))
|
||||
headers = odict.ODictCaseless(protocol.read_headers(self.rfile))
|
||||
if code >= 100 and code <= 199:
|
||||
return self.read_response()
|
||||
if request.method == "HEAD" or code == 204 or code == 304:
|
||||
@ -125,13 +126,13 @@ class ServerConnection(netlib.TCPClient):
|
||||
pass
|
||||
|
||||
|
||||
class ProxyHandler(netlib.BaseHandler):
|
||||
class ProxyHandler(tcp.BaseHandler):
|
||||
def __init__(self, config, connection, client_address, server, q):
|
||||
self.mqueue = q
|
||||
self.config = config
|
||||
self.server_conn = None
|
||||
self.proxy_connect_state = None
|
||||
netlib.BaseHandler.__init__(self, connection, client_address, server)
|
||||
tcp.BaseHandler.__init__(self, connection, client_address, server)
|
||||
|
||||
def handle(self):
|
||||
cc = flow.ClientConnect(self.client_address)
|
||||
@ -150,7 +151,7 @@ class ProxyHandler(netlib.BaseHandler):
|
||||
if not self.server_conn:
|
||||
try:
|
||||
self.server_conn = ServerConnection(self.config, scheme, host, port)
|
||||
except netlib.NetLibError, v:
|
||||
except tcp.NetLibError, v:
|
||||
raise ProxyError(502, v)
|
||||
|
||||
def handle_request(self, cc):
|
||||
@ -243,7 +244,7 @@ class ProxyHandler(netlib.BaseHandler):
|
||||
else:
|
||||
scheme = "http"
|
||||
method, path, httpversion = protocol.parse_init_http(line)
|
||||
headers = flow.ODictCaseless(protocol.read_headers(self.rfile))
|
||||
headers = odict.ODictCaseless(protocol.read_headers(self.rfile))
|
||||
content = protocol.read_http_body_request(
|
||||
self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit
|
||||
)
|
||||
@ -251,7 +252,7 @@ class ProxyHandler(netlib.BaseHandler):
|
||||
elif self.config.reverse_proxy:
|
||||
scheme, host, port = self.config.reverse_proxy
|
||||
method, path, httpversion = protocol.parse_init_http(line)
|
||||
headers = flow.ODictCaseless(protocol.read_headers(self.rfile))
|
||||
headers = odict.ODictCaseless(protocol.read_headers(self.rfile))
|
||||
content = protocol.read_http_body_request(
|
||||
self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit
|
||||
)
|
||||
@ -278,14 +279,14 @@ class ProxyHandler(netlib.BaseHandler):
|
||||
if self.proxy_connect_state:
|
||||
host, port, httpversion = self.proxy_connect_state
|
||||
method, path, httpversion = protocol.parse_init_http(line)
|
||||
headers = flow.ODictCaseless(protocol.read_headers(self.rfile))
|
||||
headers = odict.ODictCaseless(protocol.read_headers(self.rfile))
|
||||
content = protocol.read_http_body_request(
|
||||
self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit
|
||||
)
|
||||
return flow.Request(client_conn, httpversion, host, port, "https", method, path, headers, content)
|
||||
else:
|
||||
method, scheme, host, port, path, httpversion = protocol.parse_init_proxy(line)
|
||||
headers = flow.ODictCaseless(protocol.read_headers(self.rfile))
|
||||
headers = odict.ODictCaseless(protocol.read_headers(self.rfile))
|
||||
content = protocol.read_http_body_request(
|
||||
self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit
|
||||
)
|
||||
@ -317,7 +318,7 @@ class ProxyHandler(netlib.BaseHandler):
|
||||
class ProxyServerError(Exception): pass
|
||||
|
||||
|
||||
class ProxyServer(netlib.TCPServer):
|
||||
class ProxyServer(tcp.TCPServer):
|
||||
allow_reuse_address = True
|
||||
bound = True
|
||||
def __init__(self, config, port, address=''):
|
||||
@ -326,7 +327,7 @@ class ProxyServer(netlib.TCPServer):
|
||||
"""
|
||||
self.config, self.port, self.address = config, port, address
|
||||
try:
|
||||
netlib.TCPServer.__init__(self, (address, port))
|
||||
tcp.TCPServer.__init__(self, (address, port))
|
||||
except socket.error, v:
|
||||
raise ProxyServerError('Error starting proxy server: ' + v.strerror)
|
||||
self.masterq = None
|
||||
|
@ -15,7 +15,7 @@
|
||||
import os, datetime, urlparse, string, urllib, re
|
||||
import time, functools, cgi
|
||||
import json
|
||||
import protocol
|
||||
from netlib import protocol
|
||||
|
||||
def timestamp():
|
||||
"""
|
||||
@ -294,6 +294,3 @@ def safe_subn(pattern, repl, target, *args, **kwargs):
|
||||
need a better solution that is aware of the actual content ecoding.
|
||||
"""
|
||||
return re.subn(str(pattern), str(repl), target, *args, **kwargs)
|
||||
|
||||
|
||||
|
||||
|
2
setup.py
2
setup.py
@ -92,5 +92,5 @@ setup(
|
||||
"Topic :: Internet :: Proxy Servers",
|
||||
"Topic :: Software Development :: Testing"
|
||||
],
|
||||
install_requires=['urwid>=1.0', 'pyasn1>0.1.2', 'pyopenssl>=0.12', "PIL", "lxml"],
|
||||
install_requires=["netlib", "urwid>=1.0", "pyasn1>0.1.2", "pyopenssl>=0.12", "PIL", "lxml"],
|
||||
)
|
||||
|
@ -1,93 +0,0 @@
|
||||
import cStringIO, threading, Queue
|
||||
from libmproxy import netlib
|
||||
import tutils
|
||||
|
||||
class ServerThread(threading.Thread):
|
||||
def __init__(self, server):
|
||||
self.server = server
|
||||
threading.Thread.__init__(self)
|
||||
|
||||
def run(self):
|
||||
self.server.serve_forever()
|
||||
|
||||
def shutdown(self):
|
||||
self.server.shutdown()
|
||||
|
||||
|
||||
class ServerTestBase:
|
||||
@classmethod
|
||||
def setupAll(cls):
|
||||
cls.server = ServerThread(cls.makeserver())
|
||||
cls.server.start()
|
||||
|
||||
@classmethod
|
||||
def teardownAll(cls):
|
||||
cls.server.shutdown()
|
||||
|
||||
|
||||
class THandler(netlib.BaseHandler):
|
||||
def handle(self):
|
||||
v = self.rfile.readline()
|
||||
if v.startswith("echo"):
|
||||
self.wfile.write(v)
|
||||
elif v.startswith("error"):
|
||||
raise ValueError("Testing an error.")
|
||||
self.wfile.flush()
|
||||
|
||||
|
||||
class TServer(netlib.TCPServer):
|
||||
def __init__(self, addr, q):
|
||||
netlib.TCPServer.__init__(self, addr)
|
||||
self.q = q
|
||||
|
||||
def handle_connection(self, request, client_address):
|
||||
THandler(request, client_address, self)
|
||||
|
||||
def handle_error(self, request, client_address):
|
||||
s = cStringIO.StringIO()
|
||||
netlib.TCPServer.handle_error(self, request, client_address, s)
|
||||
self.q.put(s.getvalue())
|
||||
|
||||
|
||||
class TestServer(ServerTestBase):
|
||||
@classmethod
|
||||
def makeserver(cls):
|
||||
cls.q = Queue.Queue()
|
||||
s = TServer(("127.0.0.1", 0), cls.q)
|
||||
cls.port = s.port
|
||||
return s
|
||||
|
||||
def test_echo(self):
|
||||
testval = "echo!\n"
|
||||
c = netlib.TCPClient(False, "127.0.0.1", self.port, None)
|
||||
c.wfile.write(testval)
|
||||
c.wfile.flush()
|
||||
assert c.rfile.readline() == testval
|
||||
|
||||
def test_error(self):
|
||||
testval = "error!\n"
|
||||
c = netlib.TCPClient(False, "127.0.0.1", self.port, None)
|
||||
c.wfile.write(testval)
|
||||
c.wfile.flush()
|
||||
assert "Testing an error" in self.q.get()
|
||||
|
||||
|
||||
class TestTCPClient:
|
||||
def test_conerr(self):
|
||||
tutils.raises(netlib.NetLibError, netlib.TCPClient, False, "127.0.0.1", 0, None)
|
||||
|
||||
|
||||
class TestFileLike:
|
||||
def test_wrap(self):
|
||||
s = cStringIO.StringIO("foobar\nfoobar")
|
||||
s = netlib.FileLike(s)
|
||||
s.flush()
|
||||
assert s.readline() == "foobar\n"
|
||||
assert s.readline() == "foobar"
|
||||
# Test __getattr__
|
||||
assert s.isatty
|
||||
|
||||
def test_limit(self):
|
||||
s = cStringIO.StringIO("foobar\nfoobar")
|
||||
s = netlib.FileLike(s)
|
||||
assert s.readline(3) == "foo"
|
@ -1,163 +0,0 @@
|
||||
import cStringIO, textwrap
|
||||
from libmproxy import protocol, flow
|
||||
import tutils
|
||||
|
||||
def test_has_chunked_encoding():
|
||||
h = flow.ODictCaseless()
|
||||
assert not protocol.has_chunked_encoding(h)
|
||||
h["transfer-encoding"] = ["chunked"]
|
||||
assert protocol.has_chunked_encoding(h)
|
||||
|
||||
|
||||
def test_read_chunked():
|
||||
s = cStringIO.StringIO("1\r\na\r\n0\r\n")
|
||||
tutils.raises(IOError, protocol.read_chunked, s, None)
|
||||
|
||||
s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n")
|
||||
assert protocol.read_chunked(s, None) == "a"
|
||||
|
||||
s = cStringIO.StringIO("\r\n")
|
||||
tutils.raises(IOError, protocol.read_chunked, s, None)
|
||||
|
||||
s = cStringIO.StringIO("1\r\nfoo")
|
||||
tutils.raises(IOError, protocol.read_chunked, s, None)
|
||||
|
||||
s = cStringIO.StringIO("foo\r\nfoo")
|
||||
tutils.raises(protocol.ProtocolError, protocol.read_chunked, s, None)
|
||||
|
||||
|
||||
def test_request_connection_close():
|
||||
h = flow.ODictCaseless()
|
||||
assert protocol.request_connection_close((1, 0), h)
|
||||
assert not protocol.request_connection_close((1, 1), h)
|
||||
|
||||
h["connection"] = ["keep-alive"]
|
||||
assert not protocol.request_connection_close((1, 1), h)
|
||||
|
||||
|
||||
def test_read_http_body():
|
||||
h = flow.ODict()
|
||||
s = cStringIO.StringIO("testing")
|
||||
assert protocol.read_http_body(s, h, False, None) == ""
|
||||
|
||||
h["content-length"] = ["foo"]
|
||||
s = cStringIO.StringIO("testing")
|
||||
tutils.raises(protocol.ProtocolError, protocol.read_http_body, s, h, False, None)
|
||||
|
||||
h["content-length"] = [5]
|
||||
s = cStringIO.StringIO("testing")
|
||||
assert len(protocol.read_http_body(s, h, False, None)) == 5
|
||||
s = cStringIO.StringIO("testing")
|
||||
tutils.raises(protocol.ProtocolError, protocol.read_http_body, s, h, False, 4)
|
||||
|
||||
h = flow.ODict()
|
||||
s = cStringIO.StringIO("testing")
|
||||
assert len(protocol.read_http_body(s, h, True, 4)) == 4
|
||||
s = cStringIO.StringIO("testing")
|
||||
assert len(protocol.read_http_body(s, h, True, 100)) == 7
|
||||
|
||||
def test_parse_http_protocol():
|
||||
assert protocol.parse_http_protocol("HTTP/1.1") == (1, 1)
|
||||
assert protocol.parse_http_protocol("HTTP/0.0") == (0, 0)
|
||||
assert not protocol.parse_http_protocol("foo/0.0")
|
||||
|
||||
|
||||
def test_parse_init_connect():
|
||||
assert protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0")
|
||||
assert not protocol.parse_init_connect("bogus")
|
||||
assert not protocol.parse_init_connect("GET host.com:443 HTTP/1.0")
|
||||
assert not protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0")
|
||||
assert not protocol.parse_init_connect("CONNECT host.com:443 foo/1.0")
|
||||
|
||||
|
||||
def test_prase_init_proxy():
|
||||
u = "GET http://foo.com:8888/test HTTP/1.1"
|
||||
m, s, h, po, pa, httpversion = protocol.parse_init_proxy(u)
|
||||
assert m == "GET"
|
||||
assert s == "http"
|
||||
assert h == "foo.com"
|
||||
assert po == 8888
|
||||
assert pa == "/test"
|
||||
assert httpversion == (1, 1)
|
||||
|
||||
assert not protocol.parse_init_proxy("invalid")
|
||||
assert not protocol.parse_init_proxy("GET invalid HTTP/1.1")
|
||||
assert not protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1")
|
||||
|
||||
|
||||
def test_parse_init_http():
|
||||
u = "GET /test HTTP/1.1"
|
||||
m, u, httpversion= protocol.parse_init_http(u)
|
||||
assert m == "GET"
|
||||
assert u == "/test"
|
||||
assert httpversion == (1, 1)
|
||||
|
||||
assert not protocol.parse_init_http("invalid")
|
||||
assert not protocol.parse_init_http("GET invalid HTTP/1.1")
|
||||
assert not protocol.parse_init_http("GET /test foo/1.1")
|
||||
|
||||
|
||||
class TestReadHeaders:
|
||||
def test_read_simple(self):
|
||||
data = """
|
||||
Header: one
|
||||
Header2: two
|
||||
\r\n
|
||||
"""
|
||||
data = textwrap.dedent(data)
|
||||
data = data.strip()
|
||||
s = cStringIO.StringIO(data)
|
||||
h = protocol.read_headers(s)
|
||||
assert h == [["Header", "one"], ["Header2", "two"]]
|
||||
|
||||
def test_read_multi(self):
|
||||
data = """
|
||||
Header: one
|
||||
Header: two
|
||||
\r\n
|
||||
"""
|
||||
data = textwrap.dedent(data)
|
||||
data = data.strip()
|
||||
s = cStringIO.StringIO(data)
|
||||
h = protocol.read_headers(s)
|
||||
assert h == [["Header", "one"], ["Header", "two"]]
|
||||
|
||||
def test_read_continued(self):
|
||||
data = """
|
||||
Header: one
|
||||
\ttwo
|
||||
Header2: three
|
||||
\r\n
|
||||
"""
|
||||
data = textwrap.dedent(data)
|
||||
data = data.strip()
|
||||
s = cStringIO.StringIO(data)
|
||||
h = protocol.read_headers(s)
|
||||
assert h == [["Header", "one\r\n two"], ["Header2", "three"]]
|
||||
|
||||
|
||||
def test_parse_url():
|
||||
assert not protocol.parse_url("")
|
||||
|
||||
u = "http://foo.com:8888/test"
|
||||
s, h, po, pa = protocol.parse_url(u)
|
||||
assert s == "http"
|
||||
assert h == "foo.com"
|
||||
assert po == 8888
|
||||
assert pa == "/test"
|
||||
|
||||
s, h, po, pa = protocol.parse_url("http://foo/bar")
|
||||
assert s == "http"
|
||||
assert h == "foo"
|
||||
assert po == 80
|
||||
assert pa == "/bar"
|
||||
|
||||
s, h, po, pa = protocol.parse_url("http://foo")
|
||||
assert pa == "/"
|
||||
|
||||
s, h, po, pa = protocol.parse_url("https://foo")
|
||||
assert po == 443
|
||||
|
||||
assert not protocol.parse_url("https://foo:bar")
|
||||
assert not protocol.parse_url("https://foo:")
|
||||
|
Loading…
Reference in New Issue
Block a user