Extend unit tests for proxy.py to some tricky cases.

This commit is contained in:
Aldo Cortesi 2013-03-02 22:42:36 +13:00
parent 415844511c
commit c20d1d7d32
3 changed files with 76 additions and 35 deletions

View File

@ -326,11 +326,11 @@ class ProxyHandler(tcp.BaseHandler):
if not self.ssl_established and (port in self.config.transparent_proxy["sslports"]):
scheme = "https"
dummycert = self.find_cert(client_conn, host, port, host)
sni = HandleSNI(
self, client_conn, host, port,
dummycert, self.config.certfile or self.config.cacert
)
try:
sni = HandleSNI(
self, client_conn, host, port,
dummycert, self.config.certfile or self.config.cacert
)
self.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert, handle_sni=sni)
except tcp.NetLibError, v:
raise ProxyError(400, str(v))
@ -356,31 +356,29 @@ class ProxyHandler(tcp.BaseHandler):
line = self.get_line(self.rfile)
if line == "":
return None
if http.parse_init_connect(line):
r = http.parse_init_connect(line)
if not r:
raise ProxyError(400, "Bad HTTP request line: %s"%repr(line))
host, port, httpversion = r
headers = self.read_headers(authenticate=True)
self.wfile.write(
'HTTP/1.1 200 Connection established\r\n' +
('Proxy-agent: %s\r\n'%self.server_version) +
'\r\n'
)
self.wfile.flush()
dummycert = self.find_cert(client_conn, host, port, host)
try:
if not self.proxy_connect_state:
connparts = http.parse_init_connect(line)
if connparts:
host, port, httpversion = connparts
headers = self.read_headers(authenticate=True)
self.wfile.write(
'HTTP/1.1 200 Connection established\r\n' +
('Proxy-agent: %s\r\n'%self.server_version) +
'\r\n'
)
self.wfile.flush()
dummycert = self.find_cert(client_conn, host, port, host)
sni = HandleSNI(
self, client_conn, host, port,
dummycert, self.config.certfile or self.config.cacert
)
self.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert, handle_sni=sni)
except tcp.NetLibError, v:
raise ProxyError(400, str(v))
self.proxy_connect_state = (host, port, httpversion)
line = self.rfile.readline(line)
try:
self.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert, handle_sni=sni)
except tcp.NetLibError, v:
raise ProxyError(400, str(v))
self.proxy_connect_state = (host, port, httpversion)
line = self.rfile.readline(line)
if self.proxy_connect_state:
r = http.parse_init_http(line)

View File

@ -45,6 +45,12 @@ class CommonMixin:
assert "host" in l.request.headers
assert l.response.code == 304
def test_invalid_http(self):
t = tcp.TCPClient("127.0.0.1", self.proxy.port)
t.connect()
t.wfile.write("invalid\r\n\r\n")
t.wfile.flush()
assert "Bad Request" in t.rfile.readline()
class TestHTTP(tservers.HTTPProxTest, CommonMixin):
@ -54,13 +60,6 @@ class TestHTTP(tservers.HTTPProxTest, CommonMixin):
assert ret.status_code == 500
assert "ValueError" in ret.content
def test_invalid_http(self):
t = tcp.TCPClient("127.0.0.1", self.proxy.port)
t.connect()
t.wfile.write("invalid\n\n")
t.wfile.flush()
assert "Bad Request" in t.rfile.readline()
def test_invalid_connect(self):
t = tcp.TCPClient("127.0.0.1", self.proxy.port)
t.connect()
@ -125,6 +124,25 @@ class TestHTTP(tservers.HTTPProxTest, CommonMixin):
ret = p.request("get:'http://localhost:0'")
assert ret.status_code == 502
def test_blank_leading_line(self):
p = self.pathoc()
req = "get:'%s/p/201':i0,'\r\n'"
assert p.request(req%self.server.urlbase).status_code == 201
def test_invalid_headers(self):
p = self.pathoc()
req = p.request("get:'http://foo':h':foo'='bar'")
print req
class TestHTTPConnectSSLError(tservers.HTTPProxTest):
certfile = True
def test_go(self):
p = self.pathoc()
req = "connect:'localhost:%s'"%self.proxy.port
assert p.request(req).status_code == 200
assert p.request(req).status_code == 400
class TestHTTPS(tservers.HTTPProxTest, CommonMixin):
ssl = True
@ -140,6 +158,11 @@ class TestHTTPS(tservers.HTTPProxTest, CommonMixin):
l = self.server.last_log()
assert self.server.last_log()["request"]["sni"] == "testserver.com"
def test_error_post_connect(self):
p = self.pathoc()
assert p.request("get:/:i0,'invalid\r\n\r\n'").status_code == 400
class TestHTTPSNoUpstream(tservers.HTTPProxTest, CommonMixin):
ssl = True
@ -163,12 +186,10 @@ class TestReverse(tservers.ReverseProxTest, CommonMixin):
class TestTransparent(tservers.TransparentProxTest, CommonMixin):
transparent = True
ssl = False
class TestTransparentSSL(tservers.TransparentProxTest, CommonMixin):
transparent = True
ssl = True
def test_sni(self):
f = self.pathod("304", sni="testserver.com")
@ -176,6 +197,10 @@ class TestTransparentSSL(tservers.TransparentProxTest, CommonMixin):
l = self.server.last_log()
assert self.server.last_log()["request"]["sni"] == "testserver.com"
def test_sslerr(self):
p = pathoc.Pathoc("localhost", self.proxy.port)
p.connect()
assert p.request("get:/").status_code == 400
class TestProxy(tservers.HTTPProxTest):
@ -267,3 +292,19 @@ class TestKillResponse(tservers.HTTPProxTest):
# The server should have seen a request
assert self.server.last_log()
class EResolver(tservers.TResolver):
def original_addr(self, sock):
return None
class TestTransparentResolveError(tservers.TransparentProxTest):
resolver = EResolver
def test_resolve_error(self):
assert self.pathod("304").status_code == 502

View File

@ -131,7 +131,7 @@ class ProxTestBase:
class HTTPProxTest(ProxTestBase):
def pathoc_raw(self):
return libpathod.pathoc.Pathoc("127.0.0.1", self.proxy.port)
def pathoc(self, sni=None):
"""
Returns a connected Pathoc instance.
@ -148,6 +148,7 @@ class HTTPProxTest(ProxTestBase):
Constructs a pathod GET request, with the appropriate base and proxy.
"""
p = self.pathoc(sni=sni)
spec = spec.encode("string_escape")
if self.ssl:
q = "get:'/p/%s'"%spec
else:
@ -165,6 +166,7 @@ class TResolver:
class TransparentProxTest(ProxTestBase):
ssl = None
resolver = TResolver
@classmethod
def get_proxy_config(cls):
d = ProxTestBase.get_proxy_config()
@ -173,7 +175,7 @@ class TransparentProxTest(ProxTestBase):
else:
ports = []
d["transparent_proxy"] = dict(
resolver = TResolver(cls.server.port),
resolver = cls.resolver(cls.server.port),
sslports = ports
)
return d