clean up cert handling, fix mitmproxy/mitmproxy#472

This commit is contained in:
Maximilian Hils 2015-03-07 01:22:02 +01:00
parent 24a3dd59fe
commit dbadc1b613
2 changed files with 88 additions and 56 deletions

View File

@ -302,6 +302,43 @@ class _Connection(object):
except SSL.Error:
pass
"""
Creates an SSL Context.
"""
def _create_ssl_context(self,
method=SSLv23_METHOD,
options=(OP_NO_SSLv2 | OP_NO_SSLv3),
cipher_list=None
):
"""
:param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD
:param options: A bit field consisting of OpenSSL.SSL.OP_* values
:param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html
:rtype : SSL.Context
"""
context = SSL.Context(method)
# Options (NO_SSLv2/3)
if options is not None:
context.set_options(options)
# Workaround for
# https://github.com/pyca/pyopenssl/issues/190
# https://github.com/mitmproxy/mitmproxy/issues/472
context.set_mode(SSL._lib.SSL_MODE_AUTO_RETRY) # Options already set before are not cleared.
# Cipher List
if cipher_list:
try:
context.set_cipher_list(cipher_list)
except SSL.Error, v:
raise NetLibError("SSL cipher specification error: %s"%str(v))
# SSLKEYLOGFILE
if log_ssl_key:
context.set_info_callback(log_ssl_key)
return context
class TCPClient(_Connection):
rbufsize = -1
@ -324,32 +361,28 @@ class TCPClient(_Connection):
self.ssl_established = False
self.sni = None
def convert_to_ssl(self, cert=None, sni=None, method=SSLv23_METHOD, options=(OP_NO_SSLv2 | OP_NO_SSLv3), cipher_list=None):
"""
cert: Path to a file containing both client cert and private key.
options: A bit field consisting of OpenSSL.SSL.OP_* values
"""
context = SSL.Context(method)
if cipher_list:
try:
context.set_cipher_list(cipher_list)
except SSL.Error, v:
raise NetLibError("SSL cipher specification error: %s"%str(v))
if options is not None:
context.set_options(options)
def create_ssl_context(self, cert=None, **sslctx_kwargs):
context = self._create_ssl_context(**sslctx_kwargs)
# Client Certs
if cert:
try:
context.use_privatekey_file(cert)
context.use_certificate_file(cert)
except SSL.Error, v:
raise NetLibError("SSL client certificate error: %s"%str(v))
return context
def convert_to_ssl(self, sni=None, **sslctx_kwargs):
"""
cert: Path to a file containing both client cert and private key.
options: A bit field consisting of OpenSSL.SSL.OP_* values
"""
context = self.create_ssl_context(**sslctx_kwargs)
self.connection = SSL.Connection(context, self.connection)
if sni:
self.sni = sni
self.connection.set_tlsext_host_name(sni)
if log_ssl_key:
context.set_info_callback(log_ssl_key)
self.connection.set_connect_state()
try:
self.connection.do_handshake()
@ -400,21 +433,21 @@ class BaseHandler(_Connection):
self.ssl_established = False
self.clientcert = None
def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=OP_NO_SSLv2,
handle_sni=None, request_client_cert=None, cipher_list=None,
dhparams=None, chain_file=None):
def create_ssl_context(self,
cert, key,
handle_sni=None,
request_client_cert=None,
chain_file=None,
dhparams=None,
**sslctx_kwargs):
"""
cert: A certutils.SSLCert object.
method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD
handle_sni: SNI handler, should take a connection object. Server
name can be retrieved like this:
connection.get_servername()
options: A bit field consisting of OpenSSL.SSL.OP_* values
And you can specify the connection keys as follows:
new_context = Context(TLSv1_METHOD)
@ -431,40 +464,38 @@ class BaseHandler(_Connection):
we may be able to make the proper behaviour the default again, but
until then we're conservative.
"""
ctx = SSL.Context(method)
if not options is None:
ctx.set_options(options)
if chain_file:
ctx.load_verify_locations(chain_file)
if cipher_list:
try:
ctx.set_cipher_list(cipher_list)
except SSL.Error, v:
raise NetLibError("SSL cipher specification error: %s"%str(v))
context = self._create_ssl_context(**sslctx_kwargs)
context.use_privatekey(key)
context.use_certificate(cert.x509)
if handle_sni:
# SNI callback happens during do_handshake()
ctx.set_tlsext_servername_callback(handle_sni)
ctx.use_privatekey(key)
ctx.use_certificate(cert.x509)
if dhparams:
SSL._lib.SSL_CTX_set_tmp_dh(ctx._context, dhparams)
context.set_tlsext_servername_callback(handle_sni)
if request_client_cert:
def ver(*args):
self.clientcert = certutils.SSLCert(args[1])
def save_cert(conn, cert, errno, depth, preverify_ok):
self.clientcert = certutils.SSLCert(cert)
# Return true to prevent cert verification error
return True
ctx.set_verify(SSL.VERIFY_PEER, ver)
if log_ssl_key:
ctx.set_info_callback(log_ssl_key)
return ctx
context.set_verify(SSL.VERIFY_PEER, save_cert)
# Cert Verify
if chain_file:
context.load_verify_locations(chain_file)
if dhparams:
SSL._lib.SSL_CTX_set_tmp_dh(context._context, dhparams)
return context
def convert_to_ssl(self, cert, key, **sslctx_kwargs):
"""
Convert connection to SSL.
For a list of parameters, see BaseHandler._create_ssl_context(...)
"""
ctx = self._create_ssl_context(cert, key, **sslctx_kwargs)
self.connection = SSL.Connection(ctx, self.connection)
context = self.create_ssl_context(cert, key, **sslctx_kwargs)
self.connection = SSL.Connection(context, self.connection)
self.connection.set_accept_state()
try:
self.connection.do_handshake()
@ -483,6 +514,7 @@ class BaseHandler(_Connection):
class TCPServer(object):
request_queue_size = 20
def __init__(self, address):
self.address = Address.wrap(address)
self.__is_shut_down = threading.Event()
@ -508,7 +540,7 @@ class TCPServer(object):
while not self.__shutdown_request:
try:
r, w, e = select.select([self.socket], [], [], poll_interval)
except select.error, ex: # pragma: no cover
except select.error as ex: # pragma: no cover
if ex[0] == EINTR:
continue
else:
@ -516,9 +548,9 @@ class TCPServer(object):
if self.socket in r:
connection, client_address = self.socket.accept()
t = threading.Thread(
target = self.connection_thread,
args = (connection, client_address),
name = "ConnectionThread (%s:%s -> %s:%s)" %
target=self.connection_thread,
args=(connection, client_address),
name="ConnectionThread (%s:%s -> %s:%s)" %
(client_address[0], client_address[1],
self.address.host, self.address.port)
)

View File

@ -31,7 +31,7 @@ def raises(exc, obj, *args, **kwargs):
:kwargs Arguments to be passed to the callable.
"""
try:
apply(obj, args, kwargs)
ret = apply(obj, args, kwargs)
except Exception, v:
if isinstance(exc, basestring):
if exc.lower() in str(v).lower():
@ -51,6 +51,6 @@ def raises(exc, obj, *args, **kwargs):
exc.__name__, v.__class__.__name__, str(v)
)
)
raise AssertionError("No exception raised.")
raise AssertionError("No exception raised. Return value: {}".format(ret))
test_data = utils.Data(__name__)