From 9883509f894dde57c8a71340a69581ac46c44f51 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 5 Jun 2015 12:44:29 +0200 Subject: [PATCH] simplify default ssl params for test servers --- netlib/test.py | 30 +++++++++++++------ test/test_tcp.py | 76 +++++++----------------------------------------- 2 files changed, 32 insertions(+), 74 deletions(-) diff --git a/netlib/test.py b/netlib/test.py index 14f501577..ee8c66852 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -4,6 +4,7 @@ import Queue import cStringIO import OpenSSL from . import tcp, certutils +import tutils class ServerThread(threading.Thread): @@ -55,22 +56,33 @@ class TServer(tcp.TCPServer): dhparams, v3_only """ tcp.TCPServer.__init__(self, addr) - self.ssl, self.q = ssl, q + + if ssl is True: + self.ssl = dict() + elif isinstance(ssl, dict): + self.ssl = ssl + else: + self.ssl = None + + self.q = q self.handler_klass = handler_klass self.last_handler = None def handle_client_connection(self, request, client_address): h = self.handler_klass(request, client_address, self) self.last_handler = h - if self.ssl: - cert = certutils.SSLCert.from_pem( - file(self.ssl["cert"], "rb").read() - ) - raw = file(self.ssl["key"], "rb").read() + if self.ssl is not None: + raw_cert = self.ssl.get( + "cert", + tutils.test_data.path("data/server.crt")) + cert = certutils.SSLCert.from_pem(file(raw_cert, "rb").read()) + raw_key = self.ssl.get( + "key", + tutils.test_data.path("data/server.key")) key = OpenSSL.crypto.load_privatekey( OpenSSL.crypto.FILETYPE_PEM, - raw) - if self.ssl["v3_only"]: + file(raw_key, "rb").read()) + if self.ssl.get("v3_only", False): method = tcp.SSLv3_METHOD options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1 else: @@ -81,7 +93,7 @@ class TServer(tcp.TCPServer): method=method, options=options, handle_sni=getattr(h, "handle_sni", None), - request_client_cert=self.ssl["request_client_cert"], + request_client_cert=self.ssl.get("request_client_cert", None), cipher_list=self.ssl.get("cipher_list", None), dhparams=self.ssl.get("dhparams", None), chain_file=self.ssl.get("chain_file", None), diff --git a/test/test_tcp.py b/test/test_tcp.py index 14ba555d5..cbe92f3c8 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -135,10 +135,6 @@ class TestFinishFail(test.ServerTestBase): class TestServerSSL(test.ServerTestBase): handler = EchoHandler ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), - request_client_cert=False, - v3_only=False, cipher_list="AES256-SHA", chain_file=tutils.test_data.path("data/server.crt") ) @@ -165,8 +161,6 @@ class TestServerSSL(test.ServerTestBase): class TestSSLv3Only(test.ServerTestBase): handler = EchoHandler ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), request_client_cert=False, v3_only=True ) @@ -188,9 +182,8 @@ class TestSSLClientCert(test.ServerTestBase): def handle(self): self.wfile.write("%s\n" % self.clientcert.serial) self.wfile.flush() + ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), request_client_cert=True, v3_only=False ) @@ -224,12 +217,7 @@ class TestSNI(test.ServerTestBase): self.wfile.write(self.sni) self.wfile.flush() - ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), - request_client_cert=False, - v3_only=False - ) + ssl = True def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) @@ -242,10 +230,6 @@ class TestSNI(test.ServerTestBase): class TestServerCipherList(test.ServerTestBase): handler = ClientCipherListHandler ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), - request_client_cert=False, - v3_only=False, cipher_list='RC4-SHA' ) @@ -264,11 +248,8 @@ class TestServerCurrentCipher(test.ServerTestBase): def handle(self): self.wfile.write("%s" % str(self.get_current_cipher())) self.wfile.flush() + ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), - request_client_cert=False, - v3_only=False, cipher_list='RC4-SHA' ) @@ -282,10 +263,6 @@ class TestServerCurrentCipher(test.ServerTestBase): class TestServerCipherListError(test.ServerTestBase): handler = ClientCipherListHandler ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), - request_client_cert=False, - v3_only=False, cipher_list='bogus' ) @@ -298,10 +275,6 @@ class TestServerCipherListError(test.ServerTestBase): class TestClientCipherListError(test.ServerTestBase): handler = ClientCipherListHandler ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), - request_client_cert=False, - v3_only=False, cipher_list='RC4-SHA' ) @@ -321,12 +294,8 @@ class TestSSLDisconnect(test.ServerTestBase): def handle(self): self.finish() - ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), - request_client_cert=False, - v3_only=False - ) + + ssl = True def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) @@ -341,12 +310,7 @@ class TestSSLDisconnect(test.ServerTestBase): class TestSSLHardDisconnect(test.ServerTestBase): handler = HardDisconnectHandler - ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), - request_client_cert=False, - v3_only=False - ) + ssl = True def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) @@ -400,13 +364,9 @@ class TestTimeOut(test.ServerTestBase): class TestALPN(test.ServerTestBase): - handler = HangHandler + handler = EchoHandler ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), - request_client_cert=False, - v3_only=False, - alpn_select="h2" + alpn_select="foobar" ) if OpenSSL._util.lib.Cryptography_HAS_ALPN: @@ -414,19 +374,13 @@ class TestALPN(test.ServerTestBase): def test_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(alpn_protos=["h2"]) - print "ALPN: %s" % c.get_alpn_proto_negotiated() - assert c.get_alpn_proto_negotiated() == "h2" + c.convert_to_ssl(alpn_protos=["foobar"]) + assert c.get_alpn_proto_negotiated() == "foobar" class TestSSLTimeOut(test.ServerTestBase): handler = HangHandler - ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), - request_client_cert=False, - v3_only=False - ) + ssl = True def test_timeout_client(self): c = tcp.TCPClient(("127.0.0.1", self.port)) @@ -439,10 +393,6 @@ class TestSSLTimeOut(test.ServerTestBase): class TestDHParams(test.ServerTestBase): handler = HangHandler ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), - request_client_cert=False, - v3_only=False, dhparams=certutils.CertStore.load_dhparam( tutils.test_data.path("data/dhparam.pem"), ), @@ -643,10 +593,6 @@ class TestAddress: class TestSSLKeyLogger(test.ServerTestBase): handler = EchoHandler ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), - request_client_cert=False, - v3_only=False, cipher_list="AES256-SHA" )