diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index f94aade30..938b5d5cf 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -263,9 +263,10 @@ class RequestReplayThread(threading.Thread): def run(self): try: - server = ServerConnection(self.config, self.flow.request) - server.send() - response = server.read_response() + r = self.flow.request + server = ServerConnection(self.config, r.scheme, r.host, r.port) + server.send(r) + response = server.read_response(r) response._send(self.masterq) except ProxyError, v: err = flow.Error(self.flow.request, v.msg) @@ -273,14 +274,8 @@ class RequestReplayThread(threading.Thread): class ServerConnection: - def __init__(self, config, request): - self.config, self.request = config, request - if config.reverse_proxy: - self.scheme, self.host, self.port = config.reverse_proxy - else: - self.host = request.host - self.port = request.port - self.scheme = request.scheme + def __init__(self, config, scheme, host, port): + self.config, self.scheme, self.host, self.port = config, scheme, host, port self.close = False self.cert = None self.sock, self.rfile, self.wfile = None, None, None @@ -298,7 +293,6 @@ class ServerConnection: else: clientcert = None server = ssl.wrap_socket(server, certfile = clientcert) - server.connect((addr, self.port)) if self.scheme == "https": self.cert = server.getpeercert(True) @@ -307,18 +301,18 @@ class ServerConnection: self.sock = server self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb') - def send(self): - self.request.close = self.close + def send(self, request): + request.close = self.close try: - d = self.request._assemble() + d = request._assemble() if not d: raise ProxyError(502, "Incomplete request could not not be readied for transmission.") self.wfile.write(d) self.wfile.flush() except socket.error, err: - raise ProxyError(502, 'Error sending data to "%s": %s' % (self.request.host, err)) + raise ProxyError(502, 'Error sending data to "%s": %s' % (request.host, err)) - def read_response(self): + def read_response(self, request): line = self.rfile.readline() if line == "\r\n" or line == "\n": # Possible leftover from previous message line = self.rfile.readline() @@ -337,11 +331,11 @@ class ServerConnection: headers = read_headers(self.rfile) if code >= 100 and code <= 199: return self.read_response() - if self.request.method == "HEAD" or code == 204 or code == 304: + if request.method == "HEAD" or code == 204 or code == 304: content = "" else: content = read_http_body(self.rfile, self, headers, True, self.config.body_size_limit) - return flow.Response(self.request, code, msg, headers, content, self.cert) + return flow.Response(request, code, msg, headers, content, self.cert) def terminate(self): try: @@ -393,10 +387,14 @@ class ProxyHandler(SocketServer.StreamRequestHandler): request = False response = response._send(self.mqueue) else: - server_conn = ServerConnection(self.config, request) - server_conn.send() + if self.config.reverse_proxy: + scheme, host, port = self.config.reverse_proxy + server_conn = ServerConnection(self.config, scheme, host, port) + else: + server_conn = ServerConnection(self.config, request.scheme, request.host, request.port) + server_conn.send(request) try: - response = server_conn.read_response() + response = server_conn.read_response(request) except IOError, v: raise IOError, "Reading response: %s"%v response = response._send(self.mqueue)