mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 08:11:00 +00:00
add ALPN support to TCP abstraction
This commit is contained in:
parent
d50b9be0d5
commit
780836b182
@ -360,7 +360,9 @@ class _Connection(object):
|
|||||||
def _create_ssl_context(self,
|
def _create_ssl_context(self,
|
||||||
method=SSLv23_METHOD,
|
method=SSLv23_METHOD,
|
||||||
options=(OP_NO_SSLv2 | OP_NO_SSLv3),
|
options=(OP_NO_SSLv2 | OP_NO_SSLv3),
|
||||||
cipher_list=None
|
cipher_list=None,
|
||||||
|
alpn_protos=None,
|
||||||
|
alpn_select=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD
|
:param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD
|
||||||
@ -389,6 +391,17 @@ class _Connection(object):
|
|||||||
if log_ssl_key:
|
if log_ssl_key:
|
||||||
context.set_info_callback(log_ssl_key)
|
context.set_info_callback(log_ssl_key)
|
||||||
|
|
||||||
|
# advertise application layer protocols
|
||||||
|
if alpn_protos is not None:
|
||||||
|
context.set_alpn_protos(alpn_protos)
|
||||||
|
|
||||||
|
# select application layer protocol
|
||||||
|
if alpn_select is not None:
|
||||||
|
def alpn_select_f(conn, options):
|
||||||
|
return bytes(alpn_select)
|
||||||
|
|
||||||
|
context.set_alpn_select_callback(alpn_select_f)
|
||||||
|
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
@ -413,8 +426,8 @@ class TCPClient(_Connection):
|
|||||||
self.ssl_established = False
|
self.ssl_established = False
|
||||||
self.sni = None
|
self.sni = None
|
||||||
|
|
||||||
def create_ssl_context(self, cert=None, **sslctx_kwargs):
|
def create_ssl_context(self, cert=None, alpn_protos=None, **sslctx_kwargs):
|
||||||
context = self._create_ssl_context(**sslctx_kwargs)
|
context = self._create_ssl_context(alpn_protos=alpn_protos, **sslctx_kwargs)
|
||||||
# Client Certs
|
# Client Certs
|
||||||
if cert:
|
if cert:
|
||||||
try:
|
try:
|
||||||
@ -424,13 +437,13 @@ class TCPClient(_Connection):
|
|||||||
raise NetLibError("SSL client certificate error: %s" % str(v))
|
raise NetLibError("SSL client certificate error: %s" % str(v))
|
||||||
return context
|
return context
|
||||||
|
|
||||||
def convert_to_ssl(self, sni=None, **sslctx_kwargs):
|
def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs):
|
||||||
"""
|
"""
|
||||||
cert: Path to a file containing both client cert and private key.
|
cert: Path to a file containing both client cert and private key.
|
||||||
|
|
||||||
options: A bit field consisting of OpenSSL.SSL.OP_* values
|
options: A bit field consisting of OpenSSL.SSL.OP_* values
|
||||||
"""
|
"""
|
||||||
context = self.create_ssl_context(**sslctx_kwargs)
|
context = self.create_ssl_context(alpn_protos=alpn_protos, **sslctx_kwargs)
|
||||||
self.connection = SSL.Connection(context, self.connection)
|
self.connection = SSL.Connection(context, self.connection)
|
||||||
if sni:
|
if sni:
|
||||||
self.sni = sni
|
self.sni = sni
|
||||||
@ -465,6 +478,9 @@ class TCPClient(_Connection):
|
|||||||
def gettimeout(self):
|
def gettimeout(self):
|
||||||
return self.connection.gettimeout()
|
return self.connection.gettimeout()
|
||||||
|
|
||||||
|
def get_alpn_proto_negotiated(self):
|
||||||
|
return self.connection.get_alpn_proto_negotiated()
|
||||||
|
|
||||||
|
|
||||||
class BaseHandler(_Connection):
|
class BaseHandler(_Connection):
|
||||||
|
|
||||||
@ -492,6 +508,7 @@ class BaseHandler(_Connection):
|
|||||||
request_client_cert=None,
|
request_client_cert=None,
|
||||||
chain_file=None,
|
chain_file=None,
|
||||||
dhparams=None,
|
dhparams=None,
|
||||||
|
alpn_select=None,
|
||||||
**sslctx_kwargs):
|
**sslctx_kwargs):
|
||||||
"""
|
"""
|
||||||
cert: A certutils.SSLCert object.
|
cert: A certutils.SSLCert object.
|
||||||
@ -517,7 +534,8 @@ class BaseHandler(_Connection):
|
|||||||
we may be able to make the proper behaviour the default again, but
|
we may be able to make the proper behaviour the default again, but
|
||||||
until then we're conservative.
|
until then we're conservative.
|
||||||
"""
|
"""
|
||||||
context = self._create_ssl_context(**sslctx_kwargs)
|
|
||||||
|
context = self._create_ssl_context(alpn_select=alpn_select, **sslctx_kwargs)
|
||||||
|
|
||||||
context.use_privatekey(key)
|
context.use_privatekey(key)
|
||||||
context.use_certificate(cert.x509)
|
context.use_certificate(cert.x509)
|
||||||
@ -542,12 +560,13 @@ class BaseHandler(_Connection):
|
|||||||
|
|
||||||
return context
|
return context
|
||||||
|
|
||||||
def convert_to_ssl(self, cert, key, **sslctx_kwargs):
|
def convert_to_ssl(self, cert, key, alpn_select=None, **sslctx_kwargs):
|
||||||
"""
|
"""
|
||||||
Convert connection to SSL.
|
Convert connection to SSL.
|
||||||
For a list of parameters, see BaseHandler._create_ssl_context(...)
|
For a list of parameters, see BaseHandler._create_ssl_context(...)
|
||||||
"""
|
"""
|
||||||
context = self.create_ssl_context(cert, key, **sslctx_kwargs)
|
|
||||||
|
context = self.create_ssl_context(cert, key, alpn_select=alpn_select, **sslctx_kwargs)
|
||||||
self.connection = SSL.Connection(context, self.connection)
|
self.connection = SSL.Connection(context, self.connection)
|
||||||
self.connection.set_accept_state()
|
self.connection.set_accept_state()
|
||||||
try:
|
try:
|
||||||
|
@ -82,7 +82,8 @@ class TServer(tcp.TCPServer):
|
|||||||
request_client_cert=self.ssl["request_client_cert"],
|
request_client_cert=self.ssl["request_client_cert"],
|
||||||
cipher_list=self.ssl.get("cipher_list", None),
|
cipher_list=self.ssl.get("cipher_list", None),
|
||||||
dhparams=self.ssl.get("dhparams", None),
|
dhparams=self.ssl.get("dhparams", None),
|
||||||
chain_file=self.ssl.get("chain_file", None)
|
chain_file=self.ssl.get("chain_file", None),
|
||||||
|
alpn_select=self.ssl.get("alpn_select", None)
|
||||||
)
|
)
|
||||||
h.handle()
|
h.handle()
|
||||||
h.finish()
|
h.finish()
|
||||||
|
@ -389,6 +389,24 @@ class TestTimeOut(test.ServerTestBase):
|
|||||||
tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10)
|
tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10)
|
||||||
|
|
||||||
|
|
||||||
|
class TestALPN(test.ServerTestBase):
|
||||||
|
handler = HangHandler
|
||||||
|
ssl = dict(
|
||||||
|
cert=tutils.test_data.path("data/server.crt"),
|
||||||
|
key=tutils.test_data.path("data/server.key"),
|
||||||
|
request_client_cert=False,
|
||||||
|
v3_only=False,
|
||||||
|
alpn_select="h2"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_alpn(self):
|
||||||
|
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||||
|
c.connect()
|
||||||
|
c.convert_to_ssl(alpn_protos=["h2"])
|
||||||
|
print "ALPN: %s" % c.get_alpn_proto_negotiated()
|
||||||
|
assert c.get_alpn_proto_negotiated() == "h2"
|
||||||
|
|
||||||
|
|
||||||
class TestSSLTimeOut(test.ServerTestBase):
|
class TestSSLTimeOut(test.ServerTestBase):
|
||||||
handler = HangHandler
|
handler = HangHandler
|
||||||
ssl = dict(
|
ssl = dict(
|
||||||
|
Loading…
Reference in New Issue
Block a user