mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 08:11:00 +00:00
clean up cert handling, fix mitmproxy/mitmproxy#472
This commit is contained in:
parent
24a3dd59fe
commit
dbadc1b613
140
netlib/tcp.py
140
netlib/tcp.py
@ -302,6 +302,43 @@ class _Connection(object):
|
||||
except SSL.Error:
|
||||
pass
|
||||
|
||||
"""
|
||||
Creates an SSL Context.
|
||||
"""
|
||||
def _create_ssl_context(self,
|
||||
method=SSLv23_METHOD,
|
||||
options=(OP_NO_SSLv2 | OP_NO_SSLv3),
|
||||
cipher_list=None
|
||||
):
|
||||
"""
|
||||
:param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD
|
||||
:param options: A bit field consisting of OpenSSL.SSL.OP_* values
|
||||
:param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html
|
||||
:rtype : SSL.Context
|
||||
"""
|
||||
context = SSL.Context(method)
|
||||
# Options (NO_SSLv2/3)
|
||||
if options is not None:
|
||||
context.set_options(options)
|
||||
|
||||
# Workaround for
|
||||
# https://github.com/pyca/pyopenssl/issues/190
|
||||
# https://github.com/mitmproxy/mitmproxy/issues/472
|
||||
context.set_mode(SSL._lib.SSL_MODE_AUTO_RETRY) # Options already set before are not cleared.
|
||||
|
||||
# Cipher List
|
||||
if cipher_list:
|
||||
try:
|
||||
context.set_cipher_list(cipher_list)
|
||||
except SSL.Error, v:
|
||||
raise NetLibError("SSL cipher specification error: %s"%str(v))
|
||||
|
||||
# SSLKEYLOGFILE
|
||||
if log_ssl_key:
|
||||
context.set_info_callback(log_ssl_key)
|
||||
|
||||
return context
|
||||
|
||||
|
||||
class TCPClient(_Connection):
|
||||
rbufsize = -1
|
||||
@ -324,32 +361,28 @@ class TCPClient(_Connection):
|
||||
self.ssl_established = False
|
||||
self.sni = None
|
||||
|
||||
def convert_to_ssl(self, cert=None, sni=None, method=SSLv23_METHOD, options=(OP_NO_SSLv2 | OP_NO_SSLv3), cipher_list=None):
|
||||
"""
|
||||
cert: Path to a file containing both client cert and private key.
|
||||
|
||||
options: A bit field consisting of OpenSSL.SSL.OP_* values
|
||||
"""
|
||||
context = SSL.Context(method)
|
||||
if cipher_list:
|
||||
try:
|
||||
context.set_cipher_list(cipher_list)
|
||||
except SSL.Error, v:
|
||||
raise NetLibError("SSL cipher specification error: %s"%str(v))
|
||||
if options is not None:
|
||||
context.set_options(options)
|
||||
def create_ssl_context(self, cert=None, **sslctx_kwargs):
|
||||
context = self._create_ssl_context(**sslctx_kwargs)
|
||||
# Client Certs
|
||||
if cert:
|
||||
try:
|
||||
context.use_privatekey_file(cert)
|
||||
context.use_certificate_file(cert)
|
||||
except SSL.Error, v:
|
||||
raise NetLibError("SSL client certificate error: %s"%str(v))
|
||||
return context
|
||||
|
||||
def convert_to_ssl(self, sni=None, **sslctx_kwargs):
|
||||
"""
|
||||
cert: Path to a file containing both client cert and private key.
|
||||
|
||||
options: A bit field consisting of OpenSSL.SSL.OP_* values
|
||||
"""
|
||||
context = self.create_ssl_context(**sslctx_kwargs)
|
||||
self.connection = SSL.Connection(context, self.connection)
|
||||
if sni:
|
||||
self.sni = sni
|
||||
self.connection.set_tlsext_host_name(sni)
|
||||
if log_ssl_key:
|
||||
context.set_info_callback(log_ssl_key)
|
||||
self.connection.set_connect_state()
|
||||
try:
|
||||
self.connection.do_handshake()
|
||||
@ -400,21 +433,21 @@ class BaseHandler(_Connection):
|
||||
self.ssl_established = False
|
||||
self.clientcert = None
|
||||
|
||||
def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=OP_NO_SSLv2,
|
||||
handle_sni=None, request_client_cert=None, cipher_list=None,
|
||||
dhparams=None, chain_file=None):
|
||||
def create_ssl_context(self,
|
||||
cert, key,
|
||||
handle_sni=None,
|
||||
request_client_cert=None,
|
||||
chain_file=None,
|
||||
dhparams=None,
|
||||
**sslctx_kwargs):
|
||||
"""
|
||||
cert: A certutils.SSLCert object.
|
||||
|
||||
method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD
|
||||
|
||||
handle_sni: SNI handler, should take a connection object. Server
|
||||
name can be retrieved like this:
|
||||
|
||||
connection.get_servername()
|
||||
|
||||
options: A bit field consisting of OpenSSL.SSL.OP_* values
|
||||
|
||||
And you can specify the connection keys as follows:
|
||||
|
||||
new_context = Context(TLSv1_METHOD)
|
||||
@ -431,40 +464,38 @@ class BaseHandler(_Connection):
|
||||
we may be able to make the proper behaviour the default again, but
|
||||
until then we're conservative.
|
||||
"""
|
||||
ctx = SSL.Context(method)
|
||||
if not options is None:
|
||||
ctx.set_options(options)
|
||||
if chain_file:
|
||||
ctx.load_verify_locations(chain_file)
|
||||
if cipher_list:
|
||||
try:
|
||||
ctx.set_cipher_list(cipher_list)
|
||||
except SSL.Error, v:
|
||||
raise NetLibError("SSL cipher specification error: %s"%str(v))
|
||||
context = self._create_ssl_context(**sslctx_kwargs)
|
||||
|
||||
context.use_privatekey(key)
|
||||
context.use_certificate(cert.x509)
|
||||
|
||||
if handle_sni:
|
||||
# SNI callback happens during do_handshake()
|
||||
ctx.set_tlsext_servername_callback(handle_sni)
|
||||
ctx.use_privatekey(key)
|
||||
ctx.use_certificate(cert.x509)
|
||||
if dhparams:
|
||||
SSL._lib.SSL_CTX_set_tmp_dh(ctx._context, dhparams)
|
||||
context.set_tlsext_servername_callback(handle_sni)
|
||||
|
||||
if request_client_cert:
|
||||
def ver(*args):
|
||||
self.clientcert = certutils.SSLCert(args[1])
|
||||
def save_cert(conn, cert, errno, depth, preverify_ok):
|
||||
self.clientcert = certutils.SSLCert(cert)
|
||||
# Return true to prevent cert verification error
|
||||
return True
|
||||
ctx.set_verify(SSL.VERIFY_PEER, ver)
|
||||
if log_ssl_key:
|
||||
ctx.set_info_callback(log_ssl_key)
|
||||
return ctx
|
||||
context.set_verify(SSL.VERIFY_PEER, save_cert)
|
||||
|
||||
# Cert Verify
|
||||
if chain_file:
|
||||
context.load_verify_locations(chain_file)
|
||||
|
||||
if dhparams:
|
||||
SSL._lib.SSL_CTX_set_tmp_dh(context._context, dhparams)
|
||||
|
||||
return context
|
||||
|
||||
def convert_to_ssl(self, cert, key, **sslctx_kwargs):
|
||||
"""
|
||||
Convert connection to SSL.
|
||||
For a list of parameters, see BaseHandler._create_ssl_context(...)
|
||||
"""
|
||||
ctx = self._create_ssl_context(cert, key, **sslctx_kwargs)
|
||||
self.connection = SSL.Connection(ctx, self.connection)
|
||||
context = self.create_ssl_context(cert, key, **sslctx_kwargs)
|
||||
self.connection = SSL.Connection(context, self.connection)
|
||||
self.connection.set_accept_state()
|
||||
try:
|
||||
self.connection.do_handshake()
|
||||
@ -474,7 +505,7 @@ class BaseHandler(_Connection):
|
||||
self.rfile.set_descriptor(self.connection)
|
||||
self.wfile.set_descriptor(self.connection)
|
||||
|
||||
def handle(self): # pragma: no cover
|
||||
def handle(self): # pragma: no cover
|
||||
raise NotImplementedError
|
||||
|
||||
def settimeout(self, n):
|
||||
@ -483,6 +514,7 @@ class BaseHandler(_Connection):
|
||||
|
||||
class TCPServer(object):
|
||||
request_queue_size = 20
|
||||
|
||||
def __init__(self, address):
|
||||
self.address = Address.wrap(address)
|
||||
self.__is_shut_down = threading.Event()
|
||||
@ -508,7 +540,7 @@ class TCPServer(object):
|
||||
while not self.__shutdown_request:
|
||||
try:
|
||||
r, w, e = select.select([self.socket], [], [], poll_interval)
|
||||
except select.error, ex: # pragma: no cover
|
||||
except select.error as ex: # pragma: no cover
|
||||
if ex[0] == EINTR:
|
||||
continue
|
||||
else:
|
||||
@ -516,12 +548,12 @@ class TCPServer(object):
|
||||
if self.socket in r:
|
||||
connection, client_address = self.socket.accept()
|
||||
t = threading.Thread(
|
||||
target = self.connection_thread,
|
||||
args = (connection, client_address),
|
||||
name = "ConnectionThread (%s:%s -> %s:%s)" %
|
||||
(client_address[0], client_address[1],
|
||||
self.address.host, self.address.port)
|
||||
)
|
||||
target=self.connection_thread,
|
||||
args=(connection, client_address),
|
||||
name="ConnectionThread (%s:%s -> %s:%s)" %
|
||||
(client_address[0], client_address[1],
|
||||
self.address.host, self.address.port)
|
||||
)
|
||||
t.setDaemon(1)
|
||||
t.start()
|
||||
finally:
|
||||
|
@ -31,7 +31,7 @@ def raises(exc, obj, *args, **kwargs):
|
||||
:kwargs Arguments to be passed to the callable.
|
||||
"""
|
||||
try:
|
||||
apply(obj, args, kwargs)
|
||||
ret = apply(obj, args, kwargs)
|
||||
except Exception, v:
|
||||
if isinstance(exc, basestring):
|
||||
if exc.lower() in str(v).lower():
|
||||
@ -51,6 +51,6 @@ def raises(exc, obj, *args, **kwargs):
|
||||
exc.__name__, v.__class__.__name__, str(v)
|
||||
)
|
||||
)
|
||||
raise AssertionError("No exception raised.")
|
||||
raise AssertionError("No exception raised. Return value: {}".format(ret))
|
||||
|
||||
test_data = utils.Data(__name__)
|
||||
|
Loading…
Reference in New Issue
Block a user