refactor protocol handling, fix #332

This commit is contained in:
Maximilian Hils 2014-08-30 20:15:19 +02:00
parent 82730c1c6f
commit 1f47f7b6b2
8 changed files with 86 additions and 75 deletions

View File

@ -6,21 +6,12 @@ protocols = {
'tcp': dict(handler=tcp.TCPHandler)
}
def protocol_handler(protocol):
"""
@type protocol: str
@returns: libmproxy.protocol.primitives.ProtocolHandler
"""
if protocol in protocols:
return protocols[protocol]["handler"]
def _handler(conntype, connection_handler):
if conntype in protocols:
return protocols[conntype]["handler"](connection_handler)
raise NotImplementedError # pragma: nocover
def handle_messages(conntype, connection_handler):
return _handler(conntype, connection_handler).handle_messages()
def handle_error(conntype, connection_handler, error):
return _handler(conntype, connection_handler).handle_error(error)
def handle_server_reconnect(conntype, connection_handler, state):
return _handler(conntype, connection_handler).handle_server_reconnect(state)
raise NotImplementedError("Unknown Protocol: %s" % protocol) # pragma: nocover

View File

@ -5,7 +5,8 @@ import threading
from netlib import http, tcp, http_status
import netlib.utils
from netlib.odict import ODict, ODictCaseless
from .primitives import KILL, ProtocolHandler, LiveConnection, Flow, Error
from .tcp import TCPHandler
from .primitives import KILL, ProtocolHandler, Flow, Error
from ..proxy.connection import ServerConnection
from .. import encoding, utils, controller, stateobject, proxy
@ -914,7 +915,6 @@ class HTTPHandler(ProtocolHandler):
def handle_messages(self):
while self.handle_flow():
pass
self.c.close = True
def get_response_from_server(self, request, include_body=True):
self.c.establish_server_connection()
@ -948,9 +948,9 @@ class HTTPHandler(ProtocolHandler):
req = HTTPRequest.from_stream(self.c.client_conn.rfile,
body_size_limit=self.c.config.body_size_limit)
self.c.log("request", "debug", [req._assemble_first_line(req.form_in)])
send_request_upstream = self.process_request(flow, req)
if not send_request_upstream:
return True
ret = self.process_request(flow, req)
if ret is not None:
return ret
# Be careful NOT to assign the request to the flow before
# process_request completes. This is because the call can raise an
@ -959,7 +959,7 @@ class HTTPHandler(ProtocolHandler):
# sent through to the Master.
flow.request = req
request_reply = self.c.channel.ask("request", flow.request)
self.determine_server_address(flow, flow.request)
self.determine_server_address(flow, flow.request) # The inline script may have changed request.host
flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow
if request_reply is None or request_reply == KILL:
@ -1025,6 +1025,7 @@ class HTTPHandler(ProtocolHandler):
else:
return False
# We sent a CONNECT request to an upstream proxy.
if flow.request.form_in == "authority" and flow.response.code == 200:
# TODO: Eventually add headers (space/usefulness tradeoff)
# Make sure to add state info before the actual upgrade happens.
@ -1034,7 +1035,15 @@ class HTTPHandler(ProtocolHandler):
self.c.server_conn.state.append(("http", {"state": "connect",
"host": flow.request.host,
"port": flow.request.port}))
if self.c.check_ignore_address((flow.request.host, flow.request.port)):
self.c.log("Ignore host: %s:%s" % self.c.server_conn.address(), "info")
TCPHandler(self.c).handle_messages()
return False
else:
if flow.request.port in self.c.config.ssl_ports:
self.ssl_upgrade()
self.skip_authentication = True
# If the user has changed the target server on this connection,
# restore the original target server
@ -1065,7 +1074,7 @@ class HTTPHandler(ProtocolHandler):
if flow:
flow.error = Error(message)
# FIXME: no flows without request or with both request and response at the moement.
# FIXME: no flows without request or with both request and response at the moment.
if flow.request and not flow.response:
self.c.channel.ask("error", flow.error)
else:
@ -1103,6 +1112,13 @@ class HTTPHandler(ProtocolHandler):
self.c.log("Upgrade to SSL completed.", "debug")
def process_request(self, flow, request):
"""
@returns:
True, if the request should not be sent upstream
False, if the connection should be aborted
None, if the request should be sent upstream
(a status code != None should be returned directly by handle_flow)
"""
if not self.skip_authentication:
self.authenticate(request)
@ -1115,8 +1131,8 @@ class HTTPHandler(ProtocolHandler):
if not self.c.config.get_upstream_server:
self.c.set_server_address((request.host, request.port),
proxy.AddressPriority.FROM_PROTOCOL)
self.c.establish_server_connection()
flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow
self.c.establish_server_connection()
self.c.client_conn.send(
'HTTP/1.1 200 Connection established\r\n' +
'Content-Length: 0\r\n' +
@ -1124,18 +1140,24 @@ class HTTPHandler(ProtocolHandler):
'\r\n'
)
self.ssl_upgrade()
self.skip_authentication = True
if self.c.check_ignore_address(self.c.server_conn.address):
self.c.log("Ignore host: %s:%s" % self.c.server_conn.address(), "info")
TCPHandler(self.c).handle_messages()
return False
else:
if self.c.server_conn.address.port in self.c.config.ssl_ports:
self.ssl_upgrade()
self.skip_authentication = True
return True
else:
return None
elif request.form_in == self.expected_form_in:
if request.form_in == "absolute":
if request.scheme != "http":
raise http.HttpError(400, "Invalid request scheme: %s" % request.scheme)
self.determine_server_address(flow, request)
request.form_out = self.expected_form_out
return True
return None
raise http.HttpError(400, "Invalid HTTP request form (expected: %s, got: %s)" %
(self.expected_form_in, request.form_in))
@ -1172,10 +1194,11 @@ class RequestReplayThread(threading.Thread):
server_address, server_ssl = False, False
if self.config.get_upstream_server:
try:
# this will fail in transparent mode
upstream_info = self.config.get_upstream_server(self.flow.client_conn)
server_ssl = upstream_info[1]
server_address = upstream_info[2:]
except proxy.ProxyError: # this will fail in transparent mode
except proxy.ProxyError:
pass
if not server_address:
server_address = (r.get_host(), r.get_port())
@ -1184,7 +1207,7 @@ class RequestReplayThread(threading.Thread):
server.connect()
if server_ssl or r.get_scheme() == "https":
if self.config.http_form_out == "absolute":
if self.config.http_form_out == "absolute": # form_out == absolute -> forward mode -> send CONNECT
send_connect_request(server, r.get_host(), r.get_port())
r.form_out = "relative"
server.establish_ssl(self.config.clientcerts,

View File

@ -18,7 +18,7 @@ class TCPHandler(ProtocolHandler):
buf = memoryview(bytearray(self.chunk_size))
conns = [self.c.client_conn.rfile, self.c.server_conn.rfile]
while not self.c.close:
while True:
r, _, _ = select.select(conns, [], [], 10)
for rfile in r:
if self.c.client_conn.rfile == rfile:
@ -51,7 +51,7 @@ class TCPHandler(ProtocolHandler):
dst.connection.shutdown(socket.SHUT_WR)
if len(conns) == 0:
self.c.close = True
return
continue
if src.ssl_established or dst.ssl_established:

View File

@ -15,7 +15,7 @@ class ProxyConfig:
no_upstream_cert=False, body_size_limit=None,
mode=None, upstream_server=None, http_form_in=None, http_form_out=None,
authenticator=None, ignore=[],
ciphers=None, certs=[], certforward=False):
ciphers=None, certs=[], certforward=False, ssl_ports=TRANSPARENT_SSL_PORTS):
self.ciphers = ciphers
self.clientcerts = clientcerts
self.no_upstream_cert = no_upstream_cert
@ -49,6 +49,7 @@ class ProxyConfig:
for spec, cert in certs:
self.certstore.add_cert_file(spec, cert)
self.certforward = certforward
self.ssl_ports = ssl_ports
def process_proxy_options(parser, options):
@ -158,3 +159,9 @@ def ssl_option_group(parser):
action="store_true", dest="no_upstream_cert",
help="Don't connect to upstream server to look up certificate details."
)
group.add_argument(
"--ssl-port", action="append", type=int, dest="ssl_ports", default=TRANSPARENT_SSL_PORTS,
metavar="PORT",
help="Can be passed multiple times. Specify destination ports which are assumed to be SSL. "
"Defaults to %s." % str(TRANSPARENT_SSL_PORTS)
)

View File

@ -6,13 +6,6 @@ class ProxyError(Exception):
super(ProxyError, self).__init__(self, message)
self.code, self.headers = code, headers
class ConnectionTypeChange(Exception):
"""
Gets raised if the connection type has been changed (e.g. after HTTP/1.1 101 Switching Protocols).
It's up to the raising ProtocolHandler to specify the new conntype before raising the exception.
"""
pass
class ProxyServerError(Exception):
pass

View File

@ -5,10 +5,9 @@ import socket
from OpenSSL import SSL
from netlib import tcp
from .primitives import ProxyServerError, Log, ProxyError, ConnectionTypeChange, \
AddressPriority
from .primitives import ProxyServerError, Log, ProxyError, AddressPriority
from .connection import ClientConnection, ServerConnection
from ..protocol.handle import handle_messages, handle_error, handle_server_reconnect
from ..protocol.handle import protocol_handler
from .. import version
@ -66,7 +65,6 @@ class ConnectionHandler:
"""@type: libmproxy.proxy.connection.ServerConnection"""
self.channel, self.server_version = channel, server_version
self.close = False
self.conntype = "http"
self.sni = None
@ -77,28 +75,28 @@ class ConnectionHandler:
# Can we already identify the target server and connect to it?
client_ssl, server_ssl = False, False
if self.config.get_upstream_server:
upstream_info = self.config.get_upstream_server(
self.client_conn.connection)
upstream_info = self.config.get_upstream_server(self.client_conn.connection)
self.set_server_address(upstream_info[2:], AddressPriority.FROM_SETTINGS)
client_ssl, server_ssl = upstream_info[:2]
if self.check_ignore_address(self.server_conn.address):
self.log("Ignore host: %s:%s" % self.server_conn.address(), "info")
self.conntype = "tcp"
client_ssl, server_ssl = False, False
self.determine_conntype()
self.channel.ask("clientconnect", self)
# Check for existing connection: If an inline script already established a
# connection, do not apply client_ssl or server_ssl.
if self.server_conn and not self.server_conn.connection:
self.establish_server_connection()
if client_ssl or server_ssl:
self.establish_ssl(client=client_ssl, server=server_ssl)
while not self.close:
try:
handle_messages(self.conntype, self)
except ConnectionTypeChange:
self.log("Connection type changed: %s" % self.conntype, "info")
continue
# Delegate handling to the protocol handler
protocol_handler(self.conntype)(self).handle_messages()
except (ProxyError, tcp.NetLibError), e:
handle_error(self.conntype, self, e)
protocol_handler(self.conntype)(self).handle_error(e)
except Exception:
import traceback, sys
@ -106,7 +104,6 @@ class ConnectionHandler:
print >> sys.stderr, traceback.format_exc()
print >> sys.stderr, "mitmproxy has crashed!"
print >> sys.stderr, "Please lodge a bug report at: https://github.com/mitmproxy/mitmproxy"
raise
self.del_server_connection()
self.log("clientdisconnect", "info")
@ -124,12 +121,13 @@ class ConnectionHandler:
self.server_conn = None
self.sni = None
def determine_conntype(self):
if self.server_conn and any(rex.search(self.server_conn.address.host) for rex in self.config.ignore):
self.log("Ignore host: %s" % self.server_conn.address.host, "info")
self.conntype = "tcp"
def check_ignore_address(self, address):
address = tcp.Address.wrap(address)
host = "%s:%s" % (address.host, address.port)
if host and any(rex.search(host) for rex in self.config.ignore):
return True
else:
self.conntype = "http"
return False
def set_server_address(self, address, priority):
"""
@ -175,15 +173,8 @@ class ConnectionHandler:
def establish_ssl(self, client=False, server=False):
"""
Establishes SSL on the existing connection(s) to the server or the client,
as specified by the parameters. If the target server is on the pass-through list,
the conntype attribute will be changed and a ConnTypeChanged exception will be raised.
as specified by the parameters.
"""
# If the host is on our ignore list, change to passthrough/ignore mode.
for host in (self.server_conn.address.host, self.sni):
if host and any(rex.search(host) for rex in self.config.ignore):
self.log("Ignore host: %s" % host, "info")
self.conntype = "tcp"
raise ConnectionTypeChange()
# Logging
if client or server:
@ -224,7 +215,7 @@ class ConnectionHandler:
self.establish_server_connection()
for s in state:
handle_server_reconnect(s[0], self, s[1])
protocol_handler(s[0])(self).handle_server_reconnect(s[1])
self.server_conn.state = state
if had_ssl:
@ -288,4 +279,5 @@ class ConnectionHandler:
# make dang sure it doesn't happen.
except Exception: # pragma: no cover
import traceback
self.log("Error in handle_sni:\r\n" + traceback.format_exc(), "error")

View File

@ -175,6 +175,7 @@ class TestHTTPAuth(tservers.HTTPProxTest):
class TestHTTPConnectSSLError(tservers.HTTPProxTest):
certfile = True
def test_go(self):
self.config.ssl_ports.append(self.proxy.port)
p = self.pathoc_raw()
dst = ("localhost", self.proxy.port)
p.connect(connect_to=dst)

View File

@ -95,6 +95,7 @@ class ProxTestBase(object):
confdir = cls.confdir,
authenticator = cls.authenticator,
certforward = cls.certforward,
ssl_ports=([cls.server.port, cls.server2.port] if cls.ssl else []),
**pconf
)
tmaster = cls.masterclass(cls.config)
@ -267,17 +268,20 @@ class ChainProxTest(ProxTestBase):
Chain n instances of mitmproxy in a row - because we can.
"""
n = 2
chain_config = [lambda port: ProxyConfig(
chain_config = [lambda port, sslports: ProxyConfig(
upstream_server= (False, False, "127.0.0.1", port),
http_form_in = "absolute",
http_form_out = "absolute"
http_form_out = "absolute",
ssl_ports=sslports
)] * n
@classmethod
def setupAll(cls):
super(ChainProxTest, cls).setupAll()
cls.chain = []
for i in range(cls.n):
config = cls.chain_config[i](cls.proxy.port if i == 0 else cls.chain[-1].port)
sslports = [cls.server.port, cls.server2.port]
config = cls.chain_config[i](cls.proxy.port if i == 0 else cls.chain[-1].port,
sslports)
tmaster = cls.masterclass(config)
tmaster.start_app(APP_HOST, APP_PORT, cls.externalapp)
cls.chain.append(ProxyThread(tmaster))