diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index a6a72d55b..458ea2b5c 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -326,11 +326,11 @@ class ProxyHandler(tcp.BaseHandler): if not self.ssl_established and (port in self.config.transparent_proxy["sslports"]): scheme = "https" dummycert = self.find_cert(client_conn, host, port, host) + sni = HandleSNI( + self, client_conn, host, port, + dummycert, self.config.certfile or self.config.cacert + ) try: - sni = HandleSNI( - self, client_conn, host, port, - dummycert, self.config.certfile or self.config.cacert - ) self.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert, handle_sni=sni) except tcp.NetLibError, v: raise ProxyError(400, str(v)) @@ -356,31 +356,29 @@ class ProxyHandler(tcp.BaseHandler): line = self.get_line(self.rfile) if line == "": return None - if http.parse_init_connect(line): - r = http.parse_init_connect(line) - if not r: - raise ProxyError(400, "Bad HTTP request line: %s"%repr(line)) - host, port, httpversion = r - headers = self.read_headers(authenticate=True) - - self.wfile.write( - 'HTTP/1.1 200 Connection established\r\n' + - ('Proxy-agent: %s\r\n'%self.server_version) + - '\r\n' - ) - self.wfile.flush() - dummycert = self.find_cert(client_conn, host, port, host) - try: + if not self.proxy_connect_state: + connparts = http.parse_init_connect(line) + if connparts: + host, port, httpversion = connparts + headers = self.read_headers(authenticate=True) + self.wfile.write( + 'HTTP/1.1 200 Connection established\r\n' + + ('Proxy-agent: %s\r\n'%self.server_version) + + '\r\n' + ) + self.wfile.flush() + dummycert = self.find_cert(client_conn, host, port, host) sni = HandleSNI( self, client_conn, host, port, dummycert, self.config.certfile or self.config.cacert ) - self.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert, handle_sni=sni) - except tcp.NetLibError, v: - raise ProxyError(400, str(v)) - self.proxy_connect_state = (host, port, httpversion) - line = self.rfile.readline(line) + try: + self.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert, handle_sni=sni) + except tcp.NetLibError, v: + raise ProxyError(400, str(v)) + self.proxy_connect_state = (host, port, httpversion) + line = self.rfile.readline(line) if self.proxy_connect_state: r = http.parse_init_http(line) diff --git a/test/test_server.py b/test/test_server.py index 86a75452e..3a1b019fd 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -45,6 +45,12 @@ class CommonMixin: assert "host" in l.request.headers assert l.response.code == 304 + def test_invalid_http(self): + t = tcp.TCPClient("127.0.0.1", self.proxy.port) + t.connect() + t.wfile.write("invalid\r\n\r\n") + t.wfile.flush() + assert "Bad Request" in t.rfile.readline() class TestHTTP(tservers.HTTPProxTest, CommonMixin): @@ -54,13 +60,6 @@ class TestHTTP(tservers.HTTPProxTest, CommonMixin): assert ret.status_code == 500 assert "ValueError" in ret.content - def test_invalid_http(self): - t = tcp.TCPClient("127.0.0.1", self.proxy.port) - t.connect() - t.wfile.write("invalid\n\n") - t.wfile.flush() - assert "Bad Request" in t.rfile.readline() - def test_invalid_connect(self): t = tcp.TCPClient("127.0.0.1", self.proxy.port) t.connect() @@ -125,6 +124,25 @@ class TestHTTP(tservers.HTTPProxTest, CommonMixin): ret = p.request("get:'http://localhost:0'") assert ret.status_code == 502 + def test_blank_leading_line(self): + p = self.pathoc() + req = "get:'%s/p/201':i0,'\r\n'" + assert p.request(req%self.server.urlbase).status_code == 201 + + def test_invalid_headers(self): + p = self.pathoc() + req = p.request("get:'http://foo':h':foo'='bar'") + print req + + +class TestHTTPConnectSSLError(tservers.HTTPProxTest): + certfile = True + def test_go(self): + p = self.pathoc() + req = "connect:'localhost:%s'"%self.proxy.port + assert p.request(req).status_code == 200 + assert p.request(req).status_code == 400 + class TestHTTPS(tservers.HTTPProxTest, CommonMixin): ssl = True @@ -140,6 +158,11 @@ class TestHTTPS(tservers.HTTPProxTest, CommonMixin): l = self.server.last_log() assert self.server.last_log()["request"]["sni"] == "testserver.com" + def test_error_post_connect(self): + p = self.pathoc() + assert p.request("get:/:i0,'invalid\r\n\r\n'").status_code == 400 + + class TestHTTPSNoUpstream(tservers.HTTPProxTest, CommonMixin): ssl = True @@ -163,12 +186,10 @@ class TestReverse(tservers.ReverseProxTest, CommonMixin): class TestTransparent(tservers.TransparentProxTest, CommonMixin): - transparent = True ssl = False class TestTransparentSSL(tservers.TransparentProxTest, CommonMixin): - transparent = True ssl = True def test_sni(self): f = self.pathod("304", sni="testserver.com") @@ -176,6 +197,10 @@ class TestTransparentSSL(tservers.TransparentProxTest, CommonMixin): l = self.server.last_log() assert self.server.last_log()["request"]["sni"] == "testserver.com" + def test_sslerr(self): + p = pathoc.Pathoc("localhost", self.proxy.port) + p.connect() + assert p.request("get:/").status_code == 400 class TestProxy(tservers.HTTPProxTest): @@ -267,3 +292,19 @@ class TestKillResponse(tservers.HTTPProxTest): # The server should have seen a request assert self.server.last_log() + +class EResolver(tservers.TResolver): + def original_addr(self, sock): + return None + + +class TestTransparentResolveError(tservers.TransparentProxTest): + resolver = EResolver + def test_resolve_error(self): + assert self.pathod("304").status_code == 502 + + + + + + diff --git a/test/tservers.py b/test/tservers.py index 4efed7e2d..7672f34ab 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -131,7 +131,7 @@ class ProxTestBase: class HTTPProxTest(ProxTestBase): def pathoc_raw(self): return libpathod.pathoc.Pathoc("127.0.0.1", self.proxy.port) - + def pathoc(self, sni=None): """ Returns a connected Pathoc instance. @@ -148,6 +148,7 @@ class HTTPProxTest(ProxTestBase): Constructs a pathod GET request, with the appropriate base and proxy. """ p = self.pathoc(sni=sni) + spec = spec.encode("string_escape") if self.ssl: q = "get:'/p/%s'"%spec else: @@ -165,6 +166,7 @@ class TResolver: class TransparentProxTest(ProxTestBase): ssl = None + resolver = TResolver @classmethod def get_proxy_config(cls): d = ProxTestBase.get_proxy_config() @@ -173,7 +175,7 @@ class TransparentProxTest(ProxTestBase): else: ports = [] d["transparent_proxy"] = dict( - resolver = TResolver(cls.server.port), + resolver = cls.resolver(cls.server.port), sslports = ports ) return d