diff --git a/netlib/certutils.py b/netlib/certutils.py index d544cfa6e..19148382e 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -115,10 +115,22 @@ class CertStore: """ Implements an in-memory certificate store. """ - def __init__(self, privkey, cacert): + def __init__(self, privkey, cacert, dhparams=None): self.privkey, self.cacert = privkey, cacert + self.dhparams = dhparams self.certs = DNTree() + @classmethod + def load_dhparam(klass, path): + bio = OpenSSL.SSL._lib.BIO_new_file(path, b"r") + if bio != OpenSSL.SSL._ffi.NULL: + bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free) + dh = OpenSSL.SSL._lib.PEM_read_bio_DHparams( + bio, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL + ) + dh = OpenSSL.SSL._ffi.gc(dh, OpenSSL.SSL._lib.DH_free) + return dh + @classmethod def from_store(klass, path, basename): p = os.path.join(path, basename + "-ca.pem") @@ -129,7 +141,9 @@ class CertStore: raw = file(p, "rb").read() ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) - return klass(key, ca) + dhp = os.path.join(path, basename + "-dhparam.pem") + dh = klass.load_dhparam(dhp) + return klass(key, ca, dh) @classmethod def create_store(klass, path, basename, o=None, cn=None, expiry=DEFAULT_EXP): @@ -147,17 +161,17 @@ class CertStore: f.close() # Dump the certificate in PEM format - f = open(os.path.join(path, basename + "-cert.pem"), "wb") + f = open(os.path.join(path, basename + "-ca-cert.pem"), "wb") f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) f.close() # Create a .cer file with the same contents for Android - f = open(os.path.join(path, basename + "-cert.cer"), "wb") + f = open(os.path.join(path, basename + "-ca-cert.cer"), "wb") f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) f.close() # Dump the certificate in PKCS12 format for Windows devices - f = open(os.path.join(path, basename + "-cert.p12"), "wb") + f = open(os.path.join(path, basename + "-ca-cert.p12"), "wb") p12 = OpenSSL.crypto.PKCS12() p12.set_certificate(ca) p12.set_privatekey(key) diff --git a/netlib/tcp.py b/netlib/tcp.py index 83059bc22..078ac4970 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -339,7 +339,10 @@ class BaseHandler(_Connection): self.ssl_established = False self.clientcert = None - def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=False, cipher_list=None): + def convert_to_ssl(self, cert, key, + method=SSLv23_METHOD, options=None, handle_sni=None, + request_client_cert=False, cipher_list=None, dhparams=None + ): """ cert: A certutils.SSLCert object. method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD @@ -377,6 +380,8 @@ class BaseHandler(_Connection): ctx.set_tlsext_servername_callback(handle_sni) ctx.use_privatekey(key) ctx.use_certificate(cert.x509) + if dhparams: + SSL._lib.SSL_CTX_set_tmp_dh(ctx._context, dhparams) if request_client_cert: def ver(*args): self.clientcert = certutils.SSLCert(args[1]) diff --git a/netlib/test.py b/netlib/test.py index b88b35865..bb0012ada 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -18,7 +18,6 @@ class ServerTestBase: ssl = None handler = None addr = ("localhost", 0) - @classmethod def setupAll(cls): cls.q = Queue.Queue() @@ -43,15 +42,16 @@ class ServerTestBase: class TServer(tcp.TCPServer): def __init__(self, ssl, q, handler_klass, addr): """ - ssl: A {cert, key, v3_only} dict. + ssl: A dictionary of SSL parameters: + + cert, key, request_client_cert, cipher_list, + dhparams, v3_only """ tcp.TCPServer.__init__(self, addr) self.ssl, self.q = ssl, q self.handler_klass = handler_klass self.last_handler = None - - def handle_client_connection(self, request, client_address): h = self.handler_klass(request, client_address, self) self.last_handler = h @@ -73,7 +73,8 @@ class TServer(tcp.TCPServer): options = options, handle_sni = getattr(h, "handle_sni", None), request_client_cert = self.ssl["request_client_cert"], - cipher_list = self.ssl.get("cipher_list", None) + cipher_list = self.ssl.get("cipher_list", None), + dhparams = self.ssl.get("dhparams", None) ) h.handle() h.finish() diff --git a/test/data/dhparam.pem b/test/data/dhparam.pem new file mode 100644 index 000000000..6f2526e10 --- /dev/null +++ b/test/data/dhparam.pem @@ -0,0 +1,5 @@ +-----BEGIN DH PARAMETERS----- +MIGHAoGBAOdPzMbYgoYfO3YBYauCLRlE8X1XypTiAjoeCFD0qWRx8YUsZ6Sj20W5 +zsfQxlZfKovo3f2MftjkDkbI/C/tDgxoe0ZPbjy5CjdOhkzxn0oTbKTs16Rw8DyK +1LjTR65sQJkJEdgsX8TSi/cicCftJZl9CaZEaObF2bdgSgGK+PezAgEC +-----END DH PARAMETERS----- diff --git a/test/test_tcp.py b/test/test_tcp.py index d5d112940..814754cd0 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -363,6 +363,26 @@ class TestSSLTimeOut(test.ServerTestBase): tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10) +class TestDHParams(test.ServerTestBase): + handler = HangHandler + ssl = dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + request_client_cert = False, + v3_only = False, + dhparams = certutils.CertStore.load_dhparam( + tutils.test_data.path("data/dhparam.pem"), + ), + cipher_list = "DHE-RSA-AES256-SHA" + ) + def test_dhparams(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + ret = c.get_current_cipher() + assert ret[0] == "DHE-RSA-AES256-SHA" + + class TestTCPClient: def test_conerr(self): c = tcp.TCPClient(("127.0.0.1", 0))