diff --git a/.gitignore b/.gitignore index e66d51fe6..26c449d1c 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ MANIFEST *.swp *.swo .coverage -.idea \ No newline at end of file +.idea +__pycache__ diff --git a/netlib/certffi.py b/netlib/certffi.py new file mode 100644 index 000000000..c5d7c95e3 --- /dev/null +++ b/netlib/certffi.py @@ -0,0 +1,36 @@ +import cffi +import OpenSSL +xffi = cffi.FFI() +xffi.cdef (""" + struct rsa_meth_st { + int flags; + ...; + }; + struct rsa_st { + int pad; + long version; + struct rsa_meth_st *meth; + ...; + }; +""") +xffi.verify( + """#include """, + extra_compile_args=['-w'] +) + +def handle(privkey): + new = xffi.new("struct rsa_st*") + newbuf = xffi.buffer(new) + rsa = OpenSSL.SSL._lib.EVP_PKEY_get1_RSA(privkey._pkey) + oldbuf = OpenSSL.SSL._ffi.buffer(rsa) + newbuf[:] = oldbuf[:] + return new + +def set_flags(privkey, val): + hdl = handle(privkey) + hdl.meth.flags = val + return privkey + +def get_flags(privkey): + hdl = handle(privkey) + return hdl.meth.flags diff --git a/netlib/certutils.py b/netlib/certutils.py index 19148382e..92b219ee6 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -111,6 +111,7 @@ class DNTree: return current.value + class CertStore: """ Implements an in-memory certificate store. @@ -222,6 +223,11 @@ class CertStore: c = (c, None) return (c[0], c[1] or self.privkey) + def gen_pkey(self, cert): + import certffi + certffi.set_flags(self.privkey, 1) + return self.privkey + class _GeneralName(univ.Choice): # We are only interested in dNSNames. We use a default handler to ignore @@ -326,6 +332,7 @@ class SSLCert: return altnames + def get_remote_cert(host, port, sni): c = tcp.TCPClient((host, port)) c.connect() diff --git a/test/test_certutils.py b/test/test_certutils.py index 7f320e7ef..176575ea6 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -1,5 +1,5 @@ import os -from netlib import certutils +from netlib import certutils, certffi import OpenSSL import tutils @@ -83,6 +83,16 @@ class TestCertStore: ret = ca1.get_cert("foo.com", []) assert ret[0].serial == dc[0].serial + def test_gen_pkey(self): + try: + with tutils.tmpdir() as d: + ca1 = certutils.CertStore.from_store(os.path.join(d, "ca1"), "test") + ca2 = certutils.CertStore.from_store(os.path.join(d, "ca2"), "test") + cert = ca1.get_cert("foo.com", []) + assert certffi.get_flags(ca2.gen_pkey(cert[0])) == 1 + finally: + certffi.set_flags(ca2.privkey, 0) + class TestDummyCert: def test_with_ca(self): @@ -125,3 +135,5 @@ class TestSSLCert: d = file(tutils.test_data.path("data/dercert"),"rb").read() s = certutils.SSLCert.from_der(d) assert s.cn + + diff --git a/test/test_tcp.py b/test/test_tcp.py index 814754cd0..ec995702e 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -4,16 +4,6 @@ import mock import tutils from OpenSSL import SSL -class SNIHandler(tcp.BaseHandler): - sni = None - def handle_sni(self, connection): - self.sni = connection.get_servername() - - def handle(self): - self.wfile.write(self.sni) - self.wfile.flush() - - class EchoHandler(tcp.BaseHandler): sni = None def handle_sni(self, connection): @@ -25,58 +15,19 @@ class EchoHandler(tcp.BaseHandler): self.wfile.flush() -class ClientPeernameHandler(tcp.BaseHandler): - def handle(self): - self.wfile.write(str(self.connection.getpeername())) - self.wfile.flush() - - -class CertHandler(tcp.BaseHandler): - sni = None - def handle_sni(self, connection): - self.sni = connection.get_servername() - - def handle(self): - self.wfile.write("%s\n"%self.clientcert.serial) - self.wfile.flush() - - class ClientCipherListHandler(tcp.BaseHandler): sni = None - def handle(self): self.wfile.write("%s"%self.connection.get_cipher_list()) self.wfile.flush() -class CurrentCipherHandler(tcp.BaseHandler): - sni = None - def handle(self): - self.wfile.write("%s"%str(self.get_current_cipher())) - self.wfile.flush() - - -class DisconnectHandler(tcp.BaseHandler): - def handle(self): - self.close() - - class HangHandler(tcp.BaseHandler): def handle(self): while 1: time.sleep(1) -class TimeoutHandler(tcp.BaseHandler): - def handle(self): - self.timeout = False - self.settimeout(0.01) - try: - self.rfile.read(10) - except tcp.NetLibTimeout: - self.timeout = True - - class TestServer(test.ServerTestBase): handler = EchoHandler def test_echo(self): @@ -89,7 +40,10 @@ class TestServer(test.ServerTestBase): class TestServerBind(test.ServerTestBase): - handler = ClientPeernameHandler + class handler(tcp.BaseHandler): + def handle(self): + self.wfile.write(str(self.connection.getpeername())) + self.wfile.flush() def test_bind(self): """ Test to bind to a given random port. Try again if the random port turned out to be blocked. """ @@ -198,7 +152,14 @@ class TestSSLv3Only(test.ServerTestBase): class TestSSLClientCert(test.ServerTestBase): - handler = CertHandler + class handler(tcp.BaseHandler): + sni = None + def handle_sni(self, connection): + self.sni = connection.get_servername() + + def handle(self): + self.wfile.write("%s\n"%self.clientcert.serial) + self.wfile.flush() ssl = dict( cert = tutils.test_data.path("data/server.crt"), key = tutils.test_data.path("data/server.key"), @@ -222,7 +183,15 @@ class TestSSLClientCert(test.ServerTestBase): class TestSNI(test.ServerTestBase): - handler = SNIHandler + class handler(tcp.BaseHandler): + sni = None + def handle_sni(self, connection): + self.sni = connection.get_servername() + + def handle(self): + self.wfile.write(self.sni) + self.wfile.flush() + ssl = dict( cert = tutils.test_data.path("data/server.crt"), key = tutils.test_data.path("data/server.key"), @@ -254,7 +223,11 @@ class TestServerCipherList(test.ServerTestBase): class TestServerCurrentCipher(test.ServerTestBase): - handler = CurrentCipherHandler + class handler(tcp.BaseHandler): + sni = None + def handle(self): + self.wfile.write("%s"%str(self.get_current_cipher())) + self.wfile.flush() ssl = dict( cert = tutils.test_data.path("data/server.crt"), key = tutils.test_data.path("data/server.key"), @@ -300,7 +273,9 @@ class TestClientCipherListError(test.ServerTestBase): class TestSSLDisconnect(test.ServerTestBase): - handler = DisconnectHandler + class handler(tcp.BaseHandler): + def handle(self): + self.close() ssl = dict( cert = tutils.test_data.path("data/server.crt"), key = tutils.test_data.path("data/server.key"), @@ -329,7 +304,15 @@ class TestDisconnect(test.ServerTestBase): class TestServerTimeOut(test.ServerTestBase): - handler = TimeoutHandler + class handler(tcp.BaseHandler): + def handle(self): + self.timeout = False + self.settimeout(0.01) + try: + self.rfile.read(10) + except tcp.NetLibTimeout: + self.timeout = True + def test_timeout(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() @@ -383,6 +366,40 @@ class TestDHParams(test.ServerTestBase): assert ret[0] == "DHE-RSA-AES256-SHA" + +class TestPrivkeyGen(test.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): + with tutils.tmpdir() as d: + ca1 = certutils.CertStore.from_store(d, "test2") + ca2 = certutils.CertStore.from_store(d, "test3") + cert, _ = ca1.get_cert("foo.com", []) + key = ca2.gen_pkey(cert) + self.convert_to_ssl(cert, key) + + def test_privkey(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + tutils.raises("bad record mac", c.convert_to_ssl) + + +class TestPrivkeyGenNoFlags(test.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): + with tutils.tmpdir() as d: + ca1 = certutils.CertStore.from_store(d, "test2") + ca2 = certutils.CertStore.from_store(d, "test3") + cert, _ = ca1.get_cert("foo.com", []) + certffi.set_flags(ca2.privkey, 0) + self.convert_to_ssl(cert, ca2.privkey) + + def test_privkey(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + tutils.raises("unexpected eof", c.convert_to_ssl) + + + class TestTCPClient: def test_conerr(self): c = tcp.TCPClient(("127.0.0.1", 0))