Refactor ServerConnection API.

This commit is contained in:
Aldo Cortesi 2012-06-10 08:13:50 +12:00
parent 987f443b5d
commit 0c458e2f1a

View File

@ -263,9 +263,10 @@ class RequestReplayThread(threading.Thread):
def run(self):
try:
server = ServerConnection(self.config, self.flow.request)
server.send()
response = server.read_response()
r = self.flow.request
server = ServerConnection(self.config, r.scheme, r.host, r.port)
server.send(r)
response = server.read_response(r)
response._send(self.masterq)
except ProxyError, v:
err = flow.Error(self.flow.request, v.msg)
@ -273,14 +274,8 @@ class RequestReplayThread(threading.Thread):
class ServerConnection:
def __init__(self, config, request):
self.config, self.request = config, request
if config.reverse_proxy:
self.scheme, self.host, self.port = config.reverse_proxy
else:
self.host = request.host
self.port = request.port
self.scheme = request.scheme
def __init__(self, config, scheme, host, port):
self.config, self.scheme, self.host, self.port = config, scheme, host, port
self.close = False
self.cert = None
self.sock, self.rfile, self.wfile = None, None, None
@ -298,7 +293,6 @@ class ServerConnection:
else:
clientcert = None
server = ssl.wrap_socket(server, certfile = clientcert)
server.connect((addr, self.port))
if self.scheme == "https":
self.cert = server.getpeercert(True)
@ -307,18 +301,18 @@ class ServerConnection:
self.sock = server
self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb')
def send(self):
self.request.close = self.close
def send(self, request):
request.close = self.close
try:
d = self.request._assemble()
d = request._assemble()
if not d:
raise ProxyError(502, "Incomplete request could not not be readied for transmission.")
self.wfile.write(d)
self.wfile.flush()
except socket.error, err:
raise ProxyError(502, 'Error sending data to "%s": %s' % (self.request.host, err))
raise ProxyError(502, 'Error sending data to "%s": %s' % (request.host, err))
def read_response(self):
def read_response(self, request):
line = self.rfile.readline()
if line == "\r\n" or line == "\n": # Possible leftover from previous message
line = self.rfile.readline()
@ -337,11 +331,11 @@ class ServerConnection:
headers = read_headers(self.rfile)
if code >= 100 and code <= 199:
return self.read_response()
if self.request.method == "HEAD" or code == 204 or code == 304:
if request.method == "HEAD" or code == 204 or code == 304:
content = ""
else:
content = read_http_body(self.rfile, self, headers, True, self.config.body_size_limit)
return flow.Response(self.request, code, msg, headers, content, self.cert)
return flow.Response(request, code, msg, headers, content, self.cert)
def terminate(self):
try:
@ -393,10 +387,14 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
request = False
response = response._send(self.mqueue)
else:
server_conn = ServerConnection(self.config, request)
server_conn.send()
if self.config.reverse_proxy:
scheme, host, port = self.config.reverse_proxy
server_conn = ServerConnection(self.config, scheme, host, port)
else:
server_conn = ServerConnection(self.config, request.scheme, request.host, request.port)
server_conn.send(request)
try:
response = server_conn.read_response()
response = server_conn.read_response(request)
except IOError, v:
raise IOError, "Reading response: %s"%v
response = response._send(self.mqueue)