Add a WSGI adapter that lets us serve a WSGI app out of mitmproxy.

This commit adds:
    - A WSGI App adapter for mitmproxy
    - An app registry in the proxy instance that lets us link WSGI apps with
    (hostname, port) combinations.
    - Fixes for a number of bugs discovered while creating this feature.
This commit is contained in:
Aldo Cortesi 2012-04-24 09:43:14 +12:00
parent 51789228be
commit c8d2b2594b
7 changed files with 276 additions and 35 deletions

41
examples/proxapp Executable file
View File

@ -0,0 +1,41 @@
#!/usr/bin/env python
import bottle
import os
from libmproxy import proxy, flow
@bottle.route('/')
def index():
return 'Hi!'
class MyMaster(flow.FlowMaster):
def run(self):
try:
flow.FlowMaster.run(self)
except KeyboardInterrupt:
self.shutdown()
def handle_request(self, r):
f = flow.FlowMaster.handle_request(self, r)
if f:
r._ack()
return f
def handle_response(self, r):
f = flow.FlowMaster.handle_response(self, r)
if f:
r._ack()
print f
return f
config = proxy.ProxyConfig(
cacert = os.path.expanduser("~/.mitmproxy/mitmproxy-ca.pem")
)
state = flow.State()
server = proxy.ProxyServer(config, 8080)
server.apps.add(bottle.app(), "proxapp", 80)
m = MyMaster(server, state)
m.run()

View File

@ -160,7 +160,6 @@ class ODict:
""" """
if isinstance(valuelist, basestring): if isinstance(valuelist, basestring):
raise ValueError("ODict valuelist should be lists.") raise ValueError("ODict valuelist should be lists.")
k = self._kconv(k)
new = self._filter_lst(k, self.lst) new = self._filter_lst(k, self.lst)
for i in valuelist: for i in valuelist:
new.append((k, i)) new.append((k, i))
@ -174,7 +173,7 @@ class ODict:
def __contains__(self, k): def __contains__(self, k):
for i in self.lst: for i in self.lst:
if self._kconv(i[0]) == k: if self._kconv(i[0]) == self._kconv(k):
return True return True
return False return False
@ -187,6 +186,9 @@ class ODict:
else: else:
return d return d
def items(self):
return self.lst[:]
def _get_state(self): def _get_state(self):
return [tuple(i) for i in self.lst] return [tuple(i) for i in self.lst]

View File

@ -21,9 +21,7 @@
import sys, os, string, socket, time import sys, os, string, socket, time
import shutil, tempfile, threading import shutil, tempfile, threading
import optparse, SocketServer, ssl import optparse, SocketServer, ssl
import utils, flow, certutils import utils, flow, certutils, version, wsgi
NAME = "mitmproxy"
class ProxyError(Exception): class ProxyError(Exception):
@ -128,6 +126,8 @@ def read_http_body(rfile, connection, headers, all, limit):
return content return content
#FIXME: Return full HTTP version specification from here. Allow non-HTTP
#protocol specs, and make it all editable.
def parse_request_line(request): def parse_request_line(request):
""" """
Parse a proxy request line. Return (method, scheme, host, port, path, minor). Parse a proxy request line. Return (method, scheme, host, port, path, minor).
@ -230,7 +230,7 @@ class ServerConnection:
self.scheme = request.scheme self.scheme = request.scheme
self.close = False self.close = False
self.cert = None self.cert = None
self.server, self.rfile, self.wfile = None, None, None self.sock, self.rfile, self.wfile = None, None, None
self.connect() self.connect()
def connect(self): def connect(self):
@ -244,7 +244,7 @@ class ServerConnection:
self.cert = server.getpeercert(True) self.cert = server.getpeercert(True)
except socket.error, err: except socket.error, err:
raise ProxyError(502, 'Error connecting to "%s": %s' % (self.host, err)) raise ProxyError(502, 'Error connecting to "%s": %s' % (self.host, err))
self.server = server self.sock = server
self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb') self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb')
def send(self): def send(self):
@ -284,7 +284,7 @@ class ServerConnection:
try: try:
if not self.wfile.closed: if not self.wfile.closed:
self.wfile.flush() self.wfile.flush()
self.server.close() self.sock.close()
except IOError: except IOError:
pass pass
@ -305,7 +305,7 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
self.finish() self.finish()
def handle_request(self, cc): def handle_request(self, cc):
server, request, err = None, None, None server_conn, request, err = None, None, None
try: try:
try: try:
request = self.read_request(cc) request = self.read_request(cc)
@ -315,6 +315,11 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
cc.close = True cc.close = True
return return
cc.requestcount += 1 cc.requestcount += 1
app = self.server.apps.get(request)
if app:
app.serve(request, self.wfile)
else:
request = request._send(self.mqueue) request = request._send(self.mqueue)
if request is None: if request is None:
cc.close = True cc.close = True
@ -325,15 +330,15 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
request = False request = False
response = response._send(self.mqueue) response = response._send(self.mqueue)
else: else:
server = ServerConnection(self.config, request) server_conn = ServerConnection(self.config, request)
server.send() server_conn.send()
try: try:
response = server.read_response() response = server_conn.read_response()
except IOError, v: except IOError, v:
raise IOError, "Reading response: %s"%v raise IOError, "Reading response: %s"%v
response = response._send(self.mqueue) response = response._send(self.mqueue)
if response is None: if response is None:
server.terminate() server_conn.terminate()
if response is None: if response is None:
cc.close = True cc.close = True
return return
@ -348,8 +353,8 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
err = flow.Error(request, e.msg) err = flow.Error(request, e.msg)
err._send(self.mqueue) err._send(self.mqueue)
self.send_error(e.code, e.msg) self.send_error(e.code, e.msg)
if server: if server_conn:
server.terminate() server_conn.terminate()
def find_cert(self, host, port): def find_cert(self, host, port):
if self.config.certfile: if self.config.certfile:
@ -374,7 +379,7 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
return None return None
method, scheme, host, port, path, httpminor = parse_request_line(line) method, scheme, host, port, path, httpminor = parse_request_line(line)
if method == "CONNECT": if method == "CONNECT":
# Discard additional headers sent to the proxy. Should I expose # FIXME: Discard additional headers sent to the proxy. Should I expose
# these to users? # these to users?
while 1: while 1:
d = self.rfile.readline() d = self.rfile.readline()
@ -382,7 +387,7 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
break break
self.wfile.write( self.wfile.write(
'HTTP/1.1 200 Connection established\r\n' + 'HTTP/1.1 200 Connection established\r\n' +
('Proxy-agent: %s\r\n'%NAME) + ('Proxy-agent: %s\r\n'%version.NAMEVERSION) +
'\r\n' '\r\n'
) )
self.wfile.flush() self.wfile.flush()
@ -425,7 +430,7 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
expect = ",".join(headers['expect']) expect = ",".join(headers['expect'])
if expect == "100-continue" and httpminor >= 1: if expect == "100-continue" and httpminor >= 1:
self.wfile.write('HTTP/1.1 100 Continue\r\n') self.wfile.write('HTTP/1.1 100 Continue\r\n')
self.wfile.write('Proxy-agent: %s\r\n'%NAME) self.wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION)
self.wfile.write('\r\n') self.wfile.write('\r\n')
del headers['expect'] del headers['expect']
else: else:
@ -463,7 +468,7 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
import BaseHTTPServer import BaseHTTPServer
response = BaseHTTPServer.BaseHTTPRequestHandler.responses[code][0] response = BaseHTTPServer.BaseHTTPRequestHandler.responses[code][0]
self.wfile.write("HTTP/1.1 %s %s\r\n" % (code, response)) self.wfile.write("HTTP/1.1 %s %s\r\n" % (code, response))
self.wfile.write("Server: %s\r\n"%NAME) self.wfile.write("Server: %s\r\n"%version.NAMEVERSION)
self.wfile.write("Connection: close\r\n") self.wfile.write("Connection: close\r\n")
self.wfile.write("Content-type: text/html\r\n") self.wfile.write("Content-type: text/html\r\n")
self.wfile.write("\r\n") self.wfile.write("\r\n")
@ -494,6 +499,7 @@ class ProxyServer(ServerBase):
self.masterq = None self.masterq = None
self.certdir = tempfile.mkdtemp(prefix="mitmproxy") self.certdir = tempfile.mkdtemp(prefix="mitmproxy")
config.certdir = self.certdir config.certdir = self.certdir
self.apps = wsgi.AppRegistry()
def start_slave(self, klass, masterq): def start_slave(self, klass, masterq):
slave = klass(masterq, self) slave = klass(masterq, self)

View File

@ -1,2 +1,4 @@
IVERSION = (0, 8) IVERSION = (0, 8)
VERSION = ".".join(str(i) for i in IVERSION) VERSION = ".".join(str(i) for i in IVERSION)
NAME = "mitmproxy"
NAMEVERSION = NAME + " " + VERSION

120
libmproxy/wsgi.py Normal file
View File

@ -0,0 +1,120 @@
import cStringIO, urllib, time, sys
import version, flow
def date_time_string():
"""Return the current date and time formatted for a message header."""
WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
MONTHS = [None,
'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
now = time.time()
year, month, day, hh, mm, ss, wd, y, z = time.gmtime(now)
s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
WEEKS[wd],
day, MONTHS[month], year,
hh, mm, ss)
return s
class WSGIAdaptor:
def __init__(self, app, domain, port):
self.app, self.domain, self.port = app, domain, port
def make_environ(self, request, errsoc):
if '?' in request.path:
path_info, query = request.path.split('?', 1)
else:
path_info = request.path
query = ''
environ = {
'wsgi.version': (1, 0),
'wsgi.url_scheme': request.scheme,
'wsgi.input': cStringIO.StringIO(request.content),
'wsgi.errors': errsoc,
'wsgi.multithread': True,
'wsgi.multiprocess': False,
'wsgi.run_once': False,
'SERVER_SOFTWARE': version.NAMEVERSION,
'REQUEST_METHOD': request.method,
'SCRIPT_NAME': '',
'PATH_INFO': urllib.unquote(path_info),
'QUERY_STRING': query,
'CONTENT_TYPE': request.headers.get('Content-Type', [''])[0],
'CONTENT_LENGTH': request.headers.get('Content-Length', [''])[0],
'SERVER_NAME': self.domain,
'SERVER_PORT': self.port,
# FIXME: We need to pick up the protocol read from the request.
'SERVER_PROTOCOL': "HTTP/1.1",
}
if request.client_conn.address:
environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.client_conn.address
for key, value in request.headers.items():
key = 'HTTP_' + key.upper().replace('-', '_')
if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'):
environ[key] = value
return environ
def serve(self, request, soc):
state = dict(
response_started = False,
headers_sent = False,
status = None,
headers = None
)
def write(data):
if not state["headers_sent"]:
soc.write("HTTP/1.1 %s\r\n"%state["status"])
h = state["headers"]
if 'server' not in h:
h["Server"] = [version.NAMEVERSION]
if 'date' not in h:
h["Date"] = [date_time_string()]
soc.write(str(h))
soc.write("\r\n")
state["headers_sent"] = True
soc.write(data)
soc.flush()
def start_response(status, headers, exc_info=None):
if exc_info:
try:
if state["headers_sent"]:
raise exc_info[0], exc_info[1], exc_info[2]
finally:
exc_info = None
elif state["status"]:
raise AssertionError('Response already started')
state["status"] = status
state["headers"] = flow.ODictCaseless(headers)
return write
errs = cStringIO.StringIO()
try:
dataiter = self.app(self.make_environ(request, errs), start_response)
for i in dataiter:
write(i)
if not state["headers_sent"]:
write("")
except Exception, v:
print v
try:
# Serve internal server error page
pass
except Exception, v:
pass
return errs.getvalue()
class AppRegistry:
def __init__(self):
self.apps = {}
def add(self, app, domain, port):
self.apps[(domain, port)] = WSGIAdaptor(app, domain, port)
def get(self, request):
"""
Returns an WSGIAdaptor instance if request matches an app, or None.
"""
return self.apps.get((request.host, request.port), None)

View File

@ -1030,6 +1030,15 @@ class uODictCaseless(libpry.AutoTree):
def setUp(self): def setUp(self):
self.od = flow.ODictCaseless() self.od = flow.ODictCaseless()
def test_case_preservation(self):
self.od["Foo"] = ["1"]
assert "foo" in self.od
assert self.od.items()[0][0] == "Foo"
assert self.od.get("foo") == ["1"]
assert self.od.get("foo", [""]) == ["1"]
assert self.od.get("Foo", [""]) == ["1"]
assert self.od.get("xx", "yy") == "yy"
def test_del(self): def test_del(self):
self.od.add("foo", 1) self.od.add("foo", 1)
self.od.add("Foo", 2) self.od.add("Foo", 2)

61
test/test_wsgi.py Normal file
View File

@ -0,0 +1,61 @@
import cStringIO
import libpry
from libmproxy import wsgi
import tutils
class TestApp:
def __init__(self):
self.called = False
def __call__(self, environ, start_response):
self.called = True
status = '200 OK'
response_headers = [('Content-type', 'text/plain')]
start_response(status, response_headers)
return ['Hello', ' world!\n']
class uWSGIAdaptor(libpry.AutoTree):
def test_make_environ(self):
w = wsgi.WSGIAdaptor(None, "foo", 80)
assert w.make_environ(
tutils.treq(),
None
)
def test_serve(self):
ta = TestApp()
w = wsgi.WSGIAdaptor(ta, "foo", 80)
r = tutils.treq()
r.host = "foo"
r.port = 80
wfile = cStringIO.StringIO()
err = w.serve(r, wfile)
assert ta.called
assert not err
val = wfile.getvalue()
assert "Hello world" in val
assert "Server:" in val
class uAppRegistry(libpry.AutoTree):
def test_add_get(self):
ar = wsgi.AppRegistry()
ar.add("foo", "domain", 80)
r = tutils.treq()
r.host = "domain"
r.port = 80
assert ar.get(r)
r.port = 81
assert not ar.get(r)
tests = [
uWSGIAdaptor(),
uAppRegistry()
]