diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index f14e4e3ef..3bbb82ba3 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -107,12 +107,30 @@ class ServerConnection(tcp.TCPClient): except IOError: pass +class ServerConnectionPool: + def __init__(self, config): + self.config = config + self.conn = None + + def get_connection(self, scheme, host, port): + sc = self.conn + if self.conn and (host, port) != (sc.host, sc.port): + sc.terminate() + self.conn = None + if not self.conn: + try: + self.conn = ServerConnection(self.config, host, port) + self.conn.connect(scheme) + except tcp.NetLibError, v: + raise ProxyError(502, v) + return self.conn + class ProxyHandler(tcp.BaseHandler): def __init__(self, config, connection, client_address, server, mqueue, server_version): self.mqueue, self.server_version = mqueue, server_version self.config = config - self.server_conn = None + self.server_conn_pool = ServerConnectionPool(config) self.proxy_connect_state = None self.sni = None tcp.BaseHandler.__init__(self, connection, client_address, server) @@ -133,18 +151,6 @@ class ProxyHandler(tcp.BaseHandler): ) cd._send(self.mqueue) - def server_connect(self, scheme, host, port): - sc = self.server_conn - if sc and (host, port) != (sc.host, sc.port): - sc.terminate() - self.server_conn = None - if not self.server_conn: - try: - self.server_conn = ServerConnection(self.config, host, port) - self.server_conn.connect(scheme) - except tcp.NetLibError, v: - raise ProxyError(502, v) - def handle_request(self, cc): try: request, err = None, None @@ -173,21 +179,21 @@ class ProxyHandler(tcp.BaseHandler): scheme, host, port = self.config.reverse_proxy else: scheme, host, port = request.scheme, request.host, request.port - self.server_connect(scheme, host, port) - self.server_conn.send(request) - self.server_conn.rfile.reset_timestamps() + sc = self.server_conn_pool.get_connection(scheme, host, port) + sc.send(request) + sc.rfile.reset_timestamps() httpversion, code, msg, headers, content = http.read_response( - self.server_conn.rfile, + sc.rfile, request.method, self.config.body_size_limit ) response = flow.Response( - request, httpversion, code, msg, headers, content, self.server_conn.cert, self.server_conn.rfile.first_byte_timestamp, utils.timestamp() + request, httpversion, code, msg, headers, content, sc.cert, + sc.rfile.first_byte_timestamp, utils.timestamp() ) - response = response._send(self.mqueue) if response is None: - self.server_conn.terminate() + sc.terminate() if response is None: return self.send_response(response) @@ -310,7 +316,7 @@ class ProxyHandler(tcp.BaseHandler): self.rfile.first_byte_timestamp, utils.timestamp() ) - + def read_request_proxy(self, client_conn): line = self.get_line(self.rfile) if line == "":