diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 088fe94c8..d92e2da9c 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -50,36 +50,13 @@ class ProxyConfig: self.certstore = certutils.CertStore(certdir) -class RequestReplayThread(threading.Thread): - def __init__(self, config, flow, masterq): - self.config, self.flow, self.channel = config, flow, controller.Channel(masterq) - threading.Thread.__init__(self) - - def run(self): - try: - r = self.flow.request - server = ServerConnection(self.config, r.host, r.port) - server.connect(r.scheme) - server.send(r) - httpversion, code, msg, headers, content = http.read_response( - server.rfile, r.method, self.config.body_size_limit - ) - response = flow.Response( - self.flow.request, httpversion, code, msg, headers, content, server.cert - ) - self.channel.ask(response) - except (ProxyError, http.HttpError, tcp.NetLibError), v: - err = flow.Error(self.flow.request, str(v)) - self.channel.ask(err) - - class ServerConnection(tcp.TCPClient): def __init__(self, config, host, port): tcp.TCPClient.__init__(self, host, port) self.config = config self.requestcount = 0 - def connect(self, scheme): + def connect(self, scheme, sni): tcp.TCPClient.connect(self) if scheme == "https": clientcert = None @@ -88,7 +65,7 @@ class ServerConnection(tcp.TCPClient): if os.path.exists(path): clientcert = path try: - self.convert_to_ssl(clientcert=clientcert, sni=self.host) + self.convert_to_ssl(cert=clientcert, sni=sni) except tcp.NetLibError, v: raise ProxyError(400, str(v)) @@ -109,12 +86,35 @@ class ServerConnection(tcp.TCPClient): pass +class RequestReplayThread(threading.Thread): + def __init__(self, config, flow, masterq): + self.config, self.flow, self.channel = config, flow, controller.Channel(masterq) + threading.Thread.__init__(self) + + def run(self): + try: + r = self.flow.request + server = ServerConnection(self.config, r.host, r.port) + server.connect(r.scheme, r.host) + server.send(r) + httpversion, code, msg, headers, content = http.read_response( + server.rfile, r.method, self.config.body_size_limit + ) + response = flow.Response( + self.flow.request, httpversion, code, msg, headers, content, server.cert + ) + self.channel.ask(response) + except (ProxyError, http.HttpError, tcp.NetLibError), v: + err = flow.Error(self.flow.request, str(v)) + self.channel.ask(err) + + class ServerConnectionPool: def __init__(self, config): self.config = config self.conn = None - def get_connection(self, scheme, host, port): + def get_connection(self, scheme, host, port, sni): sc = self.conn if self.conn and (host, port) != (sc.host, sc.port): sc.terminate() @@ -122,7 +122,7 @@ class ServerConnectionPool: if not self.conn: try: self.conn = ServerConnection(self.config, host, port) - self.conn.connect(scheme) + self.conn.connect(scheme, sni) except tcp.NetLibError, v: raise ProxyError(502, v) return self.conn @@ -190,18 +190,18 @@ class ProxyHandler(tcp.BaseHandler): # the case, we want to reconnect without sending an error # to the client. while 1: + sc = self.server_conn_pool.get_connection(scheme, host, port, host) + sc.send(request) + sc.rfile.reset_timestamps() try: - sc = self.server_conn_pool.get_connection(scheme, host, port) - sc.send(request) - sc.rfile.reset_timestamps() httpversion, code, msg, headers, content = http.read_response( sc.rfile, request.method, self.config.body_size_limit ) except http.HttpErrorConnClosed, v: + self.server_conn_pool.del_connection(scheme, host, port) if sc.requestcount > 1: - self.server_conn_pool.del_connection(scheme, host, port) continue else: raise @@ -324,25 +324,6 @@ class ProxyHandler(tcp.BaseHandler): self.rfile.first_byte_timestamp, utils.timestamp() ) - def read_request_reverse(self, client_conn): - line = self.get_line(self.rfile) - if line == "": - return None - scheme, host, port = self.config.reverse_proxy - r = http.parse_init_http(line) - if not r: - raise ProxyError(400, "Bad HTTP request line: %s"%repr(line)) - method, path, httpversion = r - headers = self.read_headers(authenticate=False) - content = http.read_http_body_request( - self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit - ) - return flow.Request( - client_conn, httpversion, host, port, "http", method, path, headers, content, - self.rfile.first_byte_timestamp, utils.timestamp() - ) - - def read_request_proxy(self, client_conn): line = self.get_line(self.rfile) if line == "": @@ -398,6 +379,24 @@ class ProxyHandler(tcp.BaseHandler): self.rfile.first_byte_timestamp, utils.timestamp() ) + def read_request_reverse(self, client_conn): + line = self.get_line(self.rfile) + if line == "": + return None + scheme, host, port = self.config.reverse_proxy + r = http.parse_init_http(line) + if not r: + raise ProxyError(400, "Bad HTTP request line: %s"%repr(line)) + method, path, httpversion = r + headers = self.read_headers(authenticate=False) + content = http.read_http_body_request( + self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit + ) + return flow.Request( + client_conn, httpversion, host, port, "http", method, path, headers, content, + self.rfile.first_byte_timestamp, utils.timestamp() + ) + def read_request(self, client_conn): self.rfile.reset_timestamps() if self.config.transparent_proxy: diff --git a/test/test_proxy.py b/test/test_proxy.py index bdac8697a..b575a1d0d 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -40,7 +40,7 @@ class TestServerConnection: def test_simple(self): sc = proxy.ServerConnection(proxy.ProxyConfig(), self.d.IFACE, self.d.port) - sc.connect("http") + sc.connect("http", "host.com") r = tutils.treq() r.path = "/p/200:da" sc.send(r) @@ -54,7 +54,7 @@ class TestServerConnection: def test_terminate_error(self): sc = proxy.ServerConnection(proxy.ProxyConfig(), self.d.IFACE, self.d.port) - sc.connect("http") + sc.connect("http", "host.com") sc.connection = mock.Mock() sc.connection.close = mock.Mock(side_effect=IOError) sc.terminate() @@ -75,14 +75,14 @@ class TestServerConnectionPool: @mock.patch("libmproxy.proxy.ServerConnection", _dummysc) def test_pooling(self): p = proxy.ServerConnectionPool(proxy.ProxyConfig()) - c = p.get_connection("http", "localhost", 80) - c2 = p.get_connection("http", "localhost", 80) + c = p.get_connection("http", "localhost", 80, "localhost") + c2 = p.get_connection("http", "localhost", 80, "localhost") assert c is c2 - c3 = p.get_connection("http", "foo", 80) + c3 = p.get_connection("http", "foo", 80, "localhost") assert not c is c3 @mock.patch("libmproxy.proxy.ServerConnection", _errsc) def test_connection_error(self): p = proxy.ServerConnectionPool(proxy.ProxyConfig()) - tutils.raises("502", p.get_connection, "http", "localhost", 80) + tutils.raises("502", p.get_connection, "http", "localhost", 80, "localhost")