add tests for reconnect to upstream proxy, ensure that server_reconnect is always hooked

This commit is contained in:
Maximilian Hils 2014-02-07 18:14:15 +01:00
parent 545fc2506b
commit 735e4400c4
5 changed files with 137 additions and 41 deletions

View File

@ -883,8 +883,7 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin):
self.process_request(flow.request)
request_reply = self.c.channel.ask("request", flow.request)
flow.server_conn = self.c.server_conn # no further manipulation of self.c.server_conn beyond this point.
# we can safely set it as the final attribute valueh here.
flow.server_conn = self.c.server_conn
if request_reply is None or request_reply == KILL:
return False
@ -894,14 +893,15 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin):
else:
flow.response = self.get_response_from_server(flow.request)
flow.server_conn = self.c.server_conn # no further manipulation of self.c.server_conn beyond this point
# we can safely set it as the final attribute value here.
self.c.log("response", [flow.response._assemble_first_line()])
response_reply = self.c.channel.ask("response", flow.response)
if response_reply is None or response_reply == KILL:
return False
raw = flow.response._assemble()
self.c.client_conn.wfile.write(raw)
self.c.client_conn.wfile.flush()
self.c.client_conn.send(flow.response._assemble())
flow.timestamp_end = utils.timestamp()
if (http.connection_close(flow.request.httpversion, flow.request.headers) or
@ -909,7 +909,7 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin):
return False
if flow.request.form_in == "authority":
self.ssl_upgrade(flow.request)
self.ssl_upgrade()
self.restore_server() # If the user has changed the target server on this connection,
# restore the original target server
@ -968,7 +968,27 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin):
self.c.client_conn.wfile.write(html_content)
self.c.client_conn.wfile.flush()
def ssl_upgrade(self, upstream_request=None):
def hook_reconnect(self, upstream_request):
self.c.log("Hook reconnect function")
original_reconnect_func = self.c.server_reconnect
def reconnect_http_proxy():
self.c.log("Hooked reconnect function")
self.c.log("Hook: Run original reconnect")
original_reconnect_func(no_ssl=True)
self.c.log("Hook: Write CONNECT request to upstream proxy", [upstream_request._assemble_first_line()])
self.c.server_conn.send(upstream_request._assemble())
self.c.log("Hook: Read answer to CONNECT request from proxy")
resp = HTTPResponse.from_stream(self.c.server_conn.rfile, upstream_request.method)
if resp.code != 200:
raise ProxyError(resp.code,
"Cannot reestablish SSL connection with upstream proxy: \r\n" + str(resp.headers))
self.c.log("Hook: Establish SSL with upstream proxy")
self.c.establish_ssl(server=True)
self.c.server_reconnect = reconnect_http_proxy
def ssl_upgrade(self):
"""
Upgrade the connection to SSL after an authority (CONNECT) request has been made.
If the authority request has been forwarded upstream (because we have another proxy server there),
@ -981,28 +1001,6 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin):
self.c.mode = "transparent"
self.c.determine_conntype()
self.c.establish_ssl(server=True, client=True)
if upstream_request:
self.c.log("Hook reconnect function")
original_reconnect_func = self.c.server_reconnect
def reconnect_http_proxy():
self.c.log("Hooked reconnect function")
self.c.log("Hook: Run original reconnect")
original_reconnect_func(no_ssl=True)
self.c.log("Hook: Write CONNECT request to upstream proxy", [upstream_request._assemble_first_line()])
self.c.server_conn.wfile.write(upstream_request._assemble())
self.c.server_conn.wfile.flush()
self.c.log("Hook: Read answer to CONNECT request from proxy")
resp = HTTPResponse.from_stream(self.c.server_conn.rfile, upstream_request.method)
if resp.code != 200:
raise ProxyError(resp.code,
"Cannot reestablish SSL connection with upstream proxy: \r\n" + str(resp.headers))
self.c.log("Hook: Establish SSL with upstream proxy")
self.c.establish_ssl(server=True)
self.c.server_reconnect = reconnect_http_proxy
self.c.log("Upgrade to SSL completed.")
raise ConnectionTypeChange
@ -1028,7 +1026,7 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin):
if self.c.mode == "regular":
if request.form_in == "authority": # forward mode
pass
self.hook_reconnect(request)
elif request.form_in == "absolute":
if request.scheme != "http":
raise http.HttpError(400, "Invalid Request")
@ -1037,7 +1035,7 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin):
self.c.set_server_address((request.host, request.port), AddressPriority.FROM_PROTOCOL)
request.flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow
else:
raise http.HttpError(400, "Invalid Request")
raise http.HttpError(400, "Invalid request form (absolute-form or authority-form required)")
def authenticate(self, request):
if self.c.config.authenticator:

View File

@ -50,6 +50,9 @@ class Error(stateobject.SimpleStateObject):
timestamp=float
)
def __str__(self):
return self.msg
@classmethod
def _from_state(cls, state):
f = cls(None) # the default implementation assumes an empty constructor. Override accordingly.

View File

@ -89,6 +89,10 @@ class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject):
def copy(self):
return copy.copy(self)
def send(self, message):
self.wfile.write(message)
self.wfile.flush()
@classmethod
def _from_state(cls, state):
f = cls(None, tuple(), None)

View File

@ -1,5 +1,6 @@
from libmproxy import proxy # FIXME: Remove
from libmproxy.protocol.http import *
from libmproxy.protocol import KILL
from cStringIO import StringIO
import tutils, tservers
@ -86,6 +87,23 @@ class TestHTTPResponse:
tutils.raises("Invalid server response: 'content", HTTPResponse.from_stream, s, "GET")
class TestInvalidRequests(tservers.HTTPProxTest):
ssl = True
def test_double_connect(self):
p = self.pathoc()
r = p.request("connect:'%s:%s'" % ("127.0.0.1", self.server2.port))
assert r.status_code == 502
assert "Must not CONNECT on already encrypted connection" in r.content
def test_origin_request(self):
p = self.pathoc_raw()
p.connect()
r = p.request("get:/p/200")
assert r.status_code == 400
assert "Invalid request form" in r.content
class TestProxyChaining(tservers.HTTPChainProxyTest):
def test_all(self):
self.chain[1].tmaster.replacehooks.add("~q", "foo", "bar") # replace in request
@ -98,16 +116,83 @@ class TestProxyChaining(tservers.HTTPChainProxyTest):
assert req.content == "ORLY"
assert req.status_code == 418
self.chain[0].tmaster.replacehooks.clear()
self.chain[1].tmaster.replacehooks.clear()
self.proxy.tmaster.replacehooks.clear()
class TestProxyChainingSSL(tservers.HTTPChainProxyTest):
ssl = True
def test_full(self):
def test_simple(self):
p = self.pathoc()
req = p.request("get:'/p/418:b@100'")
assert len(req.content) == 100
assert req.status_code == 418
req = p.request("get:'/p/418:b\"content\"'")
assert req.content == "content"
assert req.status_code == 418
assert self.chain[1].tmaster.state.flow_count() == 2 # CONNECT from pathoc to chain[0],
# request from pathoc to chain[0]
assert self.chain[0].tmaster.state.flow_count() == 2 # CONNECT from chain[1] to proxy,
# request from chain[1] to proxy
assert self.proxy.tmaster.state.flow_count() == 1 # request from chain[0] (regular proxy doesn't store CONNECTs)
def test_reconnect(self):
"""
Tests proper functionality of ConnectionHandler.server_reconnect mock.
If we have a disconnect on a secure connection that's transparently proxified to
an upstream http proxy, we need to send the CONNECT request again.
"""
def kill_requests(master, attr, exclude):
k = [0] # variable scope workaround: put into array
_func = getattr(master, attr)
def handler(r):
k[0] += 1
if not (k[0] in exclude):
r.flow.client_conn.finish()
r.flow.error = Error("terminated")
r.reply(KILL)
return _func(r)
setattr(master, attr, handler)
kill_requests(self.proxy.tmaster, "handle_request",
exclude=[
# fail first request
2, # allow second request
])
kill_requests(self.chain[0].tmaster, "handle_request",
exclude=[
1, # CONNECT
# fail first request
3, # reCONNECT
4, # request
])
p = self.pathoc()
req = p.request("get:'/p/418:b\"content\"'")
assert self.chain[1].tmaster.state.flow_count() == 2 # CONNECT and request
assert self.chain[0].tmaster.state.flow_count() == 4 # CONNECT, failing request,
# reCONNECT, request
assert self.proxy.tmaster.state.flow_count() == 2 # failing request, request
# (doesn't store (repeated) CONNECTs from chain[0]
# as it is a regular proxy)
assert req.content == "content"
assert req.status_code == 418
assert not self.proxy.tmaster.state._flow_list[0].response # killed
assert self.proxy.tmaster.state._flow_list[1].response
assert self.chain[1].tmaster.state._flow_list[0].request.form_in == "authority"
assert self.chain[1].tmaster.state._flow_list[1].request.form_in == "origin"
assert self.chain[0].tmaster.state._flow_list[0].request.form_in == "authority"
assert self.chain[0].tmaster.state._flow_list[1].request.form_in == "origin"
assert self.chain[0].tmaster.state._flow_list[2].request.form_in == "authority"
assert self.chain[0].tmaster.state._flow_list[3].request.form_in == "origin"
assert self.proxy.tmaster.state._flow_list[0].request.form_in == "origin"
assert self.proxy.tmaster.state._flow_list[1].request.form_in == "origin"
req = p.request("get:'/p/418:b\"content2\"'")
assert req.status_code == 502
assert self.chain[1].tmaster.state.flow_count() == 3 # + new request
assert self.chain[0].tmaster.state.flow_count() == 6 # + new request, repeated CONNECT from chain[1]
# (both terminated)
assert self.proxy.tmaster.state.flow_count() == 2 # nothing happened here

View File

@ -254,7 +254,6 @@ class ChainProxTest(ProxTestBase):
Chain n instances of mitmproxy in a row - because we can.
"""
n = 2
chain = []
chain_config = [lambda: proxy.ProxyConfig(
cacert = tutils.test_data.path("data/serverkey.pem")
)] * n
@ -262,6 +261,7 @@ class ChainProxTest(ProxTestBase):
@classmethod
def setupAll(cls):
super(ChainProxTest, cls).setupAll()
cls.chain = []
for i in range(cls.n):
config = cls.chain_config[i]()
config.forward_proxy = ("http", "127.0.0.1",
@ -278,6 +278,12 @@ class ChainProxTest(ProxTestBase):
for p in cls.chain:
p.tmaster.server.shutdown()
def setUp(self):
super(ChainProxTest, self).setUp()
for p in self.chain:
p.tmaster.clear_log()
p.tmaster.state.clear()
class HTTPChainProxyTest(ChainProxTest):
def pathoc(self, sni=None):