simplify default ssl params for test servers

This commit is contained in:
Thomas Kriechbaumer 2015-06-05 12:44:29 +02:00
parent 113c5c187f
commit 9883509f89
2 changed files with 32 additions and 74 deletions

View File

@ -4,6 +4,7 @@ import Queue
import cStringIO import cStringIO
import OpenSSL import OpenSSL
from . import tcp, certutils from . import tcp, certutils
import tutils
class ServerThread(threading.Thread): class ServerThread(threading.Thread):
@ -55,22 +56,33 @@ class TServer(tcp.TCPServer):
dhparams, v3_only dhparams, v3_only
""" """
tcp.TCPServer.__init__(self, addr) tcp.TCPServer.__init__(self, addr)
self.ssl, self.q = ssl, q
if ssl is True:
self.ssl = dict()
elif isinstance(ssl, dict):
self.ssl = ssl
else:
self.ssl = None
self.q = q
self.handler_klass = handler_klass self.handler_klass = handler_klass
self.last_handler = None self.last_handler = None
def handle_client_connection(self, request, client_address): def handle_client_connection(self, request, client_address):
h = self.handler_klass(request, client_address, self) h = self.handler_klass(request, client_address, self)
self.last_handler = h self.last_handler = h
if self.ssl: if self.ssl is not None:
cert = certutils.SSLCert.from_pem( raw_cert = self.ssl.get(
file(self.ssl["cert"], "rb").read() "cert",
) tutils.test_data.path("data/server.crt"))
raw = file(self.ssl["key"], "rb").read() cert = certutils.SSLCert.from_pem(file(raw_cert, "rb").read())
raw_key = self.ssl.get(
"key",
tutils.test_data.path("data/server.key"))
key = OpenSSL.crypto.load_privatekey( key = OpenSSL.crypto.load_privatekey(
OpenSSL.crypto.FILETYPE_PEM, OpenSSL.crypto.FILETYPE_PEM,
raw) file(raw_key, "rb").read())
if self.ssl["v3_only"]: if self.ssl.get("v3_only", False):
method = tcp.SSLv3_METHOD method = tcp.SSLv3_METHOD
options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1 options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1
else: else:
@ -81,7 +93,7 @@ class TServer(tcp.TCPServer):
method=method, method=method,
options=options, options=options,
handle_sni=getattr(h, "handle_sni", None), handle_sni=getattr(h, "handle_sni", None),
request_client_cert=self.ssl["request_client_cert"], request_client_cert=self.ssl.get("request_client_cert", None),
cipher_list=self.ssl.get("cipher_list", None), cipher_list=self.ssl.get("cipher_list", None),
dhparams=self.ssl.get("dhparams", None), dhparams=self.ssl.get("dhparams", None),
chain_file=self.ssl.get("chain_file", None), chain_file=self.ssl.get("chain_file", None),

View File

@ -135,10 +135,6 @@ class TestFinishFail(test.ServerTestBase):
class TestServerSSL(test.ServerTestBase): class TestServerSSL(test.ServerTestBase):
handler = EchoHandler handler = EchoHandler
ssl = dict( ssl = dict(
cert=tutils.test_data.path("data/server.crt"),
key=tutils.test_data.path("data/server.key"),
request_client_cert=False,
v3_only=False,
cipher_list="AES256-SHA", cipher_list="AES256-SHA",
chain_file=tutils.test_data.path("data/server.crt") chain_file=tutils.test_data.path("data/server.crt")
) )
@ -165,8 +161,6 @@ class TestServerSSL(test.ServerTestBase):
class TestSSLv3Only(test.ServerTestBase): class TestSSLv3Only(test.ServerTestBase):
handler = EchoHandler handler = EchoHandler
ssl = dict( ssl = dict(
cert=tutils.test_data.path("data/server.crt"),
key=tutils.test_data.path("data/server.key"),
request_client_cert=False, request_client_cert=False,
v3_only=True v3_only=True
) )
@ -188,9 +182,8 @@ class TestSSLClientCert(test.ServerTestBase):
def handle(self): def handle(self):
self.wfile.write("%s\n" % self.clientcert.serial) self.wfile.write("%s\n" % self.clientcert.serial)
self.wfile.flush() self.wfile.flush()
ssl = dict( ssl = dict(
cert=tutils.test_data.path("data/server.crt"),
key=tutils.test_data.path("data/server.key"),
request_client_cert=True, request_client_cert=True,
v3_only=False v3_only=False
) )
@ -224,12 +217,7 @@ class TestSNI(test.ServerTestBase):
self.wfile.write(self.sni) self.wfile.write(self.sni)
self.wfile.flush() self.wfile.flush()
ssl = dict( ssl = True
cert=tutils.test_data.path("data/server.crt"),
key=tutils.test_data.path("data/server.key"),
request_client_cert=False,
v3_only=False
)
def test_echo(self): def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
@ -242,10 +230,6 @@ class TestSNI(test.ServerTestBase):
class TestServerCipherList(test.ServerTestBase): class TestServerCipherList(test.ServerTestBase):
handler = ClientCipherListHandler handler = ClientCipherListHandler
ssl = dict( ssl = dict(
cert=tutils.test_data.path("data/server.crt"),
key=tutils.test_data.path("data/server.key"),
request_client_cert=False,
v3_only=False,
cipher_list='RC4-SHA' cipher_list='RC4-SHA'
) )
@ -264,11 +248,8 @@ class TestServerCurrentCipher(test.ServerTestBase):
def handle(self): def handle(self):
self.wfile.write("%s" % str(self.get_current_cipher())) self.wfile.write("%s" % str(self.get_current_cipher()))
self.wfile.flush() self.wfile.flush()
ssl = dict( ssl = dict(
cert=tutils.test_data.path("data/server.crt"),
key=tutils.test_data.path("data/server.key"),
request_client_cert=False,
v3_only=False,
cipher_list='RC4-SHA' cipher_list='RC4-SHA'
) )
@ -282,10 +263,6 @@ class TestServerCurrentCipher(test.ServerTestBase):
class TestServerCipherListError(test.ServerTestBase): class TestServerCipherListError(test.ServerTestBase):
handler = ClientCipherListHandler handler = ClientCipherListHandler
ssl = dict( ssl = dict(
cert=tutils.test_data.path("data/server.crt"),
key=tutils.test_data.path("data/server.key"),
request_client_cert=False,
v3_only=False,
cipher_list='bogus' cipher_list='bogus'
) )
@ -298,10 +275,6 @@ class TestServerCipherListError(test.ServerTestBase):
class TestClientCipherListError(test.ServerTestBase): class TestClientCipherListError(test.ServerTestBase):
handler = ClientCipherListHandler handler = ClientCipherListHandler
ssl = dict( ssl = dict(
cert=tutils.test_data.path("data/server.crt"),
key=tutils.test_data.path("data/server.key"),
request_client_cert=False,
v3_only=False,
cipher_list='RC4-SHA' cipher_list='RC4-SHA'
) )
@ -321,12 +294,8 @@ class TestSSLDisconnect(test.ServerTestBase):
def handle(self): def handle(self):
self.finish() self.finish()
ssl = dict(
cert=tutils.test_data.path("data/server.crt"), ssl = True
key=tutils.test_data.path("data/server.key"),
request_client_cert=False,
v3_only=False
)
def test_echo(self): def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
@ -341,12 +310,7 @@ class TestSSLDisconnect(test.ServerTestBase):
class TestSSLHardDisconnect(test.ServerTestBase): class TestSSLHardDisconnect(test.ServerTestBase):
handler = HardDisconnectHandler handler = HardDisconnectHandler
ssl = dict( ssl = True
cert=tutils.test_data.path("data/server.crt"),
key=tutils.test_data.path("data/server.key"),
request_client_cert=False,
v3_only=False
)
def test_echo(self): def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
@ -400,13 +364,9 @@ class TestTimeOut(test.ServerTestBase):
class TestALPN(test.ServerTestBase): class TestALPN(test.ServerTestBase):
handler = HangHandler handler = EchoHandler
ssl = dict( ssl = dict(
cert=tutils.test_data.path("data/server.crt"), alpn_select="foobar"
key=tutils.test_data.path("data/server.key"),
request_client_cert=False,
v3_only=False,
alpn_select="h2"
) )
if OpenSSL._util.lib.Cryptography_HAS_ALPN: if OpenSSL._util.lib.Cryptography_HAS_ALPN:
@ -414,19 +374,13 @@ class TestALPN(test.ServerTestBase):
def test_alpn(self): def test_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect() c.connect()
c.convert_to_ssl(alpn_protos=["h2"]) c.convert_to_ssl(alpn_protos=["foobar"])
print "ALPN: %s" % c.get_alpn_proto_negotiated() assert c.get_alpn_proto_negotiated() == "foobar"
assert c.get_alpn_proto_negotiated() == "h2"
class TestSSLTimeOut(test.ServerTestBase): class TestSSLTimeOut(test.ServerTestBase):
handler = HangHandler handler = HangHandler
ssl = dict( ssl = True
cert=tutils.test_data.path("data/server.crt"),
key=tutils.test_data.path("data/server.key"),
request_client_cert=False,
v3_only=False
)
def test_timeout_client(self): def test_timeout_client(self):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
@ -439,10 +393,6 @@ class TestSSLTimeOut(test.ServerTestBase):
class TestDHParams(test.ServerTestBase): class TestDHParams(test.ServerTestBase):
handler = HangHandler handler = HangHandler
ssl = dict( ssl = dict(
cert=tutils.test_data.path("data/server.crt"),
key=tutils.test_data.path("data/server.key"),
request_client_cert=False,
v3_only=False,
dhparams=certutils.CertStore.load_dhparam( dhparams=certutils.CertStore.load_dhparam(
tutils.test_data.path("data/dhparam.pem"), tutils.test_data.path("data/dhparam.pem"),
), ),
@ -643,10 +593,6 @@ class TestAddress:
class TestSSLKeyLogger(test.ServerTestBase): class TestSSLKeyLogger(test.ServerTestBase):
handler = EchoHandler handler = EchoHandler
ssl = dict( ssl = dict(
cert=tutils.test_data.path("data/server.crt"),
key=tutils.test_data.path("data/server.key"),
request_client_cert=False,
v3_only=False,
cipher_list="AES256-SHA" cipher_list="AES256-SHA"
) )