Use PyOpenSSL's underlying ffi interface to get current cipher for connections.

This commit is contained in:
Aldo Cortesi 2014-03-02 21:37:28 +13:00
parent 1acaf1c880
commit cfaa3da25c
2 changed files with 47 additions and 4 deletions

View File

@ -2,6 +2,7 @@ import select, socket, threading, sys, time, traceback
from OpenSSL import SSL from OpenSSL import SSL
import certutils import certutils
EINTR = 4 EINTR = 4
SSLv2_METHOD = SSL.SSLv2_METHOD SSLv2_METHOD = SSL.SSLv2_METHOD
@ -214,7 +215,16 @@ class Address(object):
return (self.address, self.family) == (other.address, other.family) return (self.address, self.family) == (other.address, other.family)
class SocketCloseMixin(object): class _Connection(object):
def get_current_cipher(self):
if not self.ssl_established:
return None
c = SSL._lib.SSL_get_current_cipher(self.connection._ssl)
name = SSL._native(SSL._ffi.string(SSL._lib.SSL_CIPHER_get_name(c)))
bits = SSL._lib.SSL_CIPHER_get_bits(c, SSL._ffi.NULL)
version = SSL._native(SSL._ffi.string(SSL._lib.SSL_CIPHER_get_version(c)))
return name, bits, version
def finish(self): def finish(self):
self.finished = True self.finished = True
try: try:
@ -248,7 +258,7 @@ class SocketCloseMixin(object):
pass pass
class TCPClient(SocketCloseMixin): class TCPClient(_Connection):
rbufsize = -1 rbufsize = -1
wbufsize = -1 wbufsize = -1
def __init__(self, address, source_address=None): def __init__(self, address, source_address=None):
@ -310,7 +320,7 @@ class TCPClient(SocketCloseMixin):
return self.connection.gettimeout() return self.connection.gettimeout()
class BaseHandler(SocketCloseMixin): class BaseHandler(_Connection):
""" """
The instantiator is expected to call the handle() and finish() methods. The instantiator is expected to call the handle() and finish() methods.

View File

@ -49,6 +49,13 @@ class ClientCipherListHandler(tcp.BaseHandler):
self.wfile.flush() self.wfile.flush()
class CurrentCipherHandler(tcp.BaseHandler):
sni = None
def handle(self):
self.wfile.write("%s"%str(self.get_current_cipher()))
self.wfile.flush()
class DisconnectHandler(tcp.BaseHandler): class DisconnectHandler(tcp.BaseHandler):
def handle(self): def handle(self):
self.close() self.close()
@ -151,7 +158,8 @@ class TestServerSSL(test.ServerTestBase):
cert = tutils.test_data.path("data/server.crt"), cert = tutils.test_data.path("data/server.crt"),
key = tutils.test_data.path("data/server.key"), key = tutils.test_data.path("data/server.key"),
request_client_cert = False, request_client_cert = False,
v3_only = False v3_only = False,
cipher_list = "AES256-SHA"
) )
def test_echo(self): def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
@ -165,6 +173,15 @@ class TestServerSSL(test.ServerTestBase):
def test_get_remote_cert(self): def test_get_remote_cert(self):
assert certutils.get_remote_cert("127.0.0.1", self.port, None).digest("sha1") assert certutils.get_remote_cert("127.0.0.1", self.port, None).digest("sha1")
def test_get_current_cipher(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
assert not c.get_current_cipher()
c.convert_to_ssl(sni="foo.com")
ret = c.get_current_cipher()
assert ret
assert "AES" in ret[0]
class TestSSLv3Only(test.ServerTestBase): class TestSSLv3Only(test.ServerTestBase):
handler = EchoHandler handler = EchoHandler
@ -236,6 +253,22 @@ class TestServerCipherList(test.ServerTestBase):
assert c.rfile.readline() == "['RC4-SHA']" assert c.rfile.readline() == "['RC4-SHA']"
class TestServerCurrentCipher(test.ServerTestBase):
handler = CurrentCipherHandler
ssl = dict(
cert = tutils.test_data.path("data/server.crt"),
key = tutils.test_data.path("data/server.key"),
request_client_cert = False,
v3_only = False,
cipher_list = 'RC4-SHA'
)
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl(sni="foo.com")
assert "RC4-SHA" in c.rfile.readline()
class TestServerCipherListError(test.ServerTestBase): class TestServerCipherListError(test.ServerTestBase):
handler = ClientCipherListHandler handler = ClientCipherListHandler
ssl = dict( ssl = dict(