diff --git a/libmproxy/flow.py b/libmproxy/flow.py index 21a8f3a2f..a737057e3 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -21,7 +21,7 @@ import hashlib, Cookie, cookielib, copy, re, urlparse import time import tnetstring, filt, script, utils, encoding, proxy from email.utils import parsedate_tz, formatdate, mktime_tz -import controller, version, certutils +import controller, version, certutils, protocol HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" CONTENT_MISSING = 0 @@ -514,7 +514,7 @@ class Request(HTTPMsg): Returns False if the URL was invalid, True if the request succeeded. """ - parts = utils.parse_url(url) + parts = protocol.parse_url(url) if not parts: return False self.scheme, self.host, self.port, self.path = parts diff --git a/libmproxy/protocol.py b/libmproxy/protocol.py index 3e393b46e..547bff9e1 100644 --- a/libmproxy/protocol.py +++ b/libmproxy/protocol.py @@ -1,5 +1,4 @@ -import string -import flow, utils +import string, urlparse class ProtocolError(Exception): def __init__(self, code, msg): @@ -9,6 +8,31 @@ class ProtocolError(Exception): return "ProtocolError(%s, %s)"%(self.code, self.msg) +def parse_url(url): + """ + Returns a (scheme, host, port, path) tuple, or None on error. + """ + scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) + if not scheme: + return None + if ':' in netloc: + host, port = string.rsplit(netloc, ':', maxsplit=1) + try: + port = int(port) + except ValueError: + return None + else: + host = netloc + if scheme == "https": + port = 443 + else: + port = 80 + path = urlparse.urlunparse(('', '', path, params, query, fragment)) + if not path.startswith("/"): + path = "/" + path + return scheme, host, port, path + + def read_headers(fp): """ Read a set of headers from a file pointer. Stop once a blank line @@ -30,7 +54,7 @@ def read_headers(fp): name = line[:i] value = line[i+1:].strip() ret.append([name, value]) - return flow.ODictCaseless(ret) + return ret def read_chunked(fp, limit): @@ -128,7 +152,7 @@ def parse_init_proxy(line): method, url, protocol = string.split(line) except ValueError: return None - parts = utils.parse_url(url) + parts = parse_url(url) if not parts: return None scheme, host, port, path = parts diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 122afacd0..58ab7a580 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -53,7 +53,7 @@ class RequestReplayThread(threading.Thread): server.send(r) response = server.read_response(r) response._send(self.masterq) - except (ProxyError, ProtocolError), v: + except (ProxyError, protocol.ProtocolError), v: err = flow.Error(self.flow.request, v.msg) err._send(self.masterq) except netlib.NetLibError, v: @@ -107,7 +107,7 @@ class ServerConnection(netlib.TCPClient): code = int(code) except ValueError: raise ProxyError(502, "Invalid server response: %s."%line) - headers = protocol.read_headers(self.rfile) + headers = flow.ODictCaseless(protocol.read_headers(self.rfile)) if code >= 100 and code <= 199: return self.read_response() if request.method == "HEAD" or code == 204 or code == 304: @@ -203,7 +203,7 @@ class ProxyHandler(netlib.BaseHandler): return except IOError, v: cc.connection_error = v - except (ProxyError, ProtocolError), e: + except (ProxyError, protocol.ProtocolError), e: cc.connection_error = "%s: %s"%(e.code, e.msg) if request: err = flow.Error(request, e.msg) @@ -243,7 +243,7 @@ class ProxyHandler(netlib.BaseHandler): else: scheme = "http" method, path, httpversion = protocol.parse_init_http(line) - headers = protocol.read_headers(self.rfile) + headers = flow.ODictCaseless(protocol.read_headers(self.rfile)) content = protocol.read_http_body_request( self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit ) @@ -251,7 +251,7 @@ class ProxyHandler(netlib.BaseHandler): elif self.config.reverse_proxy: scheme, host, port = self.config.reverse_proxy method, path, httpversion = protocol.parse_init_http(line) - headers = protocol.read_headers(self.rfile) + headers = flow.ODictCaseless(protocol.read_headers(self.rfile)) content = protocol.read_http_body_request( self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit ) @@ -278,14 +278,14 @@ class ProxyHandler(netlib.BaseHandler): if self.proxy_connect_state: host, port, httpversion = self.proxy_connect_state method, path, httpversion = protocol.parse_init_http(line) - headers = protocol.read_headers(self.rfile) + headers = flow.ODictCaseless(protocol.read_headers(self.rfile)) content = protocol.read_http_body_request( self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit ) return flow.Request(client_conn, httpversion, host, port, "https", method, path, headers, content) else: method, scheme, host, port, path, httpversion = protocol.parse_init_proxy(line) - headers = protocol.read_headers(self.rfile) + headers = flow.ODictCaseless(protocol.read_headers(self.rfile)) content = protocol.read_http_body_request( self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit ) diff --git a/libmproxy/utils.py b/libmproxy/utils.py index 337d43785..989bb6951 100644 --- a/libmproxy/utils.py +++ b/libmproxy/utils.py @@ -15,6 +15,7 @@ import os, datetime, urlparse, string, urllib, re import time, functools, cgi import json +import protocol def timestamp(): """ @@ -194,33 +195,8 @@ class LRUCache: return wrap -def parse_url(url): - """ - Returns a (scheme, host, port, path) tuple, or None on error. - """ - scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) - if not scheme: - return None - if ':' in netloc: - host, port = string.rsplit(netloc, ':', maxsplit=1) - try: - port = int(port) - except ValueError: - return None - else: - host = netloc - if scheme == "https": - port = 443 - else: - port = 80 - path = urlparse.urlunparse(('', '', path, params, query, fragment)) - if not path.startswith("/"): - path = "/" + path - return scheme, host, port, path - - def parse_proxy_spec(url): - p = parse_url(url) + p = protocol.parse_url(url) if not p or not p[1]: return None return p[:3] diff --git a/test/test_protocol.py b/test/test_protocol.py index d0fa2e36d..81b5fefb6 100644 --- a/test/test_protocol.py +++ b/test/test_protocol.py @@ -107,9 +107,8 @@ class TestReadHeaders: data = textwrap.dedent(data) data = data.strip() s = cStringIO.StringIO(data) - headers = protocol.read_headers(s) - assert headers["header"] == ["one"] - assert headers["header2"] == ["two"] + h = protocol.read_headers(s) + assert h == [["Header", "one"], ["Header2", "two"]] def test_read_multi(self): data = """ @@ -120,8 +119,8 @@ class TestReadHeaders: data = textwrap.dedent(data) data = data.strip() s = cStringIO.StringIO(data) - headers = protocol.read_headers(s) - assert headers["header"] == ["one", "two"] + h = protocol.read_headers(s) + assert h == [["Header", "one"], ["Header", "two"]] def test_read_continued(self): data = """ @@ -133,7 +132,32 @@ class TestReadHeaders: data = textwrap.dedent(data) data = data.strip() s = cStringIO.StringIO(data) - headers = protocol.read_headers(s) - assert headers["header"] == ['one\r\n two'] + h = protocol.read_headers(s) + assert h == [["Header", "one\r\n two"], ["Header2", "three"]] +def test_parse_url(): + assert not protocol.parse_url("") + + u = "http://foo.com:8888/test" + s, h, po, pa = protocol.parse_url(u) + assert s == "http" + assert h == "foo.com" + assert po == 8888 + assert pa == "/test" + + s, h, po, pa = protocol.parse_url("http://foo/bar") + assert s == "http" + assert h == "foo" + assert po == 80 + assert pa == "/bar" + + s, h, po, pa = protocol.parse_url("http://foo") + assert pa == "/" + + s, h, po, pa = protocol.parse_url("https://foo") + assert po == 443 + + assert not protocol.parse_url("https://foo:bar") + assert not protocol.parse_url("https://foo:") + diff --git a/test/test_utils.py b/test/test_utils.py index 8a881b4e8..e23d919fd 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -107,32 +107,6 @@ def test_unparse_url(): assert utils.unparse_url("https", "foo.com", 443, "") == "https://foo.com" -def test_parse_url(): - assert not utils.parse_url("") - - u = "http://foo.com:8888/test" - s, h, po, pa = utils.parse_url(u) - assert s == "http" - assert h == "foo.com" - assert po == 8888 - assert pa == "/test" - - s, h, po, pa = utils.parse_url("http://foo/bar") - assert s == "http" - assert h == "foo" - assert po == 80 - assert pa == "/bar" - - s, h, po, pa = utils.parse_url("http://foo") - assert pa == "/" - - s, h, po, pa = utils.parse_url("https://foo") - assert po == 443 - - assert not utils.parse_url("https://foo:bar") - assert not utils.parse_url("https://foo:") - - def test_parse_size(): assert not utils.parse_size("") assert utils.parse_size("1") == 1