Refactor ServerConnection API.

This commit is contained in:
Aldo Cortesi 2012-06-10 08:13:50 +12:00
parent 987f443b5d
commit 0c458e2f1a

View File

@ -263,9 +263,10 @@ class RequestReplayThread(threading.Thread):
def run(self): def run(self):
try: try:
server = ServerConnection(self.config, self.flow.request) r = self.flow.request
server.send() server = ServerConnection(self.config, r.scheme, r.host, r.port)
response = server.read_response() server.send(r)
response = server.read_response(r)
response._send(self.masterq) response._send(self.masterq)
except ProxyError, v: except ProxyError, v:
err = flow.Error(self.flow.request, v.msg) err = flow.Error(self.flow.request, v.msg)
@ -273,14 +274,8 @@ class RequestReplayThread(threading.Thread):
class ServerConnection: class ServerConnection:
def __init__(self, config, request): def __init__(self, config, scheme, host, port):
self.config, self.request = config, request self.config, self.scheme, self.host, self.port = config, scheme, host, port
if config.reverse_proxy:
self.scheme, self.host, self.port = config.reverse_proxy
else:
self.host = request.host
self.port = request.port
self.scheme = request.scheme
self.close = False self.close = False
self.cert = None self.cert = None
self.sock, self.rfile, self.wfile = None, None, None self.sock, self.rfile, self.wfile = None, None, None
@ -298,7 +293,6 @@ class ServerConnection:
else: else:
clientcert = None clientcert = None
server = ssl.wrap_socket(server, certfile = clientcert) server = ssl.wrap_socket(server, certfile = clientcert)
server.connect((addr, self.port)) server.connect((addr, self.port))
if self.scheme == "https": if self.scheme == "https":
self.cert = server.getpeercert(True) self.cert = server.getpeercert(True)
@ -307,18 +301,18 @@ class ServerConnection:
self.sock = server self.sock = server
self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb') self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb')
def send(self): def send(self, request):
self.request.close = self.close request.close = self.close
try: try:
d = self.request._assemble() d = request._assemble()
if not d: if not d:
raise ProxyError(502, "Incomplete request could not not be readied for transmission.") raise ProxyError(502, "Incomplete request could not not be readied for transmission.")
self.wfile.write(d) self.wfile.write(d)
self.wfile.flush() self.wfile.flush()
except socket.error, err: except socket.error, err:
raise ProxyError(502, 'Error sending data to "%s": %s' % (self.request.host, err)) raise ProxyError(502, 'Error sending data to "%s": %s' % (request.host, err))
def read_response(self): def read_response(self, request):
line = self.rfile.readline() line = self.rfile.readline()
if line == "\r\n" or line == "\n": # Possible leftover from previous message if line == "\r\n" or line == "\n": # Possible leftover from previous message
line = self.rfile.readline() line = self.rfile.readline()
@ -337,11 +331,11 @@ class ServerConnection:
headers = read_headers(self.rfile) headers = read_headers(self.rfile)
if code >= 100 and code <= 199: if code >= 100 and code <= 199:
return self.read_response() return self.read_response()
if self.request.method == "HEAD" or code == 204 or code == 304: if request.method == "HEAD" or code == 204 or code == 304:
content = "" content = ""
else: else:
content = read_http_body(self.rfile, self, headers, True, self.config.body_size_limit) content = read_http_body(self.rfile, self, headers, True, self.config.body_size_limit)
return flow.Response(self.request, code, msg, headers, content, self.cert) return flow.Response(request, code, msg, headers, content, self.cert)
def terminate(self): def terminate(self):
try: try:
@ -393,10 +387,14 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
request = False request = False
response = response._send(self.mqueue) response = response._send(self.mqueue)
else: else:
server_conn = ServerConnection(self.config, request) if self.config.reverse_proxy:
server_conn.send() scheme, host, port = self.config.reverse_proxy
server_conn = ServerConnection(self.config, scheme, host, port)
else:
server_conn = ServerConnection(self.config, request.scheme, request.host, request.port)
server_conn.send(request)
try: try:
response = server_conn.read_response() response = server_conn.read_response(request)
except IOError, v: except IOError, v:
raise IOError, "Reading response: %s"%v raise IOError, "Reading response: %s"%v
response = response._send(self.mqueue) response = response._send(self.mqueue)