From 1f47f7b6b29cd1229264edf75194652824d94705 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 30 Aug 2014 20:15:19 +0200 Subject: [PATCH] refactor protocol handling, fix #332 --- libmproxy/protocol/handle.py | 25 +++++----------- libmproxy/protocol/http.py | 55 +++++++++++++++++++++++++---------- libmproxy/protocol/tcp.py | 4 +-- libmproxy/proxy/config.py | 9 +++++- libmproxy/proxy/primitives.py | 7 ----- libmproxy/proxy/server.py | 50 +++++++++++++------------------ test/test_server.py | 1 + test/tservers.py | 10 +++++-- 8 files changed, 86 insertions(+), 75 deletions(-) diff --git a/libmproxy/protocol/handle.py b/libmproxy/protocol/handle.py index a238b3495..100c73686 100644 --- a/libmproxy/protocol/handle.py +++ b/libmproxy/protocol/handle.py @@ -6,21 +6,12 @@ protocols = { 'tcp': dict(handler=tcp.TCPHandler) } +def protocol_handler(protocol): + """ + @type protocol: str + @returns: libmproxy.protocol.primitives.ProtocolHandler + """ + if protocol in protocols: + return protocols[protocol]["handler"] -def _handler(conntype, connection_handler): - if conntype in protocols: - return protocols[conntype]["handler"](connection_handler) - - raise NotImplementedError # pragma: nocover - - -def handle_messages(conntype, connection_handler): - return _handler(conntype, connection_handler).handle_messages() - - -def handle_error(conntype, connection_handler, error): - return _handler(conntype, connection_handler).handle_error(error) - - -def handle_server_reconnect(conntype, connection_handler, state): - return _handler(conntype, connection_handler).handle_server_reconnect(state) \ No newline at end of file + raise NotImplementedError("Unknown Protocol: %s" % protocol) # pragma: nocover \ No newline at end of file diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index 38a6cb498..b635d0719 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -5,7 +5,8 @@ import threading from netlib import http, tcp, http_status import netlib.utils from netlib.odict import ODict, ODictCaseless -from .primitives import KILL, ProtocolHandler, LiveConnection, Flow, Error +from .tcp import TCPHandler +from .primitives import KILL, ProtocolHandler, Flow, Error from ..proxy.connection import ServerConnection from .. import encoding, utils, controller, stateobject, proxy @@ -914,7 +915,6 @@ class HTTPHandler(ProtocolHandler): def handle_messages(self): while self.handle_flow(): pass - self.c.close = True def get_response_from_server(self, request, include_body=True): self.c.establish_server_connection() @@ -948,9 +948,9 @@ class HTTPHandler(ProtocolHandler): req = HTTPRequest.from_stream(self.c.client_conn.rfile, body_size_limit=self.c.config.body_size_limit) self.c.log("request", "debug", [req._assemble_first_line(req.form_in)]) - send_request_upstream = self.process_request(flow, req) - if not send_request_upstream: - return True + ret = self.process_request(flow, req) + if ret is not None: + return ret # Be careful NOT to assign the request to the flow before # process_request completes. This is because the call can raise an @@ -959,7 +959,7 @@ class HTTPHandler(ProtocolHandler): # sent through to the Master. flow.request = req request_reply = self.c.channel.ask("request", flow.request) - self.determine_server_address(flow, flow.request) + self.determine_server_address(flow, flow.request) # The inline script may have changed request.host flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow if request_reply is None or request_reply == KILL: @@ -1025,6 +1025,7 @@ class HTTPHandler(ProtocolHandler): else: return False + # We sent a CONNECT request to an upstream proxy. if flow.request.form_in == "authority" and flow.response.code == 200: # TODO: Eventually add headers (space/usefulness tradeoff) # Make sure to add state info before the actual upgrade happens. @@ -1034,7 +1035,15 @@ class HTTPHandler(ProtocolHandler): self.c.server_conn.state.append(("http", {"state": "connect", "host": flow.request.host, "port": flow.request.port})) - self.ssl_upgrade() + + if self.c.check_ignore_address((flow.request.host, flow.request.port)): + self.c.log("Ignore host: %s:%s" % self.c.server_conn.address(), "info") + TCPHandler(self.c).handle_messages() + return False + else: + if flow.request.port in self.c.config.ssl_ports: + self.ssl_upgrade() + self.skip_authentication = True # If the user has changed the target server on this connection, # restore the original target server @@ -1065,7 +1074,7 @@ class HTTPHandler(ProtocolHandler): if flow: flow.error = Error(message) - # FIXME: no flows without request or with both request and response at the moement. + # FIXME: no flows without request or with both request and response at the moment. if flow.request and not flow.response: self.c.channel.ask("error", flow.error) else: @@ -1103,6 +1112,13 @@ class HTTPHandler(ProtocolHandler): self.c.log("Upgrade to SSL completed.", "debug") def process_request(self, flow, request): + """ + @returns: + True, if the request should not be sent upstream + False, if the connection should be aborted + None, if the request should be sent upstream + (a status code != None should be returned directly by handle_flow) + """ if not self.skip_authentication: self.authenticate(request) @@ -1115,8 +1131,8 @@ class HTTPHandler(ProtocolHandler): if not self.c.config.get_upstream_server: self.c.set_server_address((request.host, request.port), proxy.AddressPriority.FROM_PROTOCOL) - self.c.establish_server_connection() flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow + self.c.establish_server_connection() self.c.client_conn.send( 'HTTP/1.1 200 Connection established\r\n' + 'Content-Length: 0\r\n' + @@ -1124,18 +1140,24 @@ class HTTPHandler(ProtocolHandler): '\r\n' ) - self.ssl_upgrade() - self.skip_authentication = True - return False + if self.c.check_ignore_address(self.c.server_conn.address): + self.c.log("Ignore host: %s:%s" % self.c.server_conn.address(), "info") + TCPHandler(self.c).handle_messages() + return False + else: + if self.c.server_conn.address.port in self.c.config.ssl_ports: + self.ssl_upgrade() + self.skip_authentication = True + return True else: - return True + return None elif request.form_in == self.expected_form_in: if request.form_in == "absolute": if request.scheme != "http": raise http.HttpError(400, "Invalid request scheme: %s" % request.scheme) self.determine_server_address(flow, request) request.form_out = self.expected_form_out - return True + return None raise http.HttpError(400, "Invalid HTTP request form (expected: %s, got: %s)" % (self.expected_form_in, request.form_in)) @@ -1172,10 +1194,11 @@ class RequestReplayThread(threading.Thread): server_address, server_ssl = False, False if self.config.get_upstream_server: try: + # this will fail in transparent mode upstream_info = self.config.get_upstream_server(self.flow.client_conn) server_ssl = upstream_info[1] server_address = upstream_info[2:] - except proxy.ProxyError: # this will fail in transparent mode + except proxy.ProxyError: pass if not server_address: server_address = (r.get_host(), r.get_port()) @@ -1184,7 +1207,7 @@ class RequestReplayThread(threading.Thread): server.connect() if server_ssl or r.get_scheme() == "https": - if self.config.http_form_out == "absolute": + if self.config.http_form_out == "absolute": # form_out == absolute -> forward mode -> send CONNECT send_connect_request(server, r.get_host(), r.get_port()) r.form_out = "relative" server.establish_ssl(self.config.clientcerts, diff --git a/libmproxy/protocol/tcp.py b/libmproxy/protocol/tcp.py index a77a90965..00dbf4b34 100644 --- a/libmproxy/protocol/tcp.py +++ b/libmproxy/protocol/tcp.py @@ -18,7 +18,7 @@ class TCPHandler(ProtocolHandler): buf = memoryview(bytearray(self.chunk_size)) conns = [self.c.client_conn.rfile, self.c.server_conn.rfile] - while not self.c.close: + while True: r, _, _ = select.select(conns, [], [], 10) for rfile in r: if self.c.client_conn.rfile == rfile: @@ -51,7 +51,7 @@ class TCPHandler(ProtocolHandler): dst.connection.shutdown(socket.SHUT_WR) if len(conns) == 0: - self.c.close = True + return continue if src.ssl_established or dst.ssl_established: diff --git a/libmproxy/proxy/config.py b/libmproxy/proxy/config.py index afa7440c5..6d4c078b2 100644 --- a/libmproxy/proxy/config.py +++ b/libmproxy/proxy/config.py @@ -15,7 +15,7 @@ class ProxyConfig: no_upstream_cert=False, body_size_limit=None, mode=None, upstream_server=None, http_form_in=None, http_form_out=None, authenticator=None, ignore=[], - ciphers=None, certs=[], certforward=False): + ciphers=None, certs=[], certforward=False, ssl_ports=TRANSPARENT_SSL_PORTS): self.ciphers = ciphers self.clientcerts = clientcerts self.no_upstream_cert = no_upstream_cert @@ -49,6 +49,7 @@ class ProxyConfig: for spec, cert in certs: self.certstore.add_cert_file(spec, cert) self.certforward = certforward + self.ssl_ports = ssl_ports def process_proxy_options(parser, options): @@ -157,4 +158,10 @@ def ssl_option_group(parser): "--no-upstream-cert", default=False, action="store_true", dest="no_upstream_cert", help="Don't connect to upstream server to look up certificate details." + ) + group.add_argument( + "--ssl-port", action="append", type=int, dest="ssl_ports", default=TRANSPARENT_SSL_PORTS, + metavar="PORT", + help="Can be passed multiple times. Specify destination ports which are assumed to be SSL. " + "Defaults to %s." % str(TRANSPARENT_SSL_PORTS) ) \ No newline at end of file diff --git a/libmproxy/proxy/primitives.py b/libmproxy/proxy/primitives.py index dc4b7e220..f2c0a9842 100644 --- a/libmproxy/proxy/primitives.py +++ b/libmproxy/proxy/primitives.py @@ -6,13 +6,6 @@ class ProxyError(Exception): super(ProxyError, self).__init__(self, message) self.code, self.headers = code, headers -class ConnectionTypeChange(Exception): - """ - Gets raised if the connection type has been changed (e.g. after HTTP/1.1 101 Switching Protocols). - It's up to the raising ProtocolHandler to specify the new conntype before raising the exception. - """ - pass - class ProxyServerError(Exception): pass diff --git a/libmproxy/proxy/server.py b/libmproxy/proxy/server.py index 02b86d71d..56e8860b4 100644 --- a/libmproxy/proxy/server.py +++ b/libmproxy/proxy/server.py @@ -5,10 +5,9 @@ import socket from OpenSSL import SSL from netlib import tcp -from .primitives import ProxyServerError, Log, ProxyError, ConnectionTypeChange, \ - AddressPriority +from .primitives import ProxyServerError, Log, ProxyError, AddressPriority from .connection import ClientConnection, ServerConnection -from ..protocol.handle import handle_messages, handle_error, handle_server_reconnect +from ..protocol.handle import protocol_handler from .. import version @@ -66,7 +65,6 @@ class ConnectionHandler: """@type: libmproxy.proxy.connection.ServerConnection""" self.channel, self.server_version = channel, server_version - self.close = False self.conntype = "http" self.sni = None @@ -77,28 +75,28 @@ class ConnectionHandler: # Can we already identify the target server and connect to it? client_ssl, server_ssl = False, False if self.config.get_upstream_server: - upstream_info = self.config.get_upstream_server( - self.client_conn.connection) + upstream_info = self.config.get_upstream_server(self.client_conn.connection) self.set_server_address(upstream_info[2:], AddressPriority.FROM_SETTINGS) client_ssl, server_ssl = upstream_info[:2] + if self.check_ignore_address(self.server_conn.address): + self.log("Ignore host: %s:%s" % self.server_conn.address(), "info") + self.conntype = "tcp" + client_ssl, server_ssl = False, False - self.determine_conntype() self.channel.ask("clientconnect", self) + # Check for existing connection: If an inline script already established a + # connection, do not apply client_ssl or server_ssl. if self.server_conn and not self.server_conn.connection: self.establish_server_connection() if client_ssl or server_ssl: self.establish_ssl(client=client_ssl, server=server_ssl) - while not self.close: - try: - handle_messages(self.conntype, self) - except ConnectionTypeChange: - self.log("Connection type changed: %s" % self.conntype, "info") - continue + # Delegate handling to the protocol handler + protocol_handler(self.conntype)(self).handle_messages() except (ProxyError, tcp.NetLibError), e: - handle_error(self.conntype, self, e) + protocol_handler(self.conntype)(self).handle_error(e) except Exception: import traceback, sys @@ -106,7 +104,6 @@ class ConnectionHandler: print >> sys.stderr, traceback.format_exc() print >> sys.stderr, "mitmproxy has crashed!" print >> sys.stderr, "Please lodge a bug report at: https://github.com/mitmproxy/mitmproxy" - raise self.del_server_connection() self.log("clientdisconnect", "info") @@ -124,12 +121,13 @@ class ConnectionHandler: self.server_conn = None self.sni = None - def determine_conntype(self): - if self.server_conn and any(rex.search(self.server_conn.address.host) for rex in self.config.ignore): - self.log("Ignore host: %s" % self.server_conn.address.host, "info") - self.conntype = "tcp" + def check_ignore_address(self, address): + address = tcp.Address.wrap(address) + host = "%s:%s" % (address.host, address.port) + if host and any(rex.search(host) for rex in self.config.ignore): + return True else: - self.conntype = "http" + return False def set_server_address(self, address, priority): """ @@ -175,15 +173,8 @@ class ConnectionHandler: def establish_ssl(self, client=False, server=False): """ Establishes SSL on the existing connection(s) to the server or the client, - as specified by the parameters. If the target server is on the pass-through list, - the conntype attribute will be changed and a ConnTypeChanged exception will be raised. + as specified by the parameters. """ - # If the host is on our ignore list, change to passthrough/ignore mode. - for host in (self.server_conn.address.host, self.sni): - if host and any(rex.search(host) for rex in self.config.ignore): - self.log("Ignore host: %s" % host, "info") - self.conntype = "tcp" - raise ConnectionTypeChange() # Logging if client or server: @@ -224,7 +215,7 @@ class ConnectionHandler: self.establish_server_connection() for s in state: - handle_server_reconnect(s[0], self, s[1]) + protocol_handler(s[0])(self).handle_server_reconnect(s[1]) self.server_conn.state = state if had_ssl: @@ -288,4 +279,5 @@ class ConnectionHandler: # make dang sure it doesn't happen. except Exception: # pragma: no cover import traceback + self.log("Error in handle_sni:\r\n" + traceback.format_exc(), "error") \ No newline at end of file diff --git a/test/test_server.py b/test/test_server.py index 017faacb0..52efa5f28 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -175,6 +175,7 @@ class TestHTTPAuth(tservers.HTTPProxTest): class TestHTTPConnectSSLError(tservers.HTTPProxTest): certfile = True def test_go(self): + self.config.ssl_ports.append(self.proxy.port) p = self.pathoc_raw() dst = ("localhost", self.proxy.port) p.connect(connect_to=dst) diff --git a/test/tservers.py b/test/tservers.py index 597ad4ee2..a12a440e3 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -95,6 +95,7 @@ class ProxTestBase(object): confdir = cls.confdir, authenticator = cls.authenticator, certforward = cls.certforward, + ssl_ports=([cls.server.port, cls.server2.port] if cls.ssl else []), **pconf ) tmaster = cls.masterclass(cls.config) @@ -267,17 +268,20 @@ class ChainProxTest(ProxTestBase): Chain n instances of mitmproxy in a row - because we can. """ n = 2 - chain_config = [lambda port: ProxyConfig( + chain_config = [lambda port, sslports: ProxyConfig( upstream_server= (False, False, "127.0.0.1", port), http_form_in = "absolute", - http_form_out = "absolute" + http_form_out = "absolute", + ssl_ports=sslports )] * n @classmethod def setupAll(cls): super(ChainProxTest, cls).setupAll() cls.chain = [] for i in range(cls.n): - config = cls.chain_config[i](cls.proxy.port if i == 0 else cls.chain[-1].port) + sslports = [cls.server.port, cls.server2.port] + config = cls.chain_config[i](cls.proxy.port if i == 0 else cls.chain[-1].port, + sslports) tmaster = cls.masterclass(config) tmaster.start_app(APP_HOST, APP_PORT, cls.externalapp) cls.chain.append(ProxyThread(tmaster))