mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 00:01:36 +00:00
add ALPN support to TCP abstraction
This commit is contained in:
parent
d50b9be0d5
commit
780836b182
@ -360,7 +360,9 @@ class _Connection(object):
|
||||
def _create_ssl_context(self,
|
||||
method=SSLv23_METHOD,
|
||||
options=(OP_NO_SSLv2 | OP_NO_SSLv3),
|
||||
cipher_list=None
|
||||
cipher_list=None,
|
||||
alpn_protos=None,
|
||||
alpn_select=None,
|
||||
):
|
||||
"""
|
||||
:param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD
|
||||
@ -389,6 +391,17 @@ class _Connection(object):
|
||||
if log_ssl_key:
|
||||
context.set_info_callback(log_ssl_key)
|
||||
|
||||
# advertise application layer protocols
|
||||
if alpn_protos is not None:
|
||||
context.set_alpn_protos(alpn_protos)
|
||||
|
||||
# select application layer protocol
|
||||
if alpn_select is not None:
|
||||
def alpn_select_f(conn, options):
|
||||
return bytes(alpn_select)
|
||||
|
||||
context.set_alpn_select_callback(alpn_select_f)
|
||||
|
||||
return context
|
||||
|
||||
|
||||
@ -413,8 +426,8 @@ class TCPClient(_Connection):
|
||||
self.ssl_established = False
|
||||
self.sni = None
|
||||
|
||||
def create_ssl_context(self, cert=None, **sslctx_kwargs):
|
||||
context = self._create_ssl_context(**sslctx_kwargs)
|
||||
def create_ssl_context(self, cert=None, alpn_protos=None, **sslctx_kwargs):
|
||||
context = self._create_ssl_context(alpn_protos=alpn_protos, **sslctx_kwargs)
|
||||
# Client Certs
|
||||
if cert:
|
||||
try:
|
||||
@ -424,13 +437,13 @@ class TCPClient(_Connection):
|
||||
raise NetLibError("SSL client certificate error: %s" % str(v))
|
||||
return context
|
||||
|
||||
def convert_to_ssl(self, sni=None, **sslctx_kwargs):
|
||||
def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs):
|
||||
"""
|
||||
cert: Path to a file containing both client cert and private key.
|
||||
|
||||
options: A bit field consisting of OpenSSL.SSL.OP_* values
|
||||
"""
|
||||
context = self.create_ssl_context(**sslctx_kwargs)
|
||||
context = self.create_ssl_context(alpn_protos=alpn_protos, **sslctx_kwargs)
|
||||
self.connection = SSL.Connection(context, self.connection)
|
||||
if sni:
|
||||
self.sni = sni
|
||||
@ -465,6 +478,9 @@ class TCPClient(_Connection):
|
||||
def gettimeout(self):
|
||||
return self.connection.gettimeout()
|
||||
|
||||
def get_alpn_proto_negotiated(self):
|
||||
return self.connection.get_alpn_proto_negotiated()
|
||||
|
||||
|
||||
class BaseHandler(_Connection):
|
||||
|
||||
@ -492,6 +508,7 @@ class BaseHandler(_Connection):
|
||||
request_client_cert=None,
|
||||
chain_file=None,
|
||||
dhparams=None,
|
||||
alpn_select=None,
|
||||
**sslctx_kwargs):
|
||||
"""
|
||||
cert: A certutils.SSLCert object.
|
||||
@ -517,7 +534,8 @@ class BaseHandler(_Connection):
|
||||
we may be able to make the proper behaviour the default again, but
|
||||
until then we're conservative.
|
||||
"""
|
||||
context = self._create_ssl_context(**sslctx_kwargs)
|
||||
|
||||
context = self._create_ssl_context(alpn_select=alpn_select, **sslctx_kwargs)
|
||||
|
||||
context.use_privatekey(key)
|
||||
context.use_certificate(cert.x509)
|
||||
@ -542,12 +560,13 @@ class BaseHandler(_Connection):
|
||||
|
||||
return context
|
||||
|
||||
def convert_to_ssl(self, cert, key, **sslctx_kwargs):
|
||||
def convert_to_ssl(self, cert, key, alpn_select=None, **sslctx_kwargs):
|
||||
"""
|
||||
Convert connection to SSL.
|
||||
For a list of parameters, see BaseHandler._create_ssl_context(...)
|
||||
"""
|
||||
context = self.create_ssl_context(cert, key, **sslctx_kwargs)
|
||||
|
||||
context = self.create_ssl_context(cert, key, alpn_select=alpn_select, **sslctx_kwargs)
|
||||
self.connection = SSL.Connection(context, self.connection)
|
||||
self.connection.set_accept_state()
|
||||
try:
|
||||
|
@ -82,7 +82,8 @@ class TServer(tcp.TCPServer):
|
||||
request_client_cert=self.ssl["request_client_cert"],
|
||||
cipher_list=self.ssl.get("cipher_list", None),
|
||||
dhparams=self.ssl.get("dhparams", None),
|
||||
chain_file=self.ssl.get("chain_file", None)
|
||||
chain_file=self.ssl.get("chain_file", None),
|
||||
alpn_select=self.ssl.get("alpn_select", None)
|
||||
)
|
||||
h.handle()
|
||||
h.finish()
|
||||
|
@ -389,6 +389,24 @@ class TestTimeOut(test.ServerTestBase):
|
||||
tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10)
|
||||
|
||||
|
||||
class TestALPN(test.ServerTestBase):
|
||||
handler = HangHandler
|
||||
ssl = dict(
|
||||
cert=tutils.test_data.path("data/server.crt"),
|
||||
key=tutils.test_data.path("data/server.key"),
|
||||
request_client_cert=False,
|
||||
v3_only=False,
|
||||
alpn_select="h2"
|
||||
)
|
||||
|
||||
def test_alpn(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.convert_to_ssl(alpn_protos=["h2"])
|
||||
print "ALPN: %s" % c.get_alpn_proto_negotiated()
|
||||
assert c.get_alpn_proto_negotiated() == "h2"
|
||||
|
||||
|
||||
class TestSSLTimeOut(test.ServerTestBase):
|
||||
handler = HangHandler
|
||||
ssl = dict(
|
||||
|
Loading…
Reference in New Issue
Block a user