Move wsgi to netlib.

This commit is contained in:
Aldo Cortesi 2012-06-19 10:42:55 +12:00
parent 1b1ccab8b7
commit 7cb242c168
4 changed files with 34 additions and 254 deletions

View File

@ -16,8 +16,8 @@ import sys, os, string, socket, time
import shutil, tempfile, threading
import optparse, SocketServer
from OpenSSL import SSL
from netlib import odict, tcp, protocol
import utils, flow, certutils, version, wsgi
from netlib import odict, tcp, protocol, wsgi
import utils, flow, certutils, version
class ProxyError(Exception):
@ -333,7 +333,7 @@ class ProxyServer(tcp.TCPServer):
self.masterq = None
self.certdir = tempfile.mkdtemp(prefix="mitmproxy")
config.certdir = self.certdir
self.apps = wsgi.AppRegistry()
self.apps = AppRegistry()
def start_slave(self, klass, masterq):
slave = klass(masterq, self)
@ -352,6 +352,24 @@ class ProxyServer(tcp.TCPServer):
pass
class AppRegistry:
def __init__(self):
self.apps = {}
def add(self, app, domain, port):
"""
Add a WSGI app to the registry, to be served for requests to the
specified domain, on the specified port.
"""
self.apps[(domain, port)] = wsgi.WSGIAdaptor(app, domain, port, version.NAMEVERSION)
def get(self, request):
"""
Returns an WSGIAdaptor instance if request matches an app, or None.
"""
return self.apps.get((request.host, request.port), None)
class DummyServer:
bound = False
def __init__(self, config):

View File

@ -1,141 +0,0 @@
import cStringIO, urllib, time, sys, traceback
import version, flow
def date_time_string():
"""Return the current date and time formatted for a message header."""
WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
MONTHS = [None,
'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
now = time.time()
year, month, day, hh, mm, ss, wd, y, z = time.gmtime(now)
s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
WEEKS[wd],
day, MONTHS[month], year,
hh, mm, ss)
return s
class WSGIAdaptor:
def __init__(self, app, domain, port):
self.app, self.domain, self.port = app, domain, port
def make_environ(self, request, errsoc):
if '?' in request.path:
path_info, query = request.path.split('?', 1)
else:
path_info = request.path
query = ''
environ = {
'wsgi.version': (1, 0),
'wsgi.url_scheme': request.scheme,
'wsgi.input': cStringIO.StringIO(request.content),
'wsgi.errors': errsoc,
'wsgi.multithread': True,
'wsgi.multiprocess': False,
'wsgi.run_once': False,
'SERVER_SOFTWARE': version.NAMEVERSION,
'REQUEST_METHOD': request.method,
'SCRIPT_NAME': '',
'PATH_INFO': urllib.unquote(path_info),
'QUERY_STRING': query,
'CONTENT_TYPE': request.headers.get('Content-Type', [''])[0],
'CONTENT_LENGTH': request.headers.get('Content-Length', [''])[0],
'SERVER_NAME': self.domain,
'SERVER_PORT': self.port,
# FIXME: We need to pick up the protocol read from the request.
'SERVER_PROTOCOL': "HTTP/1.1",
}
if request.client_conn.address:
environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.client_conn.address
for key, value in request.headers.items():
key = 'HTTP_' + key.upper().replace('-', '_')
if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'):
environ[key] = value
return environ
def error_page(self, soc, headers_sent, s):
"""
Make a best-effort attempt to write an error page. If headers are
already sent, we just bung the error into the page.
"""
c = """
<html>
<h1>Internal Server Error</h1>
<pre>%s"</pre>
</html>
"""%s
if not headers_sent:
soc.write("HTTP/1.1 500 Internal Server Error\r\n")
soc.write("Content-Type: text/html\r\n")
soc.write("Content-Length: %s\r\n"%len(c))
soc.write("\r\n")
soc.write(c)
def serve(self, request, soc):
state = dict(
response_started = False,
headers_sent = False,
status = None,
headers = None
)
def write(data):
if not state["headers_sent"]:
soc.write("HTTP/1.1 %s\r\n"%state["status"])
h = state["headers"]
if 'server' not in h:
h["Server"] = [version.NAMEVERSION]
if 'date' not in h:
h["Date"] = [date_time_string()]
soc.write(str(h))
soc.write("\r\n")
state["headers_sent"] = True
soc.write(data)
soc.flush()
def start_response(status, headers, exc_info=None):
if exc_info:
try:
if state["headers_sent"]:
raise exc_info[0], exc_info[1], exc_info[2]
finally:
exc_info = None
elif state["status"]:
raise AssertionError('Response already started')
state["status"] = status
state["headers"] = flow.ODictCaseless(headers)
return write
errs = cStringIO.StringIO()
try:
dataiter = self.app(self.make_environ(request, errs), start_response)
for i in dataiter:
write(i)
if not state["headers_sent"]:
write("")
except Exception, v:
try:
s = traceback.format_exc()
self.error_page(soc, state["headers_sent"], s)
except Exception, v: # pragma: no cover
pass # pragma: no cover
return errs.getvalue()
class AppRegistry:
def __init__(self):
self.apps = {}
def add(self, app, domain, port):
"""
Add a WSGI app to the registry, to be served for requests to the
specified domain, on the specified port.
"""
self.apps[(domain, port)] = WSGIAdaptor(app, domain, port)
def get(self, request):
"""
Returns an WSGIAdaptor instance if request matches an app, or None.
"""
return self.apps.get((request.host, request.port), None)

View File

@ -10,3 +10,16 @@ class TestProxyError:
p = proxy.ProxyError(111, "msg")
assert repr(p)
class TestAppRegistry:
def test_add_get(self):
ar = proxy.AppRegistry()
ar.add("foo", "domain", 80)
r = tutils.treq()
r.host = "domain"
r.port = 80
assert ar.get(r)
r.port = 81
assert not ar.get(r)

View File

@ -1,110 +0,0 @@
import cStringIO, sys
from libmproxy import wsgi
import tutils
class TestApp:
def __init__(self):
self.called = False
def __call__(self, environ, start_response):
self.called = True
status = '200 OK'
response_headers = [('Content-type', 'text/plain')]
start_response(status, response_headers)
return ['Hello', ' world!\n']
class TestWSGIAdaptor:
def test_make_environ(self):
w = wsgi.WSGIAdaptor(None, "foo", 80)
tr = tutils.treq()
assert w.make_environ(tr, None)
tr.path = "/foo?bar=voing"
r = w.make_environ(tr, None)
assert r["QUERY_STRING"] == "bar=voing"
def test_serve(self):
ta = TestApp()
w = wsgi.WSGIAdaptor(ta, "foo", 80)
r = tutils.treq()
r.host = "foo"
r.port = 80
wfile = cStringIO.StringIO()
err = w.serve(r, wfile)
assert ta.called
assert not err
val = wfile.getvalue()
assert "Hello world" in val
assert "Server:" in val
def _serve(self, app):
w = wsgi.WSGIAdaptor(app, "foo", 80)
r = tutils.treq()
r.host = "foo"
r.port = 80
wfile = cStringIO.StringIO()
err = w.serve(r, wfile)
return wfile.getvalue()
def test_serve_empty_body(self):
def app(environ, start_response):
status = '200 OK'
response_headers = [('Foo', 'bar')]
start_response(status, response_headers)
return []
assert self._serve(app)
def test_serve_double_start(self):
def app(environ, start_response):
try:
raise ValueError("foo")
except:
ei = sys.exc_info()
status = '200 OK'
response_headers = [('Content-type', 'text/plain')]
start_response(status, response_headers)
start_response(status, response_headers)
assert "Internal Server Error" in self._serve(app)
def test_serve_single_err(self):
def app(environ, start_response):
try:
raise ValueError("foo")
except:
ei = sys.exc_info()
status = '200 OK'
response_headers = [('Content-type', 'text/plain')]
start_response(status, response_headers, ei)
assert "Internal Server Error" in self._serve(app)
def test_serve_double_err(self):
def app(environ, start_response):
try:
raise ValueError("foo")
except:
ei = sys.exc_info()
status = '200 OK'
response_headers = [('Content-type', 'text/plain')]
start_response(status, response_headers)
yield "aaa"
start_response(status, response_headers, ei)
yield "bbb"
assert "Internal Server Error" in self._serve(app)
class TestAppRegistry:
def test_add_get(self):
ar = wsgi.AppRegistry()
ar.add("foo", "domain", 80)
r = tutils.treq()
r.host = "domain"
r.port = 80
assert ar.get(r)
r.port = 81
assert not ar.get(r)