diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 7698a61ff..dbe91e7eb 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -104,7 +104,7 @@ def read_chunked(fp, limit): return content -def read_http_body(rfile, connection, headers, all, limit): +def read_http_body(rfile, client_conn, headers, all, limit): if 'transfer-encoding' in headers: if not ",".join(headers["transfer-encoding"]).lower() == "chunked": raise IOError('Invalid transfer-encoding') @@ -121,7 +121,7 @@ def read_http_body(rfile, connection, headers, all, limit): content = rfile.read(l) elif all: content = rfile.read(limit if limit else None) - connection.close = True + client_conn.close = True else: content = "" return content @@ -203,6 +203,18 @@ def should_connection_close(httpversion, headers): return True +def read_http_body_request(rfile, wfile, client_conn, headers, httpversion, limit): + if "expect" in headers: + # FIXME: Should be forwarded upstream + expect = ",".join(headers['expect']) + if expect == "100-continue" and httpversion >= (1, 1): + wfile.write('HTTP/1.1 100 Continue\r\n') + wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION) + wfile.write('\r\n') + del headers['expect'] + return read_http_body(rfile, client_conn, headers, False, limit) + + class FileLike: def __init__(self, o): self.o = o @@ -262,10 +274,10 @@ class RequestReplayThread(threading.Thread): class ServerConnection: def __init__(self, config, scheme, host, port): self.config, self.scheme, self.host, self.port = config, scheme, host, port - self.close = False self.cert = None self.sock, self.rfile, self.wfile = None, None, None self.connect() + self.requestcount = 0 def connect(self): try: @@ -288,6 +300,7 @@ class ServerConnection: self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb') def send(self, request): + self.requestcount += 1 try: d = request._assemble() if not d: @@ -336,29 +349,38 @@ class ServerConnection: class ProxyHandler(SocketServer.StreamRequestHandler): def __init__(self, config, request, client_address, server, q): - self.config = config self.mqueue = q + self.config = config self.server_conn = None + self.proxy_connect_state = None SocketServer.StreamRequestHandler.__init__(self, request, client_address, server) def handle(self): cc = flow.ClientConnect(self.client_address) cc._send(self.mqueue) - while not cc.close: - self.handle_request(cc) + while self.handle_request(cc) and not cc.close: + pass + cc.close = True cd = flow.ClientDisconnect(cc) cd._send(self.mqueue) self.finish() + def server_connect(self, scheme, host, port): + sc = self.server_conn + if sc and (scheme, host, port) != (sc.scheme, sc.host, sc.port): + sc.terminate() + self.server_conn = None + if not self.server_conn: + self.server_conn = ServerConnection(self.config, scheme, host, port) + def handle_request(self, cc): - server_conn, request, err = None, None, None try: + request, err = None, None try: request = self.read_request(cc) except IOError, v: raise IOError, "Reading request: %s"%v if request is None: - cc.close = True return cc.requestcount += 1 @@ -368,7 +390,6 @@ class ProxyHandler(SocketServer.StreamRequestHandler): else: request = request._send(self.mqueue) if request is None: - cc.close = True return if isinstance(request, flow.Response): @@ -380,31 +401,30 @@ class ProxyHandler(SocketServer.StreamRequestHandler): scheme, host, port = self.config.reverse_proxy else: scheme, host, port = request.scheme, request.host, request.port - server_conn = ServerConnection(self.config, scheme, host, port) - server_conn.send(request) + self.server_connect(scheme, host, port) + self.server_conn.send(request) try: - response = server_conn.read_response(request) + response = self.server_conn.read_response(request) except IOError, v: raise IOError, "Reading response: %s"%v response = response._send(self.mqueue) if response is None: - server_conn.terminate() + self.server_conn.terminate() if response is None: - cc.close = True return self.send_response(response) + if should_connection_close(request.httpversion, request.headers): + return except IOError, v: cc.connection_error = v - cc.close = True except ProxyError, e: - cc.close = True cc.connection_error = "%s: %s"%(e.code, e.msg) if request: err = flow.Error(request, e.msg) err._send(self.mqueue) self.send_error(e.code, e.msg) - if server_conn: - server_conn.terminate() + else: + return True def find_cert(self, host, port): if self.config.certfile: @@ -435,26 +455,6 @@ class ProxyHandler(SocketServer.StreamRequestHandler): self.rfile = FileLike(self.connection) self.wfile = FileLike(self.connection) - def read_contents(self, client_conn, headers, httpversion): - if "expect" in headers: - # FIXME: Should be forwarded upstream - expect = ",".join(headers['expect']) - if expect == "100-continue" and httpversion >= (1, 1): - self.wfile.write('HTTP/1.1 100 Continue\r\n') - self.wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION) - self.wfile.write('\r\n') - del headers['expect'] - if httpversion < (1, 1): - client_conn.close = True - if "connection" in headers: - for value in ",".join(headers['connection']).split(","): - value = value.strip() - if value == "close": - client_conn.close = True - if value == "keep-alive": - client_conn.close = False - return read_http_body(self.rfile, client_conn, headers, False, self.config.body_size_limit) - def read_request(self, client_conn): line = self.rfile.readline() if line == "\r\n" or line == "\n": # Possible leftover from previous message @@ -466,34 +466,45 @@ class ProxyHandler(SocketServer.StreamRequestHandler): scheme, host, port = self.config.reverse_proxy method, path, httpversion = parse_init_http(line) headers = read_headers(self.rfile) - content = self.read_contents(client_conn, headers, httpversion) + content = read_http_body_request( + self.rfile, self.wfile, client_conn, headers, httpversion, self.config.body_size_limit + ) return flow.Request(client_conn, httpversion, host, port, "http", method, path, headers, content) - elif line.startswith("CONNECT"): - host, port, httpversion = parse_init_connect(line) - # FIXME: Discard additional headers sent to the proxy. Should I expose - # these to users? - while 1: - d = self.rfile.readline() - if d == '\r\n' or d == '\n': - break - self.wfile.write( - 'HTTP/1.1 200 Connection established\r\n' + - ('Proxy-agent: %s\r\n'%version.NAMEVERSION) + - '\r\n' - ) - self.wfile.flush() - certfile = self.find_cert(host, port) - self.convert_to_ssl(certfile) - - method, path, httpversion = parse_init_http(self.rfile.readline(line)) - headers = read_headers(self.rfile) - content = self.read_contents(client_conn, headers, httpversion) - return flow.Request(client_conn, httpversion, host, port, "https", method, path, headers, content) else: - method, scheme, host, port, path, httpversion = parse_init_proxy(line) - headers = read_headers(self.rfile) - content = self.read_contents(client_conn, headers, httpversion) - return flow.Request(client_conn, httpversion, host, port, scheme, method, path, headers, content) + if line.startswith("CONNECT"): + host, port, httpversion = parse_init_connect(line) + # FIXME: Discard additional headers sent to the proxy. Should I expose + # these to users? + while 1: + d = self.rfile.readline() + if d == '\r\n' or d == '\n': + break + self.wfile.write( + 'HTTP/1.1 200 Connection established\r\n' + + ('Proxy-agent: %s\r\n'%version.NAMEVERSION) + + '\r\n' + ) + self.wfile.flush() + certfile = self.find_cert(host, port) + self.convert_to_ssl(certfile) + self.proxy_connect_state = (host, port, httpversion) + line = self.rfile.readline(line) + + if self.proxy_connect_state: + host, port, httpversion = self.proxy_connect_state + method, path, httpversion = parse_init_http(line) + headers = read_headers(self.rfile) + content = read_http_body_request( + self.rfile, self.wfile, client_conn, headers, httpversion, self.config.body_size_limit + ) + return flow.Request(client_conn, httpversion, host, port, "https", method, path, headers, content) + else: + method, scheme, host, port, path, httpversion = parse_init_proxy(line) + headers = read_headers(self.rfile) + content = read_http_body_request( + self.rfile, self.wfile, client_conn, headers, httpversion, self.config.body_size_limit + ) + return flow.Request(client_conn, httpversion, host, port, scheme, method, path, headers, content) def send_response(self, response): d = response._assemble()