Certificate flags

This commit is contained in:
Aldo Cortesi 2014-03-10 17:29:27 +13:00
parent 2a12aa3c47
commit f5cc63d653
5 changed files with 130 additions and 57 deletions

1
.gitignore vendored
View File

@ -8,3 +8,4 @@ MANIFEST
*.swo *.swo
.coverage .coverage
.idea .idea
__pycache__

36
netlib/certffi.py Normal file
View File

@ -0,0 +1,36 @@
import cffi
import OpenSSL
xffi = cffi.FFI()
xffi.cdef ("""
struct rsa_meth_st {
int flags;
...;
};
struct rsa_st {
int pad;
long version;
struct rsa_meth_st *meth;
...;
};
""")
xffi.verify(
"""#include <openssl/rsa.h>""",
extra_compile_args=['-w']
)
def handle(privkey):
new = xffi.new("struct rsa_st*")
newbuf = xffi.buffer(new)
rsa = OpenSSL.SSL._lib.EVP_PKEY_get1_RSA(privkey._pkey)
oldbuf = OpenSSL.SSL._ffi.buffer(rsa)
newbuf[:] = oldbuf[:]
return new
def set_flags(privkey, val):
hdl = handle(privkey)
hdl.meth.flags = val
return privkey
def get_flags(privkey):
hdl = handle(privkey)
return hdl.meth.flags

View File

@ -111,6 +111,7 @@ class DNTree:
return current.value return current.value
class CertStore: class CertStore:
""" """
Implements an in-memory certificate store. Implements an in-memory certificate store.
@ -222,6 +223,11 @@ class CertStore:
c = (c, None) c = (c, None)
return (c[0], c[1] or self.privkey) return (c[0], c[1] or self.privkey)
def gen_pkey(self, cert):
import certffi
certffi.set_flags(self.privkey, 1)
return self.privkey
class _GeneralName(univ.Choice): class _GeneralName(univ.Choice):
# We are only interested in dNSNames. We use a default handler to ignore # We are only interested in dNSNames. We use a default handler to ignore
@ -326,6 +332,7 @@ class SSLCert:
return altnames return altnames
def get_remote_cert(host, port, sni): def get_remote_cert(host, port, sni):
c = tcp.TCPClient((host, port)) c = tcp.TCPClient((host, port))
c.connect() c.connect()

View File

@ -1,5 +1,5 @@
import os import os
from netlib import certutils from netlib import certutils, certffi
import OpenSSL import OpenSSL
import tutils import tutils
@ -83,6 +83,16 @@ class TestCertStore:
ret = ca1.get_cert("foo.com", []) ret = ca1.get_cert("foo.com", [])
assert ret[0].serial == dc[0].serial assert ret[0].serial == dc[0].serial
def test_gen_pkey(self):
try:
with tutils.tmpdir() as d:
ca1 = certutils.CertStore.from_store(os.path.join(d, "ca1"), "test")
ca2 = certutils.CertStore.from_store(os.path.join(d, "ca2"), "test")
cert = ca1.get_cert("foo.com", [])
assert certffi.get_flags(ca2.gen_pkey(cert[0])) == 1
finally:
certffi.set_flags(ca2.privkey, 0)
class TestDummyCert: class TestDummyCert:
def test_with_ca(self): def test_with_ca(self):
@ -125,3 +135,5 @@ class TestSSLCert:
d = file(tutils.test_data.path("data/dercert"),"rb").read() d = file(tutils.test_data.path("data/dercert"),"rb").read()
s = certutils.SSLCert.from_der(d) s = certutils.SSLCert.from_der(d)
assert s.cn assert s.cn

View File

@ -4,16 +4,6 @@ import mock
import tutils import tutils
from OpenSSL import SSL from OpenSSL import SSL
class SNIHandler(tcp.BaseHandler):
sni = None
def handle_sni(self, connection):
self.sni = connection.get_servername()
def handle(self):
self.wfile.write(self.sni)
self.wfile.flush()
class EchoHandler(tcp.BaseHandler): class EchoHandler(tcp.BaseHandler):
sni = None sni = None
def handle_sni(self, connection): def handle_sni(self, connection):
@ -25,58 +15,19 @@ class EchoHandler(tcp.BaseHandler):
self.wfile.flush() self.wfile.flush()
class ClientPeernameHandler(tcp.BaseHandler):
def handle(self):
self.wfile.write(str(self.connection.getpeername()))
self.wfile.flush()
class CertHandler(tcp.BaseHandler):
sni = None
def handle_sni(self, connection):
self.sni = connection.get_servername()
def handle(self):
self.wfile.write("%s\n"%self.clientcert.serial)
self.wfile.flush()
class ClientCipherListHandler(tcp.BaseHandler): class ClientCipherListHandler(tcp.BaseHandler):
sni = None sni = None
def handle(self): def handle(self):
self.wfile.write("%s"%self.connection.get_cipher_list()) self.wfile.write("%s"%self.connection.get_cipher_list())
self.wfile.flush() self.wfile.flush()
class CurrentCipherHandler(tcp.BaseHandler):
sni = None
def handle(self):
self.wfile.write("%s"%str(self.get_current_cipher()))
self.wfile.flush()
class DisconnectHandler(tcp.BaseHandler):
def handle(self):
self.close()
class HangHandler(tcp.BaseHandler): class HangHandler(tcp.BaseHandler):
def handle(self): def handle(self):
while 1: while 1:
time.sleep(1) time.sleep(1)
class TimeoutHandler(tcp.BaseHandler):
def handle(self):
self.timeout = False
self.settimeout(0.01)
try:
self.rfile.read(10)
except tcp.NetLibTimeout:
self.timeout = True
class TestServer(test.ServerTestBase): class TestServer(test.ServerTestBase):
handler = EchoHandler handler = EchoHandler
def test_echo(self): def test_echo(self):
@ -89,7 +40,10 @@ class TestServer(test.ServerTestBase):
class TestServerBind(test.ServerTestBase): class TestServerBind(test.ServerTestBase):
handler = ClientPeernameHandler class handler(tcp.BaseHandler):
def handle(self):
self.wfile.write(str(self.connection.getpeername()))
self.wfile.flush()
def test_bind(self): def test_bind(self):
""" Test to bind to a given random port. Try again if the random port turned out to be blocked. """ """ Test to bind to a given random port. Try again if the random port turned out to be blocked. """
@ -198,7 +152,14 @@ class TestSSLv3Only(test.ServerTestBase):
class TestSSLClientCert(test.ServerTestBase): class TestSSLClientCert(test.ServerTestBase):
handler = CertHandler class handler(tcp.BaseHandler):
sni = None
def handle_sni(self, connection):
self.sni = connection.get_servername()
def handle(self):
self.wfile.write("%s\n"%self.clientcert.serial)
self.wfile.flush()
ssl = dict( ssl = dict(
cert = tutils.test_data.path("data/server.crt"), cert = tutils.test_data.path("data/server.crt"),
key = tutils.test_data.path("data/server.key"), key = tutils.test_data.path("data/server.key"),
@ -222,7 +183,15 @@ class TestSSLClientCert(test.ServerTestBase):
class TestSNI(test.ServerTestBase): class TestSNI(test.ServerTestBase):
handler = SNIHandler class handler(tcp.BaseHandler):
sni = None
def handle_sni(self, connection):
self.sni = connection.get_servername()
def handle(self):
self.wfile.write(self.sni)
self.wfile.flush()
ssl = dict( ssl = dict(
cert = tutils.test_data.path("data/server.crt"), cert = tutils.test_data.path("data/server.crt"),
key = tutils.test_data.path("data/server.key"), key = tutils.test_data.path("data/server.key"),
@ -254,7 +223,11 @@ class TestServerCipherList(test.ServerTestBase):
class TestServerCurrentCipher(test.ServerTestBase): class TestServerCurrentCipher(test.ServerTestBase):
handler = CurrentCipherHandler class handler(tcp.BaseHandler):
sni = None
def handle(self):
self.wfile.write("%s"%str(self.get_current_cipher()))
self.wfile.flush()
ssl = dict( ssl = dict(
cert = tutils.test_data.path("data/server.crt"), cert = tutils.test_data.path("data/server.crt"),
key = tutils.test_data.path("data/server.key"), key = tutils.test_data.path("data/server.key"),
@ -300,7 +273,9 @@ class TestClientCipherListError(test.ServerTestBase):
class TestSSLDisconnect(test.ServerTestBase): class TestSSLDisconnect(test.ServerTestBase):
handler = DisconnectHandler class handler(tcp.BaseHandler):
def handle(self):
self.close()
ssl = dict( ssl = dict(
cert = tutils.test_data.path("data/server.crt"), cert = tutils.test_data.path("data/server.crt"),
key = tutils.test_data.path("data/server.key"), key = tutils.test_data.path("data/server.key"),
@ -329,7 +304,15 @@ class TestDisconnect(test.ServerTestBase):
class TestServerTimeOut(test.ServerTestBase): class TestServerTimeOut(test.ServerTestBase):
handler = TimeoutHandler class handler(tcp.BaseHandler):
def handle(self):
self.timeout = False
self.settimeout(0.01)
try:
self.rfile.read(10)
except tcp.NetLibTimeout:
self.timeout = True
def test_timeout(self): def test_timeout(self):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect() c.connect()
@ -383,6 +366,40 @@ class TestDHParams(test.ServerTestBase):
assert ret[0] == "DHE-RSA-AES256-SHA" assert ret[0] == "DHE-RSA-AES256-SHA"
class TestPrivkeyGen(test.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
with tutils.tmpdir() as d:
ca1 = certutils.CertStore.from_store(d, "test2")
ca2 = certutils.CertStore.from_store(d, "test3")
cert, _ = ca1.get_cert("foo.com", [])
key = ca2.gen_pkey(cert)
self.convert_to_ssl(cert, key)
def test_privkey(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
tutils.raises("bad record mac", c.convert_to_ssl)
class TestPrivkeyGenNoFlags(test.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
with tutils.tmpdir() as d:
ca1 = certutils.CertStore.from_store(d, "test2")
ca2 = certutils.CertStore.from_store(d, "test3")
cert, _ = ca1.get_cert("foo.com", [])
certffi.set_flags(ca2.privkey, 0)
self.convert_to_ssl(cert, ca2.privkey)
def test_privkey(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
tutils.raises("unexpected eof", c.convert_to_ssl)
class TestTCPClient: class TestTCPClient:
def test_conerr(self): def test_conerr(self):
c = tcp.TCPClient(("127.0.0.1", 0)) c = tcp.TCPClient(("127.0.0.1", 0))