add ALPN support to TCP abstraction

This commit is contained in:
Thomas Kriechbaumer 2015-05-28 17:46:44 +02:00
parent d50b9be0d5
commit 780836b182
3 changed files with 47 additions and 9 deletions

View File

@ -360,7 +360,9 @@ class _Connection(object):
def _create_ssl_context(self,
method=SSLv23_METHOD,
options=(OP_NO_SSLv2 | OP_NO_SSLv3),
cipher_list=None
cipher_list=None,
alpn_protos=None,
alpn_select=None,
):
"""
:param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD
@ -389,6 +391,17 @@ class _Connection(object):
if log_ssl_key:
context.set_info_callback(log_ssl_key)
# advertise application layer protocols
if alpn_protos is not None:
context.set_alpn_protos(alpn_protos)
# select application layer protocol
if alpn_select is not None:
def alpn_select_f(conn, options):
return bytes(alpn_select)
context.set_alpn_select_callback(alpn_select_f)
return context
@ -413,8 +426,8 @@ class TCPClient(_Connection):
self.ssl_established = False
self.sni = None
def create_ssl_context(self, cert=None, **sslctx_kwargs):
context = self._create_ssl_context(**sslctx_kwargs)
def create_ssl_context(self, cert=None, alpn_protos=None, **sslctx_kwargs):
context = self._create_ssl_context(alpn_protos=alpn_protos, **sslctx_kwargs)
# Client Certs
if cert:
try:
@ -424,13 +437,13 @@ class TCPClient(_Connection):
raise NetLibError("SSL client certificate error: %s" % str(v))
return context
def convert_to_ssl(self, sni=None, **sslctx_kwargs):
def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs):
"""
cert: Path to a file containing both client cert and private key.
options: A bit field consisting of OpenSSL.SSL.OP_* values
"""
context = self.create_ssl_context(**sslctx_kwargs)
context = self.create_ssl_context(alpn_protos=alpn_protos, **sslctx_kwargs)
self.connection = SSL.Connection(context, self.connection)
if sni:
self.sni = sni
@ -465,6 +478,9 @@ class TCPClient(_Connection):
def gettimeout(self):
return self.connection.gettimeout()
def get_alpn_proto_negotiated(self):
return self.connection.get_alpn_proto_negotiated()
class BaseHandler(_Connection):
@ -492,6 +508,7 @@ class BaseHandler(_Connection):
request_client_cert=None,
chain_file=None,
dhparams=None,
alpn_select=None,
**sslctx_kwargs):
"""
cert: A certutils.SSLCert object.
@ -517,7 +534,8 @@ class BaseHandler(_Connection):
we may be able to make the proper behaviour the default again, but
until then we're conservative.
"""
context = self._create_ssl_context(**sslctx_kwargs)
context = self._create_ssl_context(alpn_select=alpn_select, **sslctx_kwargs)
context.use_privatekey(key)
context.use_certificate(cert.x509)
@ -542,12 +560,13 @@ class BaseHandler(_Connection):
return context
def convert_to_ssl(self, cert, key, **sslctx_kwargs):
def convert_to_ssl(self, cert, key, alpn_select=None, **sslctx_kwargs):
"""
Convert connection to SSL.
For a list of parameters, see BaseHandler._create_ssl_context(...)
"""
context = self.create_ssl_context(cert, key, **sslctx_kwargs)
context = self.create_ssl_context(cert, key, alpn_select=alpn_select, **sslctx_kwargs)
self.connection = SSL.Connection(context, self.connection)
self.connection.set_accept_state()
try:

View File

@ -82,7 +82,8 @@ class TServer(tcp.TCPServer):
request_client_cert=self.ssl["request_client_cert"],
cipher_list=self.ssl.get("cipher_list", None),
dhparams=self.ssl.get("dhparams", None),
chain_file=self.ssl.get("chain_file", None)
chain_file=self.ssl.get("chain_file", None),
alpn_select=self.ssl.get("alpn_select", None)
)
h.handle()
h.finish()

View File

@ -389,6 +389,24 @@ class TestTimeOut(test.ServerTestBase):
tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10)
class TestALPN(test.ServerTestBase):
handler = HangHandler
ssl = dict(
cert=tutils.test_data.path("data/server.crt"),
key=tutils.test_data.path("data/server.key"),
request_client_cert=False,
v3_only=False,
alpn_select="h2"
)
def test_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl(alpn_protos=["h2"])
print "ALPN: %s" % c.get_alpn_proto_negotiated()
assert c.get_alpn_proto_negotiated() == "h2"
class TestSSLTimeOut(test.ServerTestBase):
handler = HangHandler
ssl = dict(