refactor protocol handling, fix #332

This commit is contained in:
Maximilian Hils 2014-08-30 20:15:19 +02:00
parent 82730c1c6f
commit 1f47f7b6b2
8 changed files with 86 additions and 75 deletions

View File

@ -6,21 +6,12 @@ protocols = {
'tcp': dict(handler=tcp.TCPHandler) 'tcp': dict(handler=tcp.TCPHandler)
} }
def protocol_handler(protocol):
"""
@type protocol: str
@returns: libmproxy.protocol.primitives.ProtocolHandler
"""
if protocol in protocols:
return protocols[protocol]["handler"]
def _handler(conntype, connection_handler): raise NotImplementedError("Unknown Protocol: %s" % protocol) # pragma: nocover
if conntype in protocols:
return protocols[conntype]["handler"](connection_handler)
raise NotImplementedError # pragma: nocover
def handle_messages(conntype, connection_handler):
return _handler(conntype, connection_handler).handle_messages()
def handle_error(conntype, connection_handler, error):
return _handler(conntype, connection_handler).handle_error(error)
def handle_server_reconnect(conntype, connection_handler, state):
return _handler(conntype, connection_handler).handle_server_reconnect(state)

View File

@ -5,7 +5,8 @@ import threading
from netlib import http, tcp, http_status from netlib import http, tcp, http_status
import netlib.utils import netlib.utils
from netlib.odict import ODict, ODictCaseless from netlib.odict import ODict, ODictCaseless
from .primitives import KILL, ProtocolHandler, LiveConnection, Flow, Error from .tcp import TCPHandler
from .primitives import KILL, ProtocolHandler, Flow, Error
from ..proxy.connection import ServerConnection from ..proxy.connection import ServerConnection
from .. import encoding, utils, controller, stateobject, proxy from .. import encoding, utils, controller, stateobject, proxy
@ -914,7 +915,6 @@ class HTTPHandler(ProtocolHandler):
def handle_messages(self): def handle_messages(self):
while self.handle_flow(): while self.handle_flow():
pass pass
self.c.close = True
def get_response_from_server(self, request, include_body=True): def get_response_from_server(self, request, include_body=True):
self.c.establish_server_connection() self.c.establish_server_connection()
@ -948,9 +948,9 @@ class HTTPHandler(ProtocolHandler):
req = HTTPRequest.from_stream(self.c.client_conn.rfile, req = HTTPRequest.from_stream(self.c.client_conn.rfile,
body_size_limit=self.c.config.body_size_limit) body_size_limit=self.c.config.body_size_limit)
self.c.log("request", "debug", [req._assemble_first_line(req.form_in)]) self.c.log("request", "debug", [req._assemble_first_line(req.form_in)])
send_request_upstream = self.process_request(flow, req) ret = self.process_request(flow, req)
if not send_request_upstream: if ret is not None:
return True return ret
# Be careful NOT to assign the request to the flow before # Be careful NOT to assign the request to the flow before
# process_request completes. This is because the call can raise an # process_request completes. This is because the call can raise an
@ -959,7 +959,7 @@ class HTTPHandler(ProtocolHandler):
# sent through to the Master. # sent through to the Master.
flow.request = req flow.request = req
request_reply = self.c.channel.ask("request", flow.request) request_reply = self.c.channel.ask("request", flow.request)
self.determine_server_address(flow, flow.request) self.determine_server_address(flow, flow.request) # The inline script may have changed request.host
flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow
if request_reply is None or request_reply == KILL: if request_reply is None or request_reply == KILL:
@ -1025,6 +1025,7 @@ class HTTPHandler(ProtocolHandler):
else: else:
return False return False
# We sent a CONNECT request to an upstream proxy.
if flow.request.form_in == "authority" and flow.response.code == 200: if flow.request.form_in == "authority" and flow.response.code == 200:
# TODO: Eventually add headers (space/usefulness tradeoff) # TODO: Eventually add headers (space/usefulness tradeoff)
# Make sure to add state info before the actual upgrade happens. # Make sure to add state info before the actual upgrade happens.
@ -1034,7 +1035,15 @@ class HTTPHandler(ProtocolHandler):
self.c.server_conn.state.append(("http", {"state": "connect", self.c.server_conn.state.append(("http", {"state": "connect",
"host": flow.request.host, "host": flow.request.host,
"port": flow.request.port})) "port": flow.request.port}))
if self.c.check_ignore_address((flow.request.host, flow.request.port)):
self.c.log("Ignore host: %s:%s" % self.c.server_conn.address(), "info")
TCPHandler(self.c).handle_messages()
return False
else:
if flow.request.port in self.c.config.ssl_ports:
self.ssl_upgrade() self.ssl_upgrade()
self.skip_authentication = True
# If the user has changed the target server on this connection, # If the user has changed the target server on this connection,
# restore the original target server # restore the original target server
@ -1065,7 +1074,7 @@ class HTTPHandler(ProtocolHandler):
if flow: if flow:
flow.error = Error(message) flow.error = Error(message)
# FIXME: no flows without request or with both request and response at the moement. # FIXME: no flows without request or with both request and response at the moment.
if flow.request and not flow.response: if flow.request and not flow.response:
self.c.channel.ask("error", flow.error) self.c.channel.ask("error", flow.error)
else: else:
@ -1103,6 +1112,13 @@ class HTTPHandler(ProtocolHandler):
self.c.log("Upgrade to SSL completed.", "debug") self.c.log("Upgrade to SSL completed.", "debug")
def process_request(self, flow, request): def process_request(self, flow, request):
"""
@returns:
True, if the request should not be sent upstream
False, if the connection should be aborted
None, if the request should be sent upstream
(a status code != None should be returned directly by handle_flow)
"""
if not self.skip_authentication: if not self.skip_authentication:
self.authenticate(request) self.authenticate(request)
@ -1115,8 +1131,8 @@ class HTTPHandler(ProtocolHandler):
if not self.c.config.get_upstream_server: if not self.c.config.get_upstream_server:
self.c.set_server_address((request.host, request.port), self.c.set_server_address((request.host, request.port),
proxy.AddressPriority.FROM_PROTOCOL) proxy.AddressPriority.FROM_PROTOCOL)
self.c.establish_server_connection()
flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow
self.c.establish_server_connection()
self.c.client_conn.send( self.c.client_conn.send(
'HTTP/1.1 200 Connection established\r\n' + 'HTTP/1.1 200 Connection established\r\n' +
'Content-Length: 0\r\n' + 'Content-Length: 0\r\n' +
@ -1124,18 +1140,24 @@ class HTTPHandler(ProtocolHandler):
'\r\n' '\r\n'
) )
self.ssl_upgrade() if self.c.check_ignore_address(self.c.server_conn.address):
self.skip_authentication = True self.c.log("Ignore host: %s:%s" % self.c.server_conn.address(), "info")
TCPHandler(self.c).handle_messages()
return False return False
else: else:
if self.c.server_conn.address.port in self.c.config.ssl_ports:
self.ssl_upgrade()
self.skip_authentication = True
return True return True
else:
return None
elif request.form_in == self.expected_form_in: elif request.form_in == self.expected_form_in:
if request.form_in == "absolute": if request.form_in == "absolute":
if request.scheme != "http": if request.scheme != "http":
raise http.HttpError(400, "Invalid request scheme: %s" % request.scheme) raise http.HttpError(400, "Invalid request scheme: %s" % request.scheme)
self.determine_server_address(flow, request) self.determine_server_address(flow, request)
request.form_out = self.expected_form_out request.form_out = self.expected_form_out
return True return None
raise http.HttpError(400, "Invalid HTTP request form (expected: %s, got: %s)" % raise http.HttpError(400, "Invalid HTTP request form (expected: %s, got: %s)" %
(self.expected_form_in, request.form_in)) (self.expected_form_in, request.form_in))
@ -1172,10 +1194,11 @@ class RequestReplayThread(threading.Thread):
server_address, server_ssl = False, False server_address, server_ssl = False, False
if self.config.get_upstream_server: if self.config.get_upstream_server:
try: try:
# this will fail in transparent mode
upstream_info = self.config.get_upstream_server(self.flow.client_conn) upstream_info = self.config.get_upstream_server(self.flow.client_conn)
server_ssl = upstream_info[1] server_ssl = upstream_info[1]
server_address = upstream_info[2:] server_address = upstream_info[2:]
except proxy.ProxyError: # this will fail in transparent mode except proxy.ProxyError:
pass pass
if not server_address: if not server_address:
server_address = (r.get_host(), r.get_port()) server_address = (r.get_host(), r.get_port())
@ -1184,7 +1207,7 @@ class RequestReplayThread(threading.Thread):
server.connect() server.connect()
if server_ssl or r.get_scheme() == "https": if server_ssl or r.get_scheme() == "https":
if self.config.http_form_out == "absolute": if self.config.http_form_out == "absolute": # form_out == absolute -> forward mode -> send CONNECT
send_connect_request(server, r.get_host(), r.get_port()) send_connect_request(server, r.get_host(), r.get_port())
r.form_out = "relative" r.form_out = "relative"
server.establish_ssl(self.config.clientcerts, server.establish_ssl(self.config.clientcerts,

View File

@ -18,7 +18,7 @@ class TCPHandler(ProtocolHandler):
buf = memoryview(bytearray(self.chunk_size)) buf = memoryview(bytearray(self.chunk_size))
conns = [self.c.client_conn.rfile, self.c.server_conn.rfile] conns = [self.c.client_conn.rfile, self.c.server_conn.rfile]
while not self.c.close: while True:
r, _, _ = select.select(conns, [], [], 10) r, _, _ = select.select(conns, [], [], 10)
for rfile in r: for rfile in r:
if self.c.client_conn.rfile == rfile: if self.c.client_conn.rfile == rfile:
@ -51,7 +51,7 @@ class TCPHandler(ProtocolHandler):
dst.connection.shutdown(socket.SHUT_WR) dst.connection.shutdown(socket.SHUT_WR)
if len(conns) == 0: if len(conns) == 0:
self.c.close = True return
continue continue
if src.ssl_established or dst.ssl_established: if src.ssl_established or dst.ssl_established:

View File

@ -15,7 +15,7 @@ class ProxyConfig:
no_upstream_cert=False, body_size_limit=None, no_upstream_cert=False, body_size_limit=None,
mode=None, upstream_server=None, http_form_in=None, http_form_out=None, mode=None, upstream_server=None, http_form_in=None, http_form_out=None,
authenticator=None, ignore=[], authenticator=None, ignore=[],
ciphers=None, certs=[], certforward=False): ciphers=None, certs=[], certforward=False, ssl_ports=TRANSPARENT_SSL_PORTS):
self.ciphers = ciphers self.ciphers = ciphers
self.clientcerts = clientcerts self.clientcerts = clientcerts
self.no_upstream_cert = no_upstream_cert self.no_upstream_cert = no_upstream_cert
@ -49,6 +49,7 @@ class ProxyConfig:
for spec, cert in certs: for spec, cert in certs:
self.certstore.add_cert_file(spec, cert) self.certstore.add_cert_file(spec, cert)
self.certforward = certforward self.certforward = certforward
self.ssl_ports = ssl_ports
def process_proxy_options(parser, options): def process_proxy_options(parser, options):
@ -158,3 +159,9 @@ def ssl_option_group(parser):
action="store_true", dest="no_upstream_cert", action="store_true", dest="no_upstream_cert",
help="Don't connect to upstream server to look up certificate details." help="Don't connect to upstream server to look up certificate details."
) )
group.add_argument(
"--ssl-port", action="append", type=int, dest="ssl_ports", default=TRANSPARENT_SSL_PORTS,
metavar="PORT",
help="Can be passed multiple times. Specify destination ports which are assumed to be SSL. "
"Defaults to %s." % str(TRANSPARENT_SSL_PORTS)
)

View File

@ -6,13 +6,6 @@ class ProxyError(Exception):
super(ProxyError, self).__init__(self, message) super(ProxyError, self).__init__(self, message)
self.code, self.headers = code, headers self.code, self.headers = code, headers
class ConnectionTypeChange(Exception):
"""
Gets raised if the connection type has been changed (e.g. after HTTP/1.1 101 Switching Protocols).
It's up to the raising ProtocolHandler to specify the new conntype before raising the exception.
"""
pass
class ProxyServerError(Exception): class ProxyServerError(Exception):
pass pass

View File

@ -5,10 +5,9 @@ import socket
from OpenSSL import SSL from OpenSSL import SSL
from netlib import tcp from netlib import tcp
from .primitives import ProxyServerError, Log, ProxyError, ConnectionTypeChange, \ from .primitives import ProxyServerError, Log, ProxyError, AddressPriority
AddressPriority
from .connection import ClientConnection, ServerConnection from .connection import ClientConnection, ServerConnection
from ..protocol.handle import handle_messages, handle_error, handle_server_reconnect from ..protocol.handle import protocol_handler
from .. import version from .. import version
@ -66,7 +65,6 @@ class ConnectionHandler:
"""@type: libmproxy.proxy.connection.ServerConnection""" """@type: libmproxy.proxy.connection.ServerConnection"""
self.channel, self.server_version = channel, server_version self.channel, self.server_version = channel, server_version
self.close = False
self.conntype = "http" self.conntype = "http"
self.sni = None self.sni = None
@ -77,28 +75,28 @@ class ConnectionHandler:
# Can we already identify the target server and connect to it? # Can we already identify the target server and connect to it?
client_ssl, server_ssl = False, False client_ssl, server_ssl = False, False
if self.config.get_upstream_server: if self.config.get_upstream_server:
upstream_info = self.config.get_upstream_server( upstream_info = self.config.get_upstream_server(self.client_conn.connection)
self.client_conn.connection)
self.set_server_address(upstream_info[2:], AddressPriority.FROM_SETTINGS) self.set_server_address(upstream_info[2:], AddressPriority.FROM_SETTINGS)
client_ssl, server_ssl = upstream_info[:2] client_ssl, server_ssl = upstream_info[:2]
if self.check_ignore_address(self.server_conn.address):
self.log("Ignore host: %s:%s" % self.server_conn.address(), "info")
self.conntype = "tcp"
client_ssl, server_ssl = False, False
self.determine_conntype()
self.channel.ask("clientconnect", self) self.channel.ask("clientconnect", self)
# Check for existing connection: If an inline script already established a
# connection, do not apply client_ssl or server_ssl.
if self.server_conn and not self.server_conn.connection: if self.server_conn and not self.server_conn.connection:
self.establish_server_connection() self.establish_server_connection()
if client_ssl or server_ssl: if client_ssl or server_ssl:
self.establish_ssl(client=client_ssl, server=server_ssl) self.establish_ssl(client=client_ssl, server=server_ssl)
while not self.close: # Delegate handling to the protocol handler
try: protocol_handler(self.conntype)(self).handle_messages()
handle_messages(self.conntype, self)
except ConnectionTypeChange:
self.log("Connection type changed: %s" % self.conntype, "info")
continue
except (ProxyError, tcp.NetLibError), e: except (ProxyError, tcp.NetLibError), e:
handle_error(self.conntype, self, e) protocol_handler(self.conntype)(self).handle_error(e)
except Exception: except Exception:
import traceback, sys import traceback, sys
@ -106,7 +104,6 @@ class ConnectionHandler:
print >> sys.stderr, traceback.format_exc() print >> sys.stderr, traceback.format_exc()
print >> sys.stderr, "mitmproxy has crashed!" print >> sys.stderr, "mitmproxy has crashed!"
print >> sys.stderr, "Please lodge a bug report at: https://github.com/mitmproxy/mitmproxy" print >> sys.stderr, "Please lodge a bug report at: https://github.com/mitmproxy/mitmproxy"
raise
self.del_server_connection() self.del_server_connection()
self.log("clientdisconnect", "info") self.log("clientdisconnect", "info")
@ -124,12 +121,13 @@ class ConnectionHandler:
self.server_conn = None self.server_conn = None
self.sni = None self.sni = None
def determine_conntype(self): def check_ignore_address(self, address):
if self.server_conn and any(rex.search(self.server_conn.address.host) for rex in self.config.ignore): address = tcp.Address.wrap(address)
self.log("Ignore host: %s" % self.server_conn.address.host, "info") host = "%s:%s" % (address.host, address.port)
self.conntype = "tcp" if host and any(rex.search(host) for rex in self.config.ignore):
return True
else: else:
self.conntype = "http" return False
def set_server_address(self, address, priority): def set_server_address(self, address, priority):
""" """
@ -175,15 +173,8 @@ class ConnectionHandler:
def establish_ssl(self, client=False, server=False): def establish_ssl(self, client=False, server=False):
""" """
Establishes SSL on the existing connection(s) to the server or the client, Establishes SSL on the existing connection(s) to the server or the client,
as specified by the parameters. If the target server is on the pass-through list, as specified by the parameters.
the conntype attribute will be changed and a ConnTypeChanged exception will be raised.
""" """
# If the host is on our ignore list, change to passthrough/ignore mode.
for host in (self.server_conn.address.host, self.sni):
if host and any(rex.search(host) for rex in self.config.ignore):
self.log("Ignore host: %s" % host, "info")
self.conntype = "tcp"
raise ConnectionTypeChange()
# Logging # Logging
if client or server: if client or server:
@ -224,7 +215,7 @@ class ConnectionHandler:
self.establish_server_connection() self.establish_server_connection()
for s in state: for s in state:
handle_server_reconnect(s[0], self, s[1]) protocol_handler(s[0])(self).handle_server_reconnect(s[1])
self.server_conn.state = state self.server_conn.state = state
if had_ssl: if had_ssl:
@ -288,4 +279,5 @@ class ConnectionHandler:
# make dang sure it doesn't happen. # make dang sure it doesn't happen.
except Exception: # pragma: no cover except Exception: # pragma: no cover
import traceback import traceback
self.log("Error in handle_sni:\r\n" + traceback.format_exc(), "error") self.log("Error in handle_sni:\r\n" + traceback.format_exc(), "error")

View File

@ -175,6 +175,7 @@ class TestHTTPAuth(tservers.HTTPProxTest):
class TestHTTPConnectSSLError(tservers.HTTPProxTest): class TestHTTPConnectSSLError(tservers.HTTPProxTest):
certfile = True certfile = True
def test_go(self): def test_go(self):
self.config.ssl_ports.append(self.proxy.port)
p = self.pathoc_raw() p = self.pathoc_raw()
dst = ("localhost", self.proxy.port) dst = ("localhost", self.proxy.port)
p.connect(connect_to=dst) p.connect(connect_to=dst)

View File

@ -95,6 +95,7 @@ class ProxTestBase(object):
confdir = cls.confdir, confdir = cls.confdir,
authenticator = cls.authenticator, authenticator = cls.authenticator,
certforward = cls.certforward, certforward = cls.certforward,
ssl_ports=([cls.server.port, cls.server2.port] if cls.ssl else []),
**pconf **pconf
) )
tmaster = cls.masterclass(cls.config) tmaster = cls.masterclass(cls.config)
@ -267,17 +268,20 @@ class ChainProxTest(ProxTestBase):
Chain n instances of mitmproxy in a row - because we can. Chain n instances of mitmproxy in a row - because we can.
""" """
n = 2 n = 2
chain_config = [lambda port: ProxyConfig( chain_config = [lambda port, sslports: ProxyConfig(
upstream_server= (False, False, "127.0.0.1", port), upstream_server= (False, False, "127.0.0.1", port),
http_form_in = "absolute", http_form_in = "absolute",
http_form_out = "absolute" http_form_out = "absolute",
ssl_ports=sslports
)] * n )] * n
@classmethod @classmethod
def setupAll(cls): def setupAll(cls):
super(ChainProxTest, cls).setupAll() super(ChainProxTest, cls).setupAll()
cls.chain = [] cls.chain = []
for i in range(cls.n): for i in range(cls.n):
config = cls.chain_config[i](cls.proxy.port if i == 0 else cls.chain[-1].port) sslports = [cls.server.port, cls.server2.port]
config = cls.chain_config[i](cls.proxy.port if i == 0 else cls.chain[-1].port,
sslports)
tmaster = cls.masterclass(config) tmaster = cls.masterclass(config)
tmaster.start_app(APP_HOST, APP_PORT, cls.externalapp) tmaster.start_app(APP_HOST, APP_PORT, cls.externalapp)
cls.chain.append(ProxyThread(tmaster)) cls.chain.append(ProxyThread(tmaster))