diff --git a/libmproxy/protocol.py b/libmproxy/protocol.py new file mode 100644 index 000000000..3e393b46e --- /dev/null +++ b/libmproxy/protocol.py @@ -0,0 +1,196 @@ +import string +import flow, utils + +class ProtocolError(Exception): + def __init__(self, code, msg): + self.code, self.msg = code, msg + + def __str__(self): + return "ProtocolError(%s, %s)"%(self.code, self.msg) + + +def read_headers(fp): + """ + Read a set of headers from a file pointer. Stop once a blank line + is reached. Return a ODictCaseless object. + """ + ret = [] + name = '' + while 1: + line = fp.readline() + if not line or line == '\r\n' or line == '\n': + break + if line[0] in ' \t': + # continued header + ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() + else: + i = line.find(':') + # We're being liberal in what we accept, here. + if i > 0: + name = line[:i] + value = line[i+1:].strip() + ret.append([name, value]) + return flow.ODictCaseless(ret) + + +def read_chunked(fp, limit): + content = "" + total = 0 + while 1: + line = fp.readline(128) + if line == "": + raise IOError("Connection closed") + if line == '\r\n' or line == '\n': + continue + try: + length = int(line,16) + except ValueError: + # FIXME: Not strictly correct - this could be from the server, in which + # case we should send a 502. + raise ProtocolError(400, "Invalid chunked encoding length: %s"%line) + if not length: + break + total += length + if limit is not None and total > limit: + msg = "HTTP Body too large."\ + " Limit is %s, chunked content length was at least %s"%(limit, total) + raise ProtocolError(509, msg) + content += fp.read(length) + line = fp.readline(5) + if line != '\r\n': + raise IOError("Malformed chunked body") + while 1: + line = fp.readline() + if line == "": + raise IOError("Connection closed") + if line == '\r\n' or line == '\n': + break + return content + + +def has_chunked_encoding(headers): + for i in headers["transfer-encoding"]: + for j in i.split(","): + if j.lower() == "chunked": + return True + return False + + +def read_http_body(rfile, headers, all, limit): + if has_chunked_encoding(headers): + content = read_chunked(rfile, limit) + elif "content-length" in headers: + try: + l = int(headers["content-length"][0]) + except ValueError: + # FIXME: Not strictly correct - this could be from the server, in which + # case we should send a 502. + raise ProtocolError(400, "Invalid content-length header: %s"%headers["content-length"]) + if limit is not None and l > limit: + raise ProtocolError(509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) + content = rfile.read(l) + elif all: + content = rfile.read(limit if limit else None) + else: + content = "" + return content + + +def parse_http_protocol(s): + if not s.startswith("HTTP/"): + return None + major, minor = s.split('/')[1].split('.') + major = int(major) + minor = int(minor) + return major, minor + + +def parse_init_connect(line): + try: + method, url, protocol = string.split(line) + except ValueError: + return None + if method != 'CONNECT': + return None + try: + host, port = url.split(":") + except ValueError: + return None + port = int(port) + httpversion = parse_http_protocol(protocol) + if not httpversion: + return None + return host, port, httpversion + + +def parse_init_proxy(line): + try: + method, url, protocol = string.split(line) + except ValueError: + return None + parts = utils.parse_url(url) + if not parts: + return None + scheme, host, port, path = parts + httpversion = parse_http_protocol(protocol) + if not httpversion: + return None + return method, scheme, host, port, path, httpversion + + +def parse_init_http(line): + """ + Returns (method, url, httpversion) + """ + try: + method, url, protocol = string.split(line) + except ValueError: + return None + if not (url.startswith("/") or url == "*"): + return None + httpversion = parse_http_protocol(protocol) + if not httpversion: + return None + return method, url, httpversion + + +def request_connection_close(httpversion, headers): + """ + Checks the request to see if the client connection should be closed. + """ + if "connection" in headers: + for value in ",".join(headers['connection']).split(","): + value = value.strip() + if value == "close": + return True + elif value == "keep-alive": + return False + # HTTP 1.1 connections are assumed to be persistent + if httpversion == (1, 1): + return False + return True + + +def response_connection_close(httpversion, headers): + """ + Checks the response to see if the client connection should be closed. + """ + if request_connection_close(httpversion, headers): + return True + elif not has_chunked_encoding(headers) and "content-length" in headers: + return True + return False + + +def read_http_body_request(rfile, wfile, headers, httpversion, limit): + if "expect" in headers: + # FIXME: Should be forwarded upstream + expect = ",".join(headers['expect']) + if expect == "100-continue" and httpversion >= (1, 1): + wfile.write('HTTP/1.1 100 Continue\r\n') + wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION) + wfile.write('\r\n') + del headers['expect'] + return read_http_body(rfile, headers, False, limit) + + diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 9ebe01539..122afacd0 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -15,7 +15,7 @@ import sys, os, string, socket, time import shutil, tempfile, threading import optparse, SocketServer -import utils, flow, certutils, version, wsgi, netlib +import utils, flow, certutils, version, wsgi, netlib, protocol from OpenSSL import SSL @@ -41,191 +41,6 @@ class ProxyConfig: self.transparent_proxy = transparent_proxy -def read_headers(fp): - """ - Read a set of headers from a file pointer. Stop once a blank line - is reached. Return a ODictCaseless object. - """ - ret = [] - name = '' - while 1: - line = fp.readline() - if not line or line == '\r\n' or line == '\n': - break - if line[0] in ' \t': - # continued header - ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() - else: - i = line.find(':') - # We're being liberal in what we accept, here. - if i > 0: - name = line[:i] - value = line[i+1:].strip() - ret.append([name, value]) - return flow.ODictCaseless(ret) - - -def read_chunked(fp, limit): - content = "" - total = 0 - while 1: - line = fp.readline(128) - if line == "": - raise IOError("Connection closed") - if line == '\r\n' or line == '\n': - continue - try: - length = int(line,16) - except ValueError: - # FIXME: Not strictly correct - this could be from the server, in which - # case we should send a 502. - raise ProxyError(400, "Invalid chunked encoding length: %s"%line) - if not length: - break - total += length - if limit is not None and total > limit: - msg = "HTTP Body too large."\ - " Limit is %s, chunked content length was at least %s"%(limit, total) - raise ProxyError(509, msg) - content += fp.read(length) - line = fp.readline(5) - if line != '\r\n': - raise IOError("Malformed chunked body") - while 1: - line = fp.readline() - if line == "": - raise IOError("Connection closed") - if line == '\r\n' or line == '\n': - break - return content - - -def has_chunked_encoding(headers): - for i in headers["transfer-encoding"]: - for j in i.split(","): - if j.lower() == "chunked": - return True - return False - - -def read_http_body(rfile, headers, all, limit): - if has_chunked_encoding(headers): - content = read_chunked(rfile, limit) - elif "content-length" in headers: - try: - l = int(headers["content-length"][0]) - except ValueError: - # FIXME: Not strictly correct - this could be from the server, in which - # case we should send a 502. - raise ProxyError(400, "Invalid content-length header: %s"%headers["content-length"]) - if limit is not None and l > limit: - raise ProxyError(509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) - content = rfile.read(l) - elif all: - content = rfile.read(limit if limit else None) - else: - content = "" - return content - - -def parse_http_protocol(s): - if not s.startswith("HTTP/"): - return None - major, minor = s.split('/')[1].split('.') - major = int(major) - minor = int(minor) - return major, minor - - -def parse_init_connect(line): - try: - method, url, protocol = string.split(line) - except ValueError: - return None - if method != 'CONNECT': - return None - try: - host, port = url.split(":") - except ValueError: - return None - port = int(port) - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None - return host, port, httpversion - - -def parse_init_proxy(line): - try: - method, url, protocol = string.split(line) - except ValueError: - return None - parts = utils.parse_url(url) - if not parts: - return None - scheme, host, port, path = parts - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None - return method, scheme, host, port, path, httpversion - - -def parse_init_http(line): - """ - Returns (method, url, httpversion) - """ - try: - method, url, protocol = string.split(line) - except ValueError: - return None - if not (url.startswith("/") or url == "*"): - return None - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None - return method, url, httpversion - - -def request_connection_close(httpversion, headers): - """ - Checks the request to see if the client connection should be closed. - """ - if "connection" in headers: - for value in ",".join(headers['connection']).split(","): - value = value.strip() - if value == "close": - return True - elif value == "keep-alive": - return False - # HTTP 1.1 connections are assumed to be persistent - if httpversion == (1, 1): - return False - return True - - -def response_connection_close(httpversion, headers): - """ - Checks the response to see if the client connection should be closed. - """ - if request_connection_close(httpversion, headers): - return True - elif not has_chunked_encoding(headers) and "content-length" in headers: - return True - return False - - -def read_http_body_request(rfile, wfile, headers, httpversion, limit): - if "expect" in headers: - # FIXME: Should be forwarded upstream - expect = ",".join(headers['expect']) - if expect == "100-continue" and httpversion >= (1, 1): - wfile.write('HTTP/1.1 100 Continue\r\n') - wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION) - wfile.write('\r\n') - del headers['expect'] - return read_http_body(rfile, headers, False, limit) - - class RequestReplayThread(threading.Thread): def __init__(self, config, flow, masterq): self.config, self.flow, self.masterq = config, flow, masterq @@ -238,7 +53,7 @@ class RequestReplayThread(threading.Thread): server.send(r) response = server.read_response(r) response._send(self.masterq) - except ProxyError, v: + except (ProxyError, ProtocolError), v: err = flow.Error(self.flow.request, v.msg) err._send(self.masterq) except netlib.NetLibError, v: @@ -285,20 +100,20 @@ class ServerConnection(netlib.TCPClient): if not len(parts) == 3: raise ProxyError(502, "Invalid server response: %s."%line) proto, code, msg = parts - httpversion = parse_http_protocol(proto) + httpversion = protocol.parse_http_protocol(proto) if httpversion is None: raise ProxyError(502, "Invalid HTTP version: %s."%httpversion) try: code = int(code) except ValueError: raise ProxyError(502, "Invalid server response: %s."%line) - headers = read_headers(self.rfile) + headers = protocol.read_headers(self.rfile) if code >= 100 and code <= 199: return self.read_response() if request.method == "HEAD" or code == 204 or code == 304: content = "" else: - content = read_http_body(self.rfile, headers, True, self.config.body_size_limit) + content = protocol.read_http_body(self.rfile, headers, True, self.config.body_size_limit) return flow.Response(request, httpversion, code, msg, headers, content, self.cert) def terminate(self): @@ -378,17 +193,17 @@ class ProxyHandler(netlib.BaseHandler): if response is None: return self.send_response(response) - if request_connection_close(request.httpversion, request.headers): + if protocol.request_connection_close(request.httpversion, request.headers): return # We could keep the client connection when the server # connection needs to go away. However, we want to mimic # behaviour as closely as possible to the client, so we # disconnect. - if response_connection_close(response.httpversion, response.headers): + if protocol.response_connection_close(response.httpversion, response.headers): return except IOError, v: cc.connection_error = v - except ProxyError, e: + except (ProxyError, ProtocolError), e: cc.connection_error = "%s: %s"%(e.code, e.msg) if request: err = flow.Error(request, e.msg) @@ -427,23 +242,23 @@ class ProxyHandler(netlib.BaseHandler): self.convert_to_ssl(certfile, self.config.certfile or self.config.cacert) else: scheme = "http" - method, path, httpversion = parse_init_http(line) - headers = read_headers(self.rfile) - content = read_http_body_request( + method, path, httpversion = protocol.parse_init_http(line) + headers = protocol.read_headers(self.rfile) + content = protocol.read_http_body_request( self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit ) return flow.Request(client_conn, httpversion, host, port, "http", method, path, headers, content) elif self.config.reverse_proxy: scheme, host, port = self.config.reverse_proxy - method, path, httpversion = parse_init_http(line) - headers = read_headers(self.rfile) - content = read_http_body_request( + method, path, httpversion = protocol.parse_init_http(line) + headers = protocol.read_headers(self.rfile) + content = protocol.read_http_body_request( self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit ) return flow.Request(client_conn, httpversion, host, port, "http", method, path, headers, content) else: if line.startswith("CONNECT"): - host, port, httpversion = parse_init_connect(line) + host, port, httpversion = protocol.parse_init_connect(line) # FIXME: Discard additional headers sent to the proxy. Should I expose # these to users? while 1: @@ -462,16 +277,16 @@ class ProxyHandler(netlib.BaseHandler): line = self.rfile.readline(line) if self.proxy_connect_state: host, port, httpversion = self.proxy_connect_state - method, path, httpversion = parse_init_http(line) - headers = read_headers(self.rfile) - content = read_http_body_request( + method, path, httpversion = protocol.parse_init_http(line) + headers = protocol.read_headers(self.rfile) + content = protocol.read_http_body_request( self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit ) return flow.Request(client_conn, httpversion, host, port, "https", method, path, headers, content) else: - method, scheme, host, port, path, httpversion = parse_init_proxy(line) - headers = read_headers(self.rfile) - content = read_http_body_request( + method, scheme, host, port, path, httpversion = protocol.parse_init_proxy(line) + headers = protocol.read_headers(self.rfile) + content = protocol.read_http_body_request( self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit ) return flow.Request(client_conn, httpversion, host, port, scheme, method, path, headers, content) diff --git a/test/test_netlib.py b/test/test_netlib.py index 12aa2acc1..19902d177 100644 --- a/test/test_netlib.py +++ b/test/test_netlib.py @@ -72,7 +72,6 @@ class TestServer(ServerTestBase): assert "Testing an error" in self.q.get() - class TestTCPClient: def test_conerr(self): tutils.raises(netlib.NetLibError, netlib.TCPClient, False, "127.0.0.1", 0, None) @@ -92,4 +91,3 @@ class TestFileLike: s = cStringIO.StringIO("foobar\nfoobar") s = netlib.FileLike(s) assert s.readline(3) == "foo" - diff --git a/test/test_protocol.py b/test/test_protocol.py new file mode 100644 index 000000000..d0fa2e36d --- /dev/null +++ b/test/test_protocol.py @@ -0,0 +1,139 @@ +import cStringIO, textwrap +from libmproxy import protocol, flow +import tutils + +def test_has_chunked_encoding(): + h = flow.ODictCaseless() + assert not protocol.has_chunked_encoding(h) + h["transfer-encoding"] = ["chunked"] + assert protocol.has_chunked_encoding(h) + + +def test_read_chunked(): + s = cStringIO.StringIO("1\r\na\r\n0\r\n") + tutils.raises(IOError, protocol.read_chunked, s, None) + + s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") + assert protocol.read_chunked(s, None) == "a" + + s = cStringIO.StringIO("\r\n") + tutils.raises(IOError, protocol.read_chunked, s, None) + + s = cStringIO.StringIO("1\r\nfoo") + tutils.raises(IOError, protocol.read_chunked, s, None) + + s = cStringIO.StringIO("foo\r\nfoo") + tutils.raises(protocol.ProtocolError, protocol.read_chunked, s, None) + + +def test_request_connection_close(): + h = flow.ODictCaseless() + assert protocol.request_connection_close((1, 0), h) + assert not protocol.request_connection_close((1, 1), h) + + h["connection"] = ["keep-alive"] + assert not protocol.request_connection_close((1, 1), h) + + +def test_read_http_body(): + h = flow.ODict() + s = cStringIO.StringIO("testing") + assert protocol.read_http_body(s, h, False, None) == "" + + h["content-length"] = ["foo"] + s = cStringIO.StringIO("testing") + tutils.raises(protocol.ProtocolError, protocol.read_http_body, s, h, False, None) + + h["content-length"] = [5] + s = cStringIO.StringIO("testing") + assert len(protocol.read_http_body(s, h, False, None)) == 5 + s = cStringIO.StringIO("testing") + tutils.raises(protocol.ProtocolError, protocol.read_http_body, s, h, False, 4) + + h = flow.ODict() + s = cStringIO.StringIO("testing") + assert len(protocol.read_http_body(s, h, True, 4)) == 4 + s = cStringIO.StringIO("testing") + assert len(protocol.read_http_body(s, h, True, 100)) == 7 + +def test_parse_http_protocol(): + assert protocol.parse_http_protocol("HTTP/1.1") == (1, 1) + assert protocol.parse_http_protocol("HTTP/0.0") == (0, 0) + assert not protocol.parse_http_protocol("foo/0.0") + + +def test_parse_init_connect(): + assert protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0") + assert not protocol.parse_init_connect("bogus") + assert not protocol.parse_init_connect("GET host.com:443 HTTP/1.0") + assert not protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0") + assert not protocol.parse_init_connect("CONNECT host.com:443 foo/1.0") + + +def test_prase_init_proxy(): + u = "GET http://foo.com:8888/test HTTP/1.1" + m, s, h, po, pa, httpversion = protocol.parse_init_proxy(u) + assert m == "GET" + assert s == "http" + assert h == "foo.com" + assert po == 8888 + assert pa == "/test" + assert httpversion == (1, 1) + + assert not protocol.parse_init_proxy("invalid") + assert not protocol.parse_init_proxy("GET invalid HTTP/1.1") + assert not protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1") + + +def test_parse_init_http(): + u = "GET /test HTTP/1.1" + m, u, httpversion= protocol.parse_init_http(u) + assert m == "GET" + assert u == "/test" + assert httpversion == (1, 1) + + assert not protocol.parse_init_http("invalid") + assert not protocol.parse_init_http("GET invalid HTTP/1.1") + assert not protocol.parse_init_http("GET /test foo/1.1") + + +class TestReadHeaders: + def test_read_simple(self): + data = """ + Header: one + Header2: two + \r\n + """ + data = textwrap.dedent(data) + data = data.strip() + s = cStringIO.StringIO(data) + headers = protocol.read_headers(s) + assert headers["header"] == ["one"] + assert headers["header2"] == ["two"] + + def test_read_multi(self): + data = """ + Header: one + Header: two + \r\n + """ + data = textwrap.dedent(data) + data = data.strip() + s = cStringIO.StringIO(data) + headers = protocol.read_headers(s) + assert headers["header"] == ["one", "two"] + + def test_read_continued(self): + data = """ + Header: one + \ttwo + Header2: three + \r\n + """ + data = textwrap.dedent(data) + data = data.strip() + s = cStringIO.StringIO(data) + headers = protocol.read_headers(s) + assert headers["header"] == ['one\r\n two'] + + diff --git a/test/test_proxy.py b/test/test_proxy.py index 5fab282c6..1e1369df4 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -5,143 +5,8 @@ from libmproxy import proxy, flow import tutils -def test_has_chunked_encoding(): - h = flow.ODictCaseless() - assert not proxy.has_chunked_encoding(h) - h["transfer-encoding"] = ["chunked"] - assert proxy.has_chunked_encoding(h) - - -def test_read_chunked(): - s = cStringIO.StringIO("1\r\na\r\n0\r\n") - tutils.raises(IOError, proxy.read_chunked, s, None) - - s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") - assert proxy.read_chunked(s, None) == "a" - - s = cStringIO.StringIO("\r\n") - tutils.raises(IOError, proxy.read_chunked, s, None) - - s = cStringIO.StringIO("1\r\nfoo") - tutils.raises(IOError, proxy.read_chunked, s, None) - - s = cStringIO.StringIO("foo\r\nfoo") - tutils.raises(proxy.ProxyError, proxy.read_chunked, s, None) - - -def test_request_connection_close(): - h = flow.ODictCaseless() - assert proxy.request_connection_close((1, 0), h) - assert not proxy.request_connection_close((1, 1), h) - - h["connection"] = ["keep-alive"] - assert not proxy.request_connection_close((1, 1), h) - - -def test_read_http_body(): - h = flow.ODict() - s = cStringIO.StringIO("testing") - assert proxy.read_http_body(s, h, False, None) == "" - - h["content-length"] = ["foo"] - s = cStringIO.StringIO("testing") - tutils.raises(proxy.ProxyError, proxy.read_http_body, s, h, False, None) - - h["content-length"] = [5] - s = cStringIO.StringIO("testing") - assert len(proxy.read_http_body(s, h, False, None)) == 5 - s = cStringIO.StringIO("testing") - tutils.raises(proxy.ProxyError, proxy.read_http_body, s, h, False, 4) - - h = flow.ODict() - s = cStringIO.StringIO("testing") - assert len(proxy.read_http_body(s, h, True, 4)) == 4 - s = cStringIO.StringIO("testing") - assert len(proxy.read_http_body(s, h, True, 100)) == 7 - - class TestProxyError: def test_simple(self): p = proxy.ProxyError(111, "msg") assert repr(p) - -class TestReadHeaders: - def test_read_simple(self): - data = """ - Header: one - Header2: two - \r\n - """ - data = textwrap.dedent(data) - data = data.strip() - s = StringIO(data) - headers = proxy.read_headers(s) - assert headers["header"] == ["one"] - assert headers["header2"] == ["two"] - - def test_read_multi(self): - data = """ - Header: one - Header: two - \r\n - """ - data = textwrap.dedent(data) - data = data.strip() - s = StringIO(data) - headers = proxy.read_headers(s) - assert headers["header"] == ["one", "two"] - - def test_read_continued(self): - data = """ - Header: one - \ttwo - Header2: three - \r\n - """ - data = textwrap.dedent(data) - data = data.strip() - s = StringIO(data) - headers = proxy.read_headers(s) - assert headers["header"] == ['one\r\n two'] - - -def test_parse_http_protocol(): - assert proxy.parse_http_protocol("HTTP/1.1") == (1, 1) - assert proxy.parse_http_protocol("HTTP/0.0") == (0, 0) - assert not proxy.parse_http_protocol("foo/0.0") - - -def test_parse_init_connect(): - assert proxy.parse_init_connect("CONNECT host.com:443 HTTP/1.0") - assert not proxy.parse_init_connect("bogus") - assert not proxy.parse_init_connect("GET host.com:443 HTTP/1.0") - assert not proxy.parse_init_connect("CONNECT host.com443 HTTP/1.0") - assert not proxy.parse_init_connect("CONNECT host.com:443 foo/1.0") - - -def test_prase_init_proxy(): - u = "GET http://foo.com:8888/test HTTP/1.1" - m, s, h, po, pa, httpversion = proxy.parse_init_proxy(u) - assert m == "GET" - assert s == "http" - assert h == "foo.com" - assert po == 8888 - assert pa == "/test" - assert httpversion == (1, 1) - - assert not proxy.parse_init_proxy("invalid") - assert not proxy.parse_init_proxy("GET invalid HTTP/1.1") - assert not proxy.parse_init_proxy("GET http://foo.com:8888/test foo/1.1") - - -def test_parse_init_http(): - u = "GET /test HTTP/1.1" - m, u, httpversion= proxy.parse_init_http(u) - assert m == "GET" - assert u == "/test" - assert httpversion == (1, 1) - - assert not proxy.parse_init_http("invalid") - assert not proxy.parse_init_http("GET invalid HTTP/1.1") - assert not proxy.parse_init_http("GET /test foo/1.1")