clean up cert handling, fix mitmproxy/mitmproxy#472

This commit is contained in:
Maximilian Hils 2015-03-07 01:22:02 +01:00
parent 24a3dd59fe
commit dbadc1b613
2 changed files with 88 additions and 56 deletions

View File

@ -302,6 +302,43 @@ class _Connection(object):
except SSL.Error: except SSL.Error:
pass pass
"""
Creates an SSL Context.
"""
def _create_ssl_context(self,
method=SSLv23_METHOD,
options=(OP_NO_SSLv2 | OP_NO_SSLv3),
cipher_list=None
):
"""
:param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD
:param options: A bit field consisting of OpenSSL.SSL.OP_* values
:param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html
:rtype : SSL.Context
"""
context = SSL.Context(method)
# Options (NO_SSLv2/3)
if options is not None:
context.set_options(options)
# Workaround for
# https://github.com/pyca/pyopenssl/issues/190
# https://github.com/mitmproxy/mitmproxy/issues/472
context.set_mode(SSL._lib.SSL_MODE_AUTO_RETRY) # Options already set before are not cleared.
# Cipher List
if cipher_list:
try:
context.set_cipher_list(cipher_list)
except SSL.Error, v:
raise NetLibError("SSL cipher specification error: %s"%str(v))
# SSLKEYLOGFILE
if log_ssl_key:
context.set_info_callback(log_ssl_key)
return context
class TCPClient(_Connection): class TCPClient(_Connection):
rbufsize = -1 rbufsize = -1
@ -324,32 +361,28 @@ class TCPClient(_Connection):
self.ssl_established = False self.ssl_established = False
self.sni = None self.sni = None
def convert_to_ssl(self, cert=None, sni=None, method=SSLv23_METHOD, options=(OP_NO_SSLv2 | OP_NO_SSLv3), cipher_list=None): def create_ssl_context(self, cert=None, **sslctx_kwargs):
""" context = self._create_ssl_context(**sslctx_kwargs)
cert: Path to a file containing both client cert and private key. # Client Certs
options: A bit field consisting of OpenSSL.SSL.OP_* values
"""
context = SSL.Context(method)
if cipher_list:
try:
context.set_cipher_list(cipher_list)
except SSL.Error, v:
raise NetLibError("SSL cipher specification error: %s"%str(v))
if options is not None:
context.set_options(options)
if cert: if cert:
try: try:
context.use_privatekey_file(cert) context.use_privatekey_file(cert)
context.use_certificate_file(cert) context.use_certificate_file(cert)
except SSL.Error, v: except SSL.Error, v:
raise NetLibError("SSL client certificate error: %s"%str(v)) raise NetLibError("SSL client certificate error: %s"%str(v))
return context
def convert_to_ssl(self, sni=None, **sslctx_kwargs):
"""
cert: Path to a file containing both client cert and private key.
options: A bit field consisting of OpenSSL.SSL.OP_* values
"""
context = self.create_ssl_context(**sslctx_kwargs)
self.connection = SSL.Connection(context, self.connection) self.connection = SSL.Connection(context, self.connection)
if sni: if sni:
self.sni = sni self.sni = sni
self.connection.set_tlsext_host_name(sni) self.connection.set_tlsext_host_name(sni)
if log_ssl_key:
context.set_info_callback(log_ssl_key)
self.connection.set_connect_state() self.connection.set_connect_state()
try: try:
self.connection.do_handshake() self.connection.do_handshake()
@ -400,21 +433,21 @@ class BaseHandler(_Connection):
self.ssl_established = False self.ssl_established = False
self.clientcert = None self.clientcert = None
def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=OP_NO_SSLv2, def create_ssl_context(self,
handle_sni=None, request_client_cert=None, cipher_list=None, cert, key,
dhparams=None, chain_file=None): handle_sni=None,
request_client_cert=None,
chain_file=None,
dhparams=None,
**sslctx_kwargs):
""" """
cert: A certutils.SSLCert object. cert: A certutils.SSLCert object.
method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD
handle_sni: SNI handler, should take a connection object. Server handle_sni: SNI handler, should take a connection object. Server
name can be retrieved like this: name can be retrieved like this:
connection.get_servername() connection.get_servername()
options: A bit field consisting of OpenSSL.SSL.OP_* values
And you can specify the connection keys as follows: And you can specify the connection keys as follows:
new_context = Context(TLSv1_METHOD) new_context = Context(TLSv1_METHOD)
@ -431,40 +464,38 @@ class BaseHandler(_Connection):
we may be able to make the proper behaviour the default again, but we may be able to make the proper behaviour the default again, but
until then we're conservative. until then we're conservative.
""" """
ctx = SSL.Context(method) context = self._create_ssl_context(**sslctx_kwargs)
if not options is None:
ctx.set_options(options) context.use_privatekey(key)
if chain_file: context.use_certificate(cert.x509)
ctx.load_verify_locations(chain_file)
if cipher_list:
try:
ctx.set_cipher_list(cipher_list)
except SSL.Error, v:
raise NetLibError("SSL cipher specification error: %s"%str(v))
if handle_sni: if handle_sni:
# SNI callback happens during do_handshake() # SNI callback happens during do_handshake()
ctx.set_tlsext_servername_callback(handle_sni) context.set_tlsext_servername_callback(handle_sni)
ctx.use_privatekey(key)
ctx.use_certificate(cert.x509)
if dhparams:
SSL._lib.SSL_CTX_set_tmp_dh(ctx._context, dhparams)
if request_client_cert: if request_client_cert:
def ver(*args): def save_cert(conn, cert, errno, depth, preverify_ok):
self.clientcert = certutils.SSLCert(args[1]) self.clientcert = certutils.SSLCert(cert)
# Return true to prevent cert verification error # Return true to prevent cert verification error
return True return True
ctx.set_verify(SSL.VERIFY_PEER, ver) context.set_verify(SSL.VERIFY_PEER, save_cert)
if log_ssl_key:
ctx.set_info_callback(log_ssl_key) # Cert Verify
return ctx if chain_file:
context.load_verify_locations(chain_file)
if dhparams:
SSL._lib.SSL_CTX_set_tmp_dh(context._context, dhparams)
return context
def convert_to_ssl(self, cert, key, **sslctx_kwargs): def convert_to_ssl(self, cert, key, **sslctx_kwargs):
""" """
Convert connection to SSL. Convert connection to SSL.
For a list of parameters, see BaseHandler._create_ssl_context(...) For a list of parameters, see BaseHandler._create_ssl_context(...)
""" """
ctx = self._create_ssl_context(cert, key, **sslctx_kwargs) context = self.create_ssl_context(cert, key, **sslctx_kwargs)
self.connection = SSL.Connection(ctx, self.connection) self.connection = SSL.Connection(context, self.connection)
self.connection.set_accept_state() self.connection.set_accept_state()
try: try:
self.connection.do_handshake() self.connection.do_handshake()
@ -474,7 +505,7 @@ class BaseHandler(_Connection):
self.rfile.set_descriptor(self.connection) self.rfile.set_descriptor(self.connection)
self.wfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection)
def handle(self): # pragma: no cover def handle(self): # pragma: no cover
raise NotImplementedError raise NotImplementedError
def settimeout(self, n): def settimeout(self, n):
@ -483,6 +514,7 @@ class BaseHandler(_Connection):
class TCPServer(object): class TCPServer(object):
request_queue_size = 20 request_queue_size = 20
def __init__(self, address): def __init__(self, address):
self.address = Address.wrap(address) self.address = Address.wrap(address)
self.__is_shut_down = threading.Event() self.__is_shut_down = threading.Event()
@ -508,7 +540,7 @@ class TCPServer(object):
while not self.__shutdown_request: while not self.__shutdown_request:
try: try:
r, w, e = select.select([self.socket], [], [], poll_interval) r, w, e = select.select([self.socket], [], [], poll_interval)
except select.error, ex: # pragma: no cover except select.error as ex: # pragma: no cover
if ex[0] == EINTR: if ex[0] == EINTR:
continue continue
else: else:
@ -516,12 +548,12 @@ class TCPServer(object):
if self.socket in r: if self.socket in r:
connection, client_address = self.socket.accept() connection, client_address = self.socket.accept()
t = threading.Thread( t = threading.Thread(
target = self.connection_thread, target=self.connection_thread,
args = (connection, client_address), args=(connection, client_address),
name = "ConnectionThread (%s:%s -> %s:%s)" % name="ConnectionThread (%s:%s -> %s:%s)" %
(client_address[0], client_address[1], (client_address[0], client_address[1],
self.address.host, self.address.port) self.address.host, self.address.port)
) )
t.setDaemon(1) t.setDaemon(1)
t.start() t.start()
finally: finally:

View File

@ -31,7 +31,7 @@ def raises(exc, obj, *args, **kwargs):
:kwargs Arguments to be passed to the callable. :kwargs Arguments to be passed to the callable.
""" """
try: try:
apply(obj, args, kwargs) ret = apply(obj, args, kwargs)
except Exception, v: except Exception, v:
if isinstance(exc, basestring): if isinstance(exc, basestring):
if exc.lower() in str(v).lower(): if exc.lower() in str(v).lower():
@ -51,6 +51,6 @@ def raises(exc, obj, *args, **kwargs):
exc.__name__, v.__class__.__name__, str(v) exc.__name__, v.__class__.__name__, str(v)
) )
) )
raise AssertionError("No exception raised.") raise AssertionError("No exception raised. Return value: {}".format(ret))
test_data = utils.Data(__name__) test_data = utils.Data(__name__)