diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 04734fcbb..0805a63d3 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -16,8 +16,8 @@ import sys, os, string, socket, time import shutil, tempfile, threading import optparse, SocketServer from OpenSSL import SSL -from netlib import odict, tcp, protocol -import utils, flow, certutils, version, wsgi +from netlib import odict, tcp, protocol, wsgi +import utils, flow, certutils, version class ProxyError(Exception): @@ -333,7 +333,7 @@ class ProxyServer(tcp.TCPServer): self.masterq = None self.certdir = tempfile.mkdtemp(prefix="mitmproxy") config.certdir = self.certdir - self.apps = wsgi.AppRegistry() + self.apps = AppRegistry() def start_slave(self, klass, masterq): slave = klass(masterq, self) @@ -352,6 +352,24 @@ class ProxyServer(tcp.TCPServer): pass +class AppRegistry: + def __init__(self): + self.apps = {} + + def add(self, app, domain, port): + """ + Add a WSGI app to the registry, to be served for requests to the + specified domain, on the specified port. + """ + self.apps[(domain, port)] = wsgi.WSGIAdaptor(app, domain, port, version.NAMEVERSION) + + def get(self, request): + """ + Returns an WSGIAdaptor instance if request matches an app, or None. + """ + return self.apps.get((request.host, request.port), None) + + class DummyServer: bound = False def __init__(self, config): diff --git a/libmproxy/wsgi.py b/libmproxy/wsgi.py deleted file mode 100644 index b4555312d..000000000 --- a/libmproxy/wsgi.py +++ /dev/null @@ -1,141 +0,0 @@ -import cStringIO, urllib, time, sys, traceback -import version, flow - -def date_time_string(): - """Return the current date and time formatted for a message header.""" - WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] - MONTHS = [None, - 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', - 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] - now = time.time() - year, month, day, hh, mm, ss, wd, y, z = time.gmtime(now) - s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( - WEEKS[wd], - day, MONTHS[month], year, - hh, mm, ss) - return s - - -class WSGIAdaptor: - def __init__(self, app, domain, port): - self.app, self.domain, self.port = app, domain, port - - def make_environ(self, request, errsoc): - if '?' in request.path: - path_info, query = request.path.split('?', 1) - else: - path_info = request.path - query = '' - environ = { - 'wsgi.version': (1, 0), - 'wsgi.url_scheme': request.scheme, - 'wsgi.input': cStringIO.StringIO(request.content), - 'wsgi.errors': errsoc, - 'wsgi.multithread': True, - 'wsgi.multiprocess': False, - 'wsgi.run_once': False, - 'SERVER_SOFTWARE': version.NAMEVERSION, - 'REQUEST_METHOD': request.method, - 'SCRIPT_NAME': '', - 'PATH_INFO': urllib.unquote(path_info), - 'QUERY_STRING': query, - 'CONTENT_TYPE': request.headers.get('Content-Type', [''])[0], - 'CONTENT_LENGTH': request.headers.get('Content-Length', [''])[0], - 'SERVER_NAME': self.domain, - 'SERVER_PORT': self.port, - # FIXME: We need to pick up the protocol read from the request. - 'SERVER_PROTOCOL': "HTTP/1.1", - } - if request.client_conn.address: - environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.client_conn.address - - for key, value in request.headers.items(): - key = 'HTTP_' + key.upper().replace('-', '_') - if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'): - environ[key] = value - return environ - - def error_page(self, soc, headers_sent, s): - """ - Make a best-effort attempt to write an error page. If headers are - already sent, we just bung the error into the page. - """ - c = """ - -
%s"- - """%s - if not headers_sent: - soc.write("HTTP/1.1 500 Internal Server Error\r\n") - soc.write("Content-Type: text/html\r\n") - soc.write("Content-Length: %s\r\n"%len(c)) - soc.write("\r\n") - soc.write(c) - - def serve(self, request, soc): - state = dict( - response_started = False, - headers_sent = False, - status = None, - headers = None - ) - def write(data): - if not state["headers_sent"]: - soc.write("HTTP/1.1 %s\r\n"%state["status"]) - h = state["headers"] - if 'server' not in h: - h["Server"] = [version.NAMEVERSION] - if 'date' not in h: - h["Date"] = [date_time_string()] - soc.write(str(h)) - soc.write("\r\n") - state["headers_sent"] = True - soc.write(data) - soc.flush() - - def start_response(status, headers, exc_info=None): - if exc_info: - try: - if state["headers_sent"]: - raise exc_info[0], exc_info[1], exc_info[2] - finally: - exc_info = None - elif state["status"]: - raise AssertionError('Response already started') - state["status"] = status - state["headers"] = flow.ODictCaseless(headers) - return write - - errs = cStringIO.StringIO() - try: - dataiter = self.app(self.make_environ(request, errs), start_response) - for i in dataiter: - write(i) - if not state["headers_sent"]: - write("") - except Exception, v: - try: - s = traceback.format_exc() - self.error_page(soc, state["headers_sent"], s) - except Exception, v: # pragma: no cover - pass # pragma: no cover - return errs.getvalue() - - -class AppRegistry: - def __init__(self): - self.apps = {} - - def add(self, app, domain, port): - """ - Add a WSGI app to the registry, to be served for requests to the - specified domain, on the specified port. - """ - self.apps[(domain, port)] = WSGIAdaptor(app, domain, port) - - def get(self, request): - """ - Returns an WSGIAdaptor instance if request matches an app, or None. - """ - return self.apps.get((request.host, request.port), None) diff --git a/test/test_proxy.py b/test/test_proxy.py index 1e1369df4..08b3634fa 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -10,3 +10,16 @@ class TestProxyError: p = proxy.ProxyError(111, "msg") assert repr(p) + +class TestAppRegistry: + def test_add_get(self): + ar = proxy.AppRegistry() + ar.add("foo", "domain", 80) + + r = tutils.treq() + r.host = "domain" + r.port = 80 + assert ar.get(r) + + r.port = 81 + assert not ar.get(r) diff --git a/test/test_wsgi.py b/test/test_wsgi.py deleted file mode 100644 index 1ae81d119..000000000 --- a/test/test_wsgi.py +++ /dev/null @@ -1,110 +0,0 @@ -import cStringIO, sys -from libmproxy import wsgi -import tutils - - -class TestApp: - def __init__(self): - self.called = False - - def __call__(self, environ, start_response): - self.called = True - status = '200 OK' - response_headers = [('Content-type', 'text/plain')] - start_response(status, response_headers) - return ['Hello', ' world!\n'] - - -class TestWSGIAdaptor: - def test_make_environ(self): - w = wsgi.WSGIAdaptor(None, "foo", 80) - tr = tutils.treq() - assert w.make_environ(tr, None) - - tr.path = "/foo?bar=voing" - r = w.make_environ(tr, None) - assert r["QUERY_STRING"] == "bar=voing" - - def test_serve(self): - ta = TestApp() - w = wsgi.WSGIAdaptor(ta, "foo", 80) - r = tutils.treq() - r.host = "foo" - r.port = 80 - - wfile = cStringIO.StringIO() - err = w.serve(r, wfile) - assert ta.called - assert not err - - val = wfile.getvalue() - assert "Hello world" in val - assert "Server:" in val - - def _serve(self, app): - w = wsgi.WSGIAdaptor(app, "foo", 80) - r = tutils.treq() - r.host = "foo" - r.port = 80 - wfile = cStringIO.StringIO() - err = w.serve(r, wfile) - return wfile.getvalue() - - def test_serve_empty_body(self): - def app(environ, start_response): - status = '200 OK' - response_headers = [('Foo', 'bar')] - start_response(status, response_headers) - return [] - assert self._serve(app) - - def test_serve_double_start(self): - def app(environ, start_response): - try: - raise ValueError("foo") - except: - ei = sys.exc_info() - status = '200 OK' - response_headers = [('Content-type', 'text/plain')] - start_response(status, response_headers) - start_response(status, response_headers) - assert "Internal Server Error" in self._serve(app) - - def test_serve_single_err(self): - def app(environ, start_response): - try: - raise ValueError("foo") - except: - ei = sys.exc_info() - status = '200 OK' - response_headers = [('Content-type', 'text/plain')] - start_response(status, response_headers, ei) - assert "Internal Server Error" in self._serve(app) - - def test_serve_double_err(self): - def app(environ, start_response): - try: - raise ValueError("foo") - except: - ei = sys.exc_info() - status = '200 OK' - response_headers = [('Content-type', 'text/plain')] - start_response(status, response_headers) - yield "aaa" - start_response(status, response_headers, ei) - yield "bbb" - assert "Internal Server Error" in self._serve(app) - - -class TestAppRegistry: - def test_add_get(self): - ar = wsgi.AppRegistry() - ar.add("foo", "domain", 80) - - r = tutils.treq() - r.host = "domain" - r.port = 80 - assert ar.get(r) - - r.port = 81 - assert not ar.get(r)