diff --git a/netlib/tcp.py b/netlib/tcp.py index 8f2ebdf03..0dff807ba 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -2,6 +2,7 @@ import select, socket, threading, sys, time, traceback from OpenSSL import SSL import certutils + EINTR = 4 SSLv2_METHOD = SSL.SSLv2_METHOD @@ -214,7 +215,16 @@ class Address(object): return (self.address, self.family) == (other.address, other.family) -class SocketCloseMixin(object): +class _Connection(object): + def get_current_cipher(self): + if not self.ssl_established: + return None + c = SSL._lib.SSL_get_current_cipher(self.connection._ssl) + name = SSL._native(SSL._ffi.string(SSL._lib.SSL_CIPHER_get_name(c))) + bits = SSL._lib.SSL_CIPHER_get_bits(c, SSL._ffi.NULL) + version = SSL._native(SSL._ffi.string(SSL._lib.SSL_CIPHER_get_version(c))) + return name, bits, version + def finish(self): self.finished = True try: @@ -248,7 +258,7 @@ class SocketCloseMixin(object): pass -class TCPClient(SocketCloseMixin): +class TCPClient(_Connection): rbufsize = -1 wbufsize = -1 def __init__(self, address, source_address=None): @@ -310,7 +320,7 @@ class TCPClient(SocketCloseMixin): return self.connection.gettimeout() -class BaseHandler(SocketCloseMixin): +class BaseHandler(_Connection): """ The instantiator is expected to call the handle() and finish() methods. diff --git a/test/test_tcp.py b/test/test_tcp.py index 4e27a632f..387e3f33d 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -49,6 +49,13 @@ class ClientCipherListHandler(tcp.BaseHandler): self.wfile.flush() +class CurrentCipherHandler(tcp.BaseHandler): + sni = None + def handle(self): + self.wfile.write("%s"%str(self.get_current_cipher())) + self.wfile.flush() + + class DisconnectHandler(tcp.BaseHandler): def handle(self): self.close() @@ -151,7 +158,8 @@ class TestServerSSL(test.ServerTestBase): cert = tutils.test_data.path("data/server.crt"), key = tutils.test_data.path("data/server.key"), request_client_cert = False, - v3_only = False + v3_only = False, + cipher_list = "AES256-SHA" ) def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) @@ -165,6 +173,15 @@ class TestServerSSL(test.ServerTestBase): def test_get_remote_cert(self): assert certutils.get_remote_cert("127.0.0.1", self.port, None).digest("sha1") + def test_get_current_cipher(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + assert not c.get_current_cipher() + c.convert_to_ssl(sni="foo.com") + ret = c.get_current_cipher() + assert ret + assert "AES" in ret[0] + class TestSSLv3Only(test.ServerTestBase): handler = EchoHandler @@ -236,6 +253,22 @@ class TestServerCipherList(test.ServerTestBase): assert c.rfile.readline() == "['RC4-SHA']" +class TestServerCurrentCipher(test.ServerTestBase): + handler = CurrentCipherHandler + ssl = dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + request_client_cert = False, + v3_only = False, + cipher_list = 'RC4-SHA' + ) + def test_echo(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl(sni="foo.com") + assert "RC4-SHA" in c.rfile.readline() + + class TestServerCipherListError(test.ServerTestBase): handler = ClientCipherListHandler ssl = dict(