From 02578151410fff4b3c018303290e2f843e244a89 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 24 Feb 2013 22:24:21 +1300 Subject: [PATCH] Significantly simplify server connection handling, and test. --- libmproxy/proxy.py | 66 ++++++++++++++++++++++++--------------------- test/test_proxy.py | 35 +++--------------------- test/test_server.py | 28 ++++++++++++++----- test/tservers.py | 23 ++++++++-------- 4 files changed, 73 insertions(+), 79 deletions(-) diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index d92e2da9c..7c229064d 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -51,21 +51,22 @@ class ProxyConfig: class ServerConnection(tcp.TCPClient): - def __init__(self, config, host, port): + def __init__(self, config, scheme, host, port, sni): tcp.TCPClient.__init__(self, host, port) self.config = config + self.scheme, self.sni = scheme, sni self.requestcount = 0 - def connect(self, scheme, sni): + def connect(self): tcp.TCPClient.connect(self) - if scheme == "https": + if self.scheme == "https": clientcert = None if self.config.clientcerts: path = os.path.join(self.config.clientcerts, self.host.encode("idna")) + ".pem" if os.path.exists(path): clientcert = path try: - self.convert_to_ssl(cert=clientcert, sni=sni) + self.convert_to_ssl(cert=clientcert, sni=self.sni) except tcp.NetLibError, v: raise ProxyError(400, str(v)) @@ -94,8 +95,8 @@ class RequestReplayThread(threading.Thread): def run(self): try: r = self.flow.request - server = ServerConnection(self.config, r.host, r.port) - server.connect(r.scheme, r.host) + server = ServerConnection(self.config, r.scheme, r.host, r.port, r.host) + server.connect() server.send(r) httpversion, code, msg, headers, content = http.read_response( server.rfile, r.method, self.config.body_size_limit @@ -109,37 +110,40 @@ class RequestReplayThread(threading.Thread): self.channel.ask(err) -class ServerConnectionPool: - def __init__(self, config): - self.config = config - self.conn = None - - def get_connection(self, scheme, host, port, sni): - sc = self.conn - if self.conn and (host, port) != (sc.host, sc.port): - sc.terminate() - self.conn = None - if not self.conn: - try: - self.conn = ServerConnection(self.config, host, port) - self.conn.connect(scheme, sni) - except tcp.NetLibError, v: - raise ProxyError(502, v) - return self.conn - - def del_connection(self, scheme, host, port): - self.conn = None - - class ProxyHandler(tcp.BaseHandler): def __init__(self, config, connection, client_address, server, channel, server_version): self.channel, self.server_version = channel, server_version self.config = config - self.server_conn_pool = ServerConnectionPool(config) self.proxy_connect_state = None self.sni = None + self.server_conn = None tcp.BaseHandler.__init__(self, connection, client_address, server) + def get_server_connection(self, cc, scheme, host, port, sni): + sc = self.server_conn + if sc and (scheme, host, port, sni) != (sc.scheme, sc.host, sc.port, sc.sni): + sc.terminate() + self.server_conn = None + self.log( + cc, + "switching connection", [ + "%s://%s:%s (sni=%s) -> %s://%s:%s (sni=%s)"%( + scheme, host, port, sni, + sc.scheme, sc.host, sc.port, sc.sni + ) + ] + ) + if not self.server_conn: + try: + self.server_conn = ServerConnection(self.config, scheme, host, port, sni) + self.server_conn.connect() + except tcp.NetLibError, v: + raise ProxyError(502, v) + return self.server_conn + + def del_server_connection(self): + self.server_conn = None + def handle(self): cc = flow.ClientConnect(self.client_address) self.log(cc, "connect") @@ -190,7 +194,7 @@ class ProxyHandler(tcp.BaseHandler): # the case, we want to reconnect without sending an error # to the client. while 1: - sc = self.server_conn_pool.get_connection(scheme, host, port, host) + sc = self.get_server_connection(cc, scheme, host, port, host) sc.send(request) sc.rfile.reset_timestamps() try: @@ -200,7 +204,7 @@ class ProxyHandler(tcp.BaseHandler): self.config.body_size_limit ) except http.HttpErrorConnClosed, v: - self.server_conn_pool.del_connection(scheme, host, port) + self.del_server_connection() if sc.requestcount > 1: continue else: diff --git a/test/test_proxy.py b/test/test_proxy.py index b575a1d0d..3995b393a 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -39,8 +39,8 @@ class TestServerConnection: self.d.shutdown() def test_simple(self): - sc = proxy.ServerConnection(proxy.ProxyConfig(), self.d.IFACE, self.d.port) - sc.connect("http", "host.com") + sc = proxy.ServerConnection(proxy.ProxyConfig(), "http", self.d.IFACE, self.d.port, "host.com") + sc.connect() r = tutils.treq() r.path = "/p/200:da" sc.send(r) @@ -53,36 +53,9 @@ class TestServerConnection: sc.terminate() def test_terminate_error(self): - sc = proxy.ServerConnection(proxy.ProxyConfig(), self.d.IFACE, self.d.port) - sc.connect("http", "host.com") + sc = proxy.ServerConnection(proxy.ProxyConfig(), "http", self.d.IFACE, self.d.port, "host.com") + sc.connect() sc.connection = mock.Mock() sc.connection.close = mock.Mock(side_effect=IOError) sc.terminate() - - -def _dummysc(config, host, port): - return mock.MagicMock(config=config, host=host, port=port) - - -def _errsc(config, host, port): - m = mock.MagicMock(config=config, host=host, port=port) - m.connect = mock.MagicMock(side_effect=tcp.NetLibError()) - return m - - -class TestServerConnectionPool: - @mock.patch("libmproxy.proxy.ServerConnection", _dummysc) - def test_pooling(self): - p = proxy.ServerConnectionPool(proxy.ProxyConfig()) - c = p.get_connection("http", "localhost", 80, "localhost") - c2 = p.get_connection("http", "localhost", 80, "localhost") - assert c is c2 - c3 = p.get_connection("http", "foo", 80, "localhost") - assert not c is c3 - - @mock.patch("libmproxy.proxy.ServerConnection", _errsc) - def test_connection_error(self): - p = proxy.ServerConnectionPool(proxy.ProxyConfig()) - tutils.raises("502", p.get_connection, "http", "localhost", 80, "localhost") - diff --git a/test/test_server.py b/test/test_server.py index 924b63b79..f93ddbb38 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -85,7 +85,7 @@ class TestHTTP(tservers.HTTPProxTest, SanityMixin): def test_connection_close(self): # Add a body, so we have a content-length header, which combined with # HTTP1.1 means the connection is kept alive. - response = '%s/p/200:b@1'%self.urlbase + response = '%s/p/200:b@1'%self.server.urlbase # Lets sanity check that the connection does indeed stay open by # issuing two requests over the same connection @@ -99,7 +99,7 @@ class TestHTTP(tservers.HTTPProxTest, SanityMixin): tutils.raises("disconnect", p.request, "get:'%s'"%response) def test_reconnect(self): - req = "get:'%s/p/200:b@1:da'"%self.urlbase + req = "get:'%s/p/200:b@1:da'"%self.server.urlbase p = self.pathoc() assert p.request(req) # Server has disconnected. Mitmproxy should detect this, and reconnect. @@ -107,7 +107,7 @@ class TestHTTP(tservers.HTTPProxTest, SanityMixin): assert p.request(req) # However, if the server disconnects on our first try, it's an error. - req = "get:'%s/p/200:b@1:d0'"%self.urlbase + req = "get:'%s/p/200:b@1:d0'"%self.server.urlbase p = self.pathoc() tutils.raises("server disconnect", p.request, req) @@ -118,13 +118,29 @@ class TestHTTP(tservers.HTTPProxTest, SanityMixin): m.side_effect = IOError("error!") tutils.raises("empty reply", self.pathod, "304") + def test_get_connection_switching(self): + def switched(l): + for i in l: + if "switching" in i: + return True + req = "get:'%s/p/200:b@1'" + p = self.pathoc() + assert p.request(req%self.server.urlbase) + assert p.request(req%self.server2.urlbase) + assert switched(self.proxy.log) + + def test_get_connection_err(self): + p = self.pathoc() + ret = p.request("get:'http://localhost:0'") + assert ret[1] == 502 + class TestHTTPS(tservers.HTTPProxTest, SanityMixin): ssl = True clientcerts = True def test_clientcert(self): f = self.pathod("304") - assert self.last_log()["request"]["clientcert"]["keyinfo"] + assert self.server.last_log()["request"]["clientcert"]["keyinfo"] class TestReverse(tservers.ReverseProxTest, SanityMixin): @@ -211,7 +227,7 @@ class TestKillRequest(tservers.HTTPProxTest): p = self.pathoc() tutils.raises("empty reply", self.pathod, "200") # Nothing should have hit the server - assert not self.last_log() + assert not self.server.last_log() class MasterKillResponse(tservers.TestMaster): @@ -225,5 +241,5 @@ class TestKillResponse(tservers.HTTPProxTest): p = self.pathoc() tutils.raises("empty reply", self.pathod, "200") # The server should have seen a request - assert self.last_log() + assert self.server.last_log() diff --git a/test/tservers.py b/test/tservers.py index 262536a77..9597dab40 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -28,6 +28,7 @@ class TestMaster(flow.FlowMaster): state = flow.State() flow.FlowMaster.__init__(self, s, state) self.testq = testq + self.log = [] def handle_request(self, m): flow.FlowMaster.handle_request(self, m) @@ -37,6 +38,10 @@ class TestMaster(flow.FlowMaster): flow.FlowMaster.handle_response(self, m) m.reply() + def handle_log(self, l): + self.log.append(l.msg) + l.reply() + class ProxyThread(threading.Thread): def __init__(self, tmaster): @@ -48,6 +53,10 @@ class ProxyThread(threading.Thread): def port(self): return self.tmaster.server.port + @property + def log(self): + return self.tmaster.log + def run(self): self.tmaster.run() @@ -61,6 +70,7 @@ class ProxTestBase: def setupAll(cls): cls.tqueue = Queue.Queue() cls.server = libpathod.test.Daemon(ssl=cls.ssl) + cls.server2 = libpathod.test.Daemon(ssl=cls.ssl) pconf = cls.get_proxy_config() config = proxy.ProxyConfig( cacert = tutils.test_data.path("data/serverkey.pem"), @@ -78,6 +88,7 @@ class ProxTestBase: def teardownAll(cls): cls.proxy.shutdown() cls.server.shutdown() + cls.server2.shutdown() def setUp(self): self.master.state.clear() @@ -95,16 +106,6 @@ class ProxTestBase: (self.scheme, ("127.0.0.1", self.proxy.port)) ) - @property - def urlbase(self): - """ - The URL base for the server instance. - """ - return self.server.urlbase - - def last_log(self): - return self.server.last_log() - class HTTPProxTest(ProxTestBase): ssl = None @@ -129,7 +130,7 @@ class HTTPProxTest(ProxTestBase): Constructs a pathod request, with the appropriate base and proxy. """ return hurl.get( - self.urlbase + "/p/" + spec, + self.server.urlbase + "/p/" + spec, proxy=self.proxies, validate_cert=False, #debug=hurl.utils.stdout_debug