get server reconnect right, fix timestamps

This commit is contained in:
Maximilian Hils 2014-01-18 17:15:33 +01:00
parent 862b532fff
commit 6c24b1d0d2
2 changed files with 71 additions and 65 deletions

View File

@ -3,12 +3,10 @@ from netlib import http, http_status, tcp
import netlib.utils import netlib.utils
from netlib.odict import ODictCaseless from netlib.odict import ODictCaseless
import select import select
from proxy import ProxyError from proxy import ProxyError, KILL
KILL = 0 # FIXME: Remove duplication with proxy module
LEGACY = True LEGACY = True
def _handle(msg, conntype, connection_handler, *args, **kwargs): def _handle(msg, conntype, connection_handler, *args, **kwargs):
handler = None handler = None
if conntype == "http": if conntype == "http":
@ -106,11 +104,11 @@ class HTTPResponse(HTTPMessage):
if not include_content: if not include_content:
raise NotImplementedError raise NotImplementedError
timestamp_start = libmproxy.utils.timestamp()
httpversion, code, msg, headers, content = http.read_response( httpversion, code, msg, headers, content = http.read_response(
rfile, rfile,
request_method, request_method,
body_size_limit) body_size_limit)
timestamp_start = rfile.first_byte_timestamp
timestamp_end = libmproxy.utils.timestamp() timestamp_end = libmproxy.utils.timestamp()
return HTTPResponse(httpversion, code, msg, headers, content, timestamp_start, timestamp_end) return HTTPResponse(httpversion, code, msg, headers, content, timestamp_start, timestamp_end)
@ -165,8 +163,8 @@ class HTTPRequest(HTTPMessage):
httpversion, host, port, scheme, method, path, headers, content, timestamp_start, timestamp_end \ httpversion, host, port, scheme, method, path, headers, content, timestamp_start, timestamp_end \
= None, None, None, None, None, None, None, None, None, None = None, None, None, None, None, None, None, None, None, None
timestamp_start = libmproxy.utils.timestamp()
request_line = HTTPHandler.get_line(rfile) request_line = HTTPHandler.get_line(rfile)
timestamp_start = rfile.first_byte_timestamp
request_line_parts = http.parse_init(request_line) request_line_parts = http.parse_init(request_line)
if not request_line_parts: if not request_line_parts:
@ -210,33 +208,34 @@ class HTTPHandler(ProtocolHandler):
pass pass
self.c.close = True self.c.close = True
""" def get_response_from_server(self, request):
def wait_for_message(self): request_raw = request._assemble()
"""
Check both the client connection and the server connection (if present) for readable data. for i in range(2):
""" try:
conns = [self.c.client_conn.rfile] self.c.server_conn.wfile.write(request_raw)
if self.c.server_conn: self.c.server_conn.wfile.flush()
conns.append(self.c.server_conn.rfile) return HTTPResponse.from_stream(self.c.server_conn.rfile, request.method,
while True: body_size_limit=self.c.config.body_size_limit)
readable, _, _ = select.select(conns, [], [], 10) except (tcp.NetLibDisconnect, http.HttpErrorConnClosed), v:
if self.c.client_conn.rfile in readable: self.c.log("error in server communication: %s" % str(v))
return if i < 1:
if self.c.server_conn.rfile in readable: # In any case, we try to reconnect at least once.
data = self.c.server_conn.rfile.read(1) # This is necessary because it might be possible that we already initiated an upstream connection
if data == "": # after clientconnect that has already been expired, e.g consider the following event log:
raise tcp.NetLibDisconnect # > clientconnect (transparent mode destination known)
elif data == "\r" or data == "\n": # > serverconnect
self.c.log("Received an empty line from server") # > read n% of large request
pass # Possible leftover from previous message # > server detects timeout, disconnects
# > read (100-n)% of large request
# > send large request upstream
self.c.server_reconnect()
else: else:
raise ProxyError(502, "Unexpected message from server") raise v
"""
def handle_flow(self): def handle_flow(self):
flow = HTTPFlow(self.c.client_conn, self.c.server_conn, None, None, None) flow = HTTPFlow(self.c.client_conn, self.c.server_conn, None, None, None)
try: try:
# self.wait_for_message()
flow.request = HTTPRequest.from_stream(self.c.client_conn.rfile, flow.request = HTTPRequest.from_stream(self.c.client_conn.rfile,
body_size_limit=self.c.config.body_size_limit) body_size_limit=self.c.config.body_size_limit)
self.c.log("request", [flow.request._assemble_request_line(flow.request.form_in)]) self.c.log("request", [flow.request._assemble_request_line(flow.request.form_in)])
@ -250,11 +249,7 @@ class HTTPHandler(ProtocolHandler):
if isinstance(request_reply, HTTPResponse): if isinstance(request_reply, HTTPResponse):
flow.response = request_reply flow.response = request_reply
else: else:
raw = flow.request._assemble() flow.response = self.get_response_from_server(flow.request)
self.c.server_conn.wfile.write(raw)
self.c.server_conn.wfile.flush()
flow.response = HTTPResponse.from_stream(self.c.server_conn.rfile, flow.request.method,
body_size_limit=self.c.config.body_size_limit)
self.c.log("response", [flow.response._assemble_response_line()]) self.c.log("response", [flow.response._assemble_response_line()])
response_reply = self.c.channel.ask("response" if LEGACY else "httpresponse", response_reply = self.c.channel.ask("response" if LEGACY else "httpresponse",

View File

@ -25,7 +25,8 @@ class Log:
class ProxyConfig: class ProxyConfig:
def __init__(self, certfile = None, cacert = None, clientcerts = None, no_upstream_cert=False, body_size_limit = None, reverse_proxy=None, forward_proxy=None, transparent_proxy=None, authenticator=None): def __init__(self, certfile=None, cacert=None, clientcerts=None, no_upstream_cert=False, body_size_limit=None,
reverse_proxy=None, forward_proxy=None, transparent_proxy=None, authenticator=None):
self.certfile = certfile self.certfile = certfile
self.cacert = cacert self.cacert = cacert
self.clientcerts = clientcerts self.clientcerts = clientcerts
@ -91,7 +92,6 @@ class ServerConnection(tcp.TCPClient):
raise ProxyError(400, str(v)) raise ProxyError(400, str(v))
def finish(self): def finish(self):
if self.connection: # Eventually, we had an error during .connect() and aren't even connected.
tcp.TCPClient.finish(self) tcp.TCPClient.finish(self)
self.timestamp_end = utils.timestamp() self.timestamp_end = utils.timestamp()
@ -141,7 +141,7 @@ class ConnectionHandler:
self.mode = "transparent" self.mode = "transparent"
def del_server_connection(self): def del_server_connection(self):
if self.server_conn: if self.server_conn and self.server_conn.connection:
self.server_conn.finish() self.server_conn.finish()
self.log("serverdisconnect", ["%s:%s" % (self.server_conn.host, self.server_conn.port)]) self.log("serverdisconnect", ["%s:%s" % (self.server_conn.host, self.server_conn.port)])
self.channel.tell("serverdisconnect", self) self.channel.tell("serverdisconnect", self)
@ -161,7 +161,8 @@ class ConnectionHandler:
if self.config.reverse_proxy: if self.config.reverse_proxy:
server_address = self.config.reverse_proxy[1:] server_address = self.config.reverse_proxy[1:]
elif self.config.transparent_proxy: elif self.config.transparent_proxy:
server_address = self.config.transparent_proxy["resolver"].original_addr(self.client_conn.connection) server_address = self.config.transparent_proxy["resolver"].original_addr(
self.client_conn.connection)
if not server_address: if not server_address:
raise ProxyError(502, "Transparent mode failure: could not resolve original destination.") raise ProxyError(502, "Transparent mode failure: could not resolve original destination.")
self.log("transparent to %s:%s" % server_address) self.log("transparent to %s:%s" % server_address)
@ -219,7 +220,7 @@ class ConnectionHandler:
self.log("serverconnect", ["%s:%s" % (host, port)]) self.log("serverconnect", ["%s:%s" % (host, port)])
self.channel.tell("serverconnect", self) self.channel.tell("serverconnect", self)
def establish_ssl(self, client, server): def establish_ssl(self, client=False, server=False):
""" """
Establishes SSL on the existing connection(s) to the server or the client, Establishes SSL on the existing connection(s) to the server or the client,
as specified by the parameters. If the target server is on the pass-through list, as specified by the parameters. If the target server is on the pass-through list,
@ -241,6 +242,14 @@ class ConnectionHandler:
self.client_conn.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert, self.client_conn.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert,
handle_sni=self.handle_sni) handle_sni=self.handle_sni)
def server_reconnect(self):
self.log("server reconnect")
had_ssl, sni = self.server_conn.ssl_established, self.sni
self.establish_server_connection(*self.server_conn.address)
if had_ssl:
self.sni = sni
self.establish_ssl(server=True)
def log(self, msg, subs=()): def log(self, msg, subs=()):
msg = [ msg = [
"%s:%s: %s" % (self.client_conn.host, self.client_conn.port, msg) "%s:%s: %s" % (self.client_conn.host, self.client_conn.port, msg)
@ -291,12 +300,14 @@ class ConnectionHandler:
except Exception, e: # pragma: no cover except Exception, e: # pragma: no cover
pass pass
class ProxyServerError(Exception): pass class ProxyServerError(Exception): pass
class ProxyServer(tcp.TCPServer): class ProxyServer(tcp.TCPServer):
allow_reuse_address = True allow_reuse_address = True
bound = True bound = True
def __init__(self, config, port, address='', server_version=version.NAMEVERSION): def __init__(self, config, port, address='', server_version=version.NAMEVERSION):
""" """
Raises ProxyServerError if there's a startup problem. Raises ProxyServerError if there's a startup problem.
@ -324,6 +335,7 @@ class ProxyServer(tcp.TCPServer):
class DummyServer: class DummyServer:
bound = False bound = False
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
@ -349,7 +361,6 @@ def certificate_option_group(parser):
) )
def process_proxy_options(parser, options): def process_proxy_options(parser, options):
if options.cert: if options.cert:
options.cert = os.path.expanduser(options.cert) options.cert = os.path.expanduser(options.cert)