mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-27 02:24:18 +00:00
clean up cert handling, fix mitmproxy/mitmproxy#472
This commit is contained in:
parent
24a3dd59fe
commit
dbadc1b613
140
netlib/tcp.py
140
netlib/tcp.py
@ -302,6 +302,43 @@ class _Connection(object):
|
|||||||
except SSL.Error:
|
except SSL.Error:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
"""
|
||||||
|
Creates an SSL Context.
|
||||||
|
"""
|
||||||
|
def _create_ssl_context(self,
|
||||||
|
method=SSLv23_METHOD,
|
||||||
|
options=(OP_NO_SSLv2 | OP_NO_SSLv3),
|
||||||
|
cipher_list=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD
|
||||||
|
:param options: A bit field consisting of OpenSSL.SSL.OP_* values
|
||||||
|
:param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html
|
||||||
|
:rtype : SSL.Context
|
||||||
|
"""
|
||||||
|
context = SSL.Context(method)
|
||||||
|
# Options (NO_SSLv2/3)
|
||||||
|
if options is not None:
|
||||||
|
context.set_options(options)
|
||||||
|
|
||||||
|
# Workaround for
|
||||||
|
# https://github.com/pyca/pyopenssl/issues/190
|
||||||
|
# https://github.com/mitmproxy/mitmproxy/issues/472
|
||||||
|
context.set_mode(SSL._lib.SSL_MODE_AUTO_RETRY) # Options already set before are not cleared.
|
||||||
|
|
||||||
|
# Cipher List
|
||||||
|
if cipher_list:
|
||||||
|
try:
|
||||||
|
context.set_cipher_list(cipher_list)
|
||||||
|
except SSL.Error, v:
|
||||||
|
raise NetLibError("SSL cipher specification error: %s"%str(v))
|
||||||
|
|
||||||
|
# SSLKEYLOGFILE
|
||||||
|
if log_ssl_key:
|
||||||
|
context.set_info_callback(log_ssl_key)
|
||||||
|
|
||||||
|
return context
|
||||||
|
|
||||||
|
|
||||||
class TCPClient(_Connection):
|
class TCPClient(_Connection):
|
||||||
rbufsize = -1
|
rbufsize = -1
|
||||||
@ -324,32 +361,28 @@ class TCPClient(_Connection):
|
|||||||
self.ssl_established = False
|
self.ssl_established = False
|
||||||
self.sni = None
|
self.sni = None
|
||||||
|
|
||||||
def convert_to_ssl(self, cert=None, sni=None, method=SSLv23_METHOD, options=(OP_NO_SSLv2 | OP_NO_SSLv3), cipher_list=None):
|
def create_ssl_context(self, cert=None, **sslctx_kwargs):
|
||||||
"""
|
context = self._create_ssl_context(**sslctx_kwargs)
|
||||||
cert: Path to a file containing both client cert and private key.
|
# Client Certs
|
||||||
|
|
||||||
options: A bit field consisting of OpenSSL.SSL.OP_* values
|
|
||||||
"""
|
|
||||||
context = SSL.Context(method)
|
|
||||||
if cipher_list:
|
|
||||||
try:
|
|
||||||
context.set_cipher_list(cipher_list)
|
|
||||||
except SSL.Error, v:
|
|
||||||
raise NetLibError("SSL cipher specification error: %s"%str(v))
|
|
||||||
if options is not None:
|
|
||||||
context.set_options(options)
|
|
||||||
if cert:
|
if cert:
|
||||||
try:
|
try:
|
||||||
context.use_privatekey_file(cert)
|
context.use_privatekey_file(cert)
|
||||||
context.use_certificate_file(cert)
|
context.use_certificate_file(cert)
|
||||||
except SSL.Error, v:
|
except SSL.Error, v:
|
||||||
raise NetLibError("SSL client certificate error: %s"%str(v))
|
raise NetLibError("SSL client certificate error: %s"%str(v))
|
||||||
|
return context
|
||||||
|
|
||||||
|
def convert_to_ssl(self, sni=None, **sslctx_kwargs):
|
||||||
|
"""
|
||||||
|
cert: Path to a file containing both client cert and private key.
|
||||||
|
|
||||||
|
options: A bit field consisting of OpenSSL.SSL.OP_* values
|
||||||
|
"""
|
||||||
|
context = self.create_ssl_context(**sslctx_kwargs)
|
||||||
self.connection = SSL.Connection(context, self.connection)
|
self.connection = SSL.Connection(context, self.connection)
|
||||||
if sni:
|
if sni:
|
||||||
self.sni = sni
|
self.sni = sni
|
||||||
self.connection.set_tlsext_host_name(sni)
|
self.connection.set_tlsext_host_name(sni)
|
||||||
if log_ssl_key:
|
|
||||||
context.set_info_callback(log_ssl_key)
|
|
||||||
self.connection.set_connect_state()
|
self.connection.set_connect_state()
|
||||||
try:
|
try:
|
||||||
self.connection.do_handshake()
|
self.connection.do_handshake()
|
||||||
@ -400,21 +433,21 @@ class BaseHandler(_Connection):
|
|||||||
self.ssl_established = False
|
self.ssl_established = False
|
||||||
self.clientcert = None
|
self.clientcert = None
|
||||||
|
|
||||||
def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=OP_NO_SSLv2,
|
def create_ssl_context(self,
|
||||||
handle_sni=None, request_client_cert=None, cipher_list=None,
|
cert, key,
|
||||||
dhparams=None, chain_file=None):
|
handle_sni=None,
|
||||||
|
request_client_cert=None,
|
||||||
|
chain_file=None,
|
||||||
|
dhparams=None,
|
||||||
|
**sslctx_kwargs):
|
||||||
"""
|
"""
|
||||||
cert: A certutils.SSLCert object.
|
cert: A certutils.SSLCert object.
|
||||||
|
|
||||||
method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD
|
|
||||||
|
|
||||||
handle_sni: SNI handler, should take a connection object. Server
|
handle_sni: SNI handler, should take a connection object. Server
|
||||||
name can be retrieved like this:
|
name can be retrieved like this:
|
||||||
|
|
||||||
connection.get_servername()
|
connection.get_servername()
|
||||||
|
|
||||||
options: A bit field consisting of OpenSSL.SSL.OP_* values
|
|
||||||
|
|
||||||
And you can specify the connection keys as follows:
|
And you can specify the connection keys as follows:
|
||||||
|
|
||||||
new_context = Context(TLSv1_METHOD)
|
new_context = Context(TLSv1_METHOD)
|
||||||
@ -431,40 +464,38 @@ class BaseHandler(_Connection):
|
|||||||
we may be able to make the proper behaviour the default again, but
|
we may be able to make the proper behaviour the default again, but
|
||||||
until then we're conservative.
|
until then we're conservative.
|
||||||
"""
|
"""
|
||||||
ctx = SSL.Context(method)
|
context = self._create_ssl_context(**sslctx_kwargs)
|
||||||
if not options is None:
|
|
||||||
ctx.set_options(options)
|
context.use_privatekey(key)
|
||||||
if chain_file:
|
context.use_certificate(cert.x509)
|
||||||
ctx.load_verify_locations(chain_file)
|
|
||||||
if cipher_list:
|
|
||||||
try:
|
|
||||||
ctx.set_cipher_list(cipher_list)
|
|
||||||
except SSL.Error, v:
|
|
||||||
raise NetLibError("SSL cipher specification error: %s"%str(v))
|
|
||||||
if handle_sni:
|
if handle_sni:
|
||||||
# SNI callback happens during do_handshake()
|
# SNI callback happens during do_handshake()
|
||||||
ctx.set_tlsext_servername_callback(handle_sni)
|
context.set_tlsext_servername_callback(handle_sni)
|
||||||
ctx.use_privatekey(key)
|
|
||||||
ctx.use_certificate(cert.x509)
|
|
||||||
if dhparams:
|
|
||||||
SSL._lib.SSL_CTX_set_tmp_dh(ctx._context, dhparams)
|
|
||||||
if request_client_cert:
|
if request_client_cert:
|
||||||
def ver(*args):
|
def save_cert(conn, cert, errno, depth, preverify_ok):
|
||||||
self.clientcert = certutils.SSLCert(args[1])
|
self.clientcert = certutils.SSLCert(cert)
|
||||||
# Return true to prevent cert verification error
|
# Return true to prevent cert verification error
|
||||||
return True
|
return True
|
||||||
ctx.set_verify(SSL.VERIFY_PEER, ver)
|
context.set_verify(SSL.VERIFY_PEER, save_cert)
|
||||||
if log_ssl_key:
|
|
||||||
ctx.set_info_callback(log_ssl_key)
|
# Cert Verify
|
||||||
return ctx
|
if chain_file:
|
||||||
|
context.load_verify_locations(chain_file)
|
||||||
|
|
||||||
|
if dhparams:
|
||||||
|
SSL._lib.SSL_CTX_set_tmp_dh(context._context, dhparams)
|
||||||
|
|
||||||
|
return context
|
||||||
|
|
||||||
def convert_to_ssl(self, cert, key, **sslctx_kwargs):
|
def convert_to_ssl(self, cert, key, **sslctx_kwargs):
|
||||||
"""
|
"""
|
||||||
Convert connection to SSL.
|
Convert connection to SSL.
|
||||||
For a list of parameters, see BaseHandler._create_ssl_context(...)
|
For a list of parameters, see BaseHandler._create_ssl_context(...)
|
||||||
"""
|
"""
|
||||||
ctx = self._create_ssl_context(cert, key, **sslctx_kwargs)
|
context = self.create_ssl_context(cert, key, **sslctx_kwargs)
|
||||||
self.connection = SSL.Connection(ctx, self.connection)
|
self.connection = SSL.Connection(context, self.connection)
|
||||||
self.connection.set_accept_state()
|
self.connection.set_accept_state()
|
||||||
try:
|
try:
|
||||||
self.connection.do_handshake()
|
self.connection.do_handshake()
|
||||||
@ -474,7 +505,7 @@ class BaseHandler(_Connection):
|
|||||||
self.rfile.set_descriptor(self.connection)
|
self.rfile.set_descriptor(self.connection)
|
||||||
self.wfile.set_descriptor(self.connection)
|
self.wfile.set_descriptor(self.connection)
|
||||||
|
|
||||||
def handle(self): # pragma: no cover
|
def handle(self): # pragma: no cover
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def settimeout(self, n):
|
def settimeout(self, n):
|
||||||
@ -483,6 +514,7 @@ class BaseHandler(_Connection):
|
|||||||
|
|
||||||
class TCPServer(object):
|
class TCPServer(object):
|
||||||
request_queue_size = 20
|
request_queue_size = 20
|
||||||
|
|
||||||
def __init__(self, address):
|
def __init__(self, address):
|
||||||
self.address = Address.wrap(address)
|
self.address = Address.wrap(address)
|
||||||
self.__is_shut_down = threading.Event()
|
self.__is_shut_down = threading.Event()
|
||||||
@ -508,7 +540,7 @@ class TCPServer(object):
|
|||||||
while not self.__shutdown_request:
|
while not self.__shutdown_request:
|
||||||
try:
|
try:
|
||||||
r, w, e = select.select([self.socket], [], [], poll_interval)
|
r, w, e = select.select([self.socket], [], [], poll_interval)
|
||||||
except select.error, ex: # pragma: no cover
|
except select.error as ex: # pragma: no cover
|
||||||
if ex[0] == EINTR:
|
if ex[0] == EINTR:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
@ -516,12 +548,12 @@ class TCPServer(object):
|
|||||||
if self.socket in r:
|
if self.socket in r:
|
||||||
connection, client_address = self.socket.accept()
|
connection, client_address = self.socket.accept()
|
||||||
t = threading.Thread(
|
t = threading.Thread(
|
||||||
target = self.connection_thread,
|
target=self.connection_thread,
|
||||||
args = (connection, client_address),
|
args=(connection, client_address),
|
||||||
name = "ConnectionThread (%s:%s -> %s:%s)" %
|
name="ConnectionThread (%s:%s -> %s:%s)" %
|
||||||
(client_address[0], client_address[1],
|
(client_address[0], client_address[1],
|
||||||
self.address.host, self.address.port)
|
self.address.host, self.address.port)
|
||||||
)
|
)
|
||||||
t.setDaemon(1)
|
t.setDaemon(1)
|
||||||
t.start()
|
t.start()
|
||||||
finally:
|
finally:
|
||||||
|
@ -31,7 +31,7 @@ def raises(exc, obj, *args, **kwargs):
|
|||||||
:kwargs Arguments to be passed to the callable.
|
:kwargs Arguments to be passed to the callable.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
apply(obj, args, kwargs)
|
ret = apply(obj, args, kwargs)
|
||||||
except Exception, v:
|
except Exception, v:
|
||||||
if isinstance(exc, basestring):
|
if isinstance(exc, basestring):
|
||||||
if exc.lower() in str(v).lower():
|
if exc.lower() in str(v).lower():
|
||||||
@ -51,6 +51,6 @@ def raises(exc, obj, *args, **kwargs):
|
|||||||
exc.__name__, v.__class__.__name__, str(v)
|
exc.__name__, v.__class__.__name__, str(v)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
raise AssertionError("No exception raised.")
|
raise AssertionError("No exception raised. Return value: {}".format(ret))
|
||||||
|
|
||||||
test_data = utils.Data(__name__)
|
test_data = utils.Data(__name__)
|
||||||
|
Loading…
Reference in New Issue
Block a user