add ALPN support to TCP abstraction

This commit is contained in:
Thomas Kriechbaumer 2015-05-28 17:46:44 +02:00
parent d50b9be0d5
commit 780836b182
3 changed files with 47 additions and 9 deletions

View File

@ -360,7 +360,9 @@ class _Connection(object):
def _create_ssl_context(self, def _create_ssl_context(self,
method=SSLv23_METHOD, method=SSLv23_METHOD,
options=(OP_NO_SSLv2 | OP_NO_SSLv3), options=(OP_NO_SSLv2 | OP_NO_SSLv3),
cipher_list=None cipher_list=None,
alpn_protos=None,
alpn_select=None,
): ):
""" """
:param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD
@ -389,6 +391,17 @@ class _Connection(object):
if log_ssl_key: if log_ssl_key:
context.set_info_callback(log_ssl_key) context.set_info_callback(log_ssl_key)
# advertise application layer protocols
if alpn_protos is not None:
context.set_alpn_protos(alpn_protos)
# select application layer protocol
if alpn_select is not None:
def alpn_select_f(conn, options):
return bytes(alpn_select)
context.set_alpn_select_callback(alpn_select_f)
return context return context
@ -413,8 +426,8 @@ class TCPClient(_Connection):
self.ssl_established = False self.ssl_established = False
self.sni = None self.sni = None
def create_ssl_context(self, cert=None, **sslctx_kwargs): def create_ssl_context(self, cert=None, alpn_protos=None, **sslctx_kwargs):
context = self._create_ssl_context(**sslctx_kwargs) context = self._create_ssl_context(alpn_protos=alpn_protos, **sslctx_kwargs)
# Client Certs # Client Certs
if cert: if cert:
try: try:
@ -424,13 +437,13 @@ class TCPClient(_Connection):
raise NetLibError("SSL client certificate error: %s" % str(v)) raise NetLibError("SSL client certificate error: %s" % str(v))
return context return context
def convert_to_ssl(self, sni=None, **sslctx_kwargs): def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs):
""" """
cert: Path to a file containing both client cert and private key. cert: Path to a file containing both client cert and private key.
options: A bit field consisting of OpenSSL.SSL.OP_* values options: A bit field consisting of OpenSSL.SSL.OP_* values
""" """
context = self.create_ssl_context(**sslctx_kwargs) context = self.create_ssl_context(alpn_protos=alpn_protos, **sslctx_kwargs)
self.connection = SSL.Connection(context, self.connection) self.connection = SSL.Connection(context, self.connection)
if sni: if sni:
self.sni = sni self.sni = sni
@ -465,6 +478,9 @@ class TCPClient(_Connection):
def gettimeout(self): def gettimeout(self):
return self.connection.gettimeout() return self.connection.gettimeout()
def get_alpn_proto_negotiated(self):
return self.connection.get_alpn_proto_negotiated()
class BaseHandler(_Connection): class BaseHandler(_Connection):
@ -492,6 +508,7 @@ class BaseHandler(_Connection):
request_client_cert=None, request_client_cert=None,
chain_file=None, chain_file=None,
dhparams=None, dhparams=None,
alpn_select=None,
**sslctx_kwargs): **sslctx_kwargs):
""" """
cert: A certutils.SSLCert object. cert: A certutils.SSLCert object.
@ -517,7 +534,8 @@ class BaseHandler(_Connection):
we may be able to make the proper behaviour the default again, but we may be able to make the proper behaviour the default again, but
until then we're conservative. until then we're conservative.
""" """
context = self._create_ssl_context(**sslctx_kwargs)
context = self._create_ssl_context(alpn_select=alpn_select, **sslctx_kwargs)
context.use_privatekey(key) context.use_privatekey(key)
context.use_certificate(cert.x509) context.use_certificate(cert.x509)
@ -542,12 +560,13 @@ class BaseHandler(_Connection):
return context return context
def convert_to_ssl(self, cert, key, **sslctx_kwargs): def convert_to_ssl(self, cert, key, alpn_select=None, **sslctx_kwargs):
""" """
Convert connection to SSL. Convert connection to SSL.
For a list of parameters, see BaseHandler._create_ssl_context(...) For a list of parameters, see BaseHandler._create_ssl_context(...)
""" """
context = self.create_ssl_context(cert, key, **sslctx_kwargs)
context = self.create_ssl_context(cert, key, alpn_select=alpn_select, **sslctx_kwargs)
self.connection = SSL.Connection(context, self.connection) self.connection = SSL.Connection(context, self.connection)
self.connection.set_accept_state() self.connection.set_accept_state()
try: try:

View File

@ -82,7 +82,8 @@ class TServer(tcp.TCPServer):
request_client_cert=self.ssl["request_client_cert"], request_client_cert=self.ssl["request_client_cert"],
cipher_list=self.ssl.get("cipher_list", None), cipher_list=self.ssl.get("cipher_list", None),
dhparams=self.ssl.get("dhparams", None), dhparams=self.ssl.get("dhparams", None),
chain_file=self.ssl.get("chain_file", None) chain_file=self.ssl.get("chain_file", None),
alpn_select=self.ssl.get("alpn_select", None)
) )
h.handle() h.handle()
h.finish() h.finish()

View File

@ -389,6 +389,24 @@ class TestTimeOut(test.ServerTestBase):
tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10) tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10)
class TestALPN(test.ServerTestBase):
handler = HangHandler
ssl = dict(
cert=tutils.test_data.path("data/server.crt"),
key=tutils.test_data.path("data/server.key"),
request_client_cert=False,
v3_only=False,
alpn_select="h2"
)
def test_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl(alpn_protos=["h2"])
print "ALPN: %s" % c.get_alpn_proto_negotiated()
assert c.get_alpn_proto_negotiated() == "h2"
class TestSSLTimeOut(test.ServerTestBase): class TestSSLTimeOut(test.ServerTestBase):
handler = HangHandler handler = HangHandler
ssl = dict( ssl = dict(