From 763cb90b66b23cd94b6e37df3d4c7b8e7f89492a Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 28 Jan 2014 17:26:35 +0100 Subject: [PATCH] add tcp.Address to unify ipv4/ipv6 address handling --- netlib/certutils.py | 2 +- netlib/tcp.py | 56 +++++++++++++++++++++++++++++++++------------ netlib/test.py | 11 ++++----- test/test_http.py | 2 +- test/test_tcp.py | 36 ++++++++++++++--------------- 5 files changed, 67 insertions(+), 40 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index 0349bec7a..94294f6ec 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -237,7 +237,7 @@ class SSLCert: def get_remote_cert(host, port, sni): - c = tcp.TCPClient(host, port) + c = tcp.TCPClient((host, port)) c.connect() c.convert_to_ssl(sni=sni) return c.cert diff --git a/netlib/tcp.py b/netlib/tcp.py index e48f4f6b9..bad166d0b 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -173,6 +173,35 @@ class Reader(_FileLike): return result +class Address(tuple): + """ + This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information. + """ + def __new__(cls, address, use_ipv6=False): + a = super(Address, cls).__new__(cls, tuple(address)) + a.family = socket.AF_INET6 if use_ipv6 else socket.AF_INET + return a + + @classmethod + def wrap(cls, t): + if isinstance(t, cls): + return t + else: + return cls(t) + + @property + def host(self): + return self[0] + + @property + def port(self): + return self[1] + + @property + def is_ipv6(self): + return self.family == socket.AF_INET6 + + class SocketCloseMixin: def finish(self): self.finished = True @@ -209,10 +238,9 @@ class SocketCloseMixin: class TCPClient(SocketCloseMixin): rbufsize = -1 wbufsize = -1 - def __init__(self, host, port, source_address=None, use_ipv6=False): - self.host, self.port = host, port + def __init__(self, address, source_address=None): + self.address = Address.wrap(address) self.source_address = source_address - self.use_ipv6 = use_ipv6 self.connection, self.rfile, self.wfile = None, None, None self.cert = None self.ssl_established = False @@ -245,14 +273,14 @@ class TCPClient(SocketCloseMixin): def connect(self): try: - connection = socket.socket(socket.AF_INET6 if self.use_ipv6 else socket.AF_INET, socket.SOCK_STREAM) + connection = socket.socket(self.address.family, socket.SOCK_STREAM) if self.source_address: connection.bind(self.source_address) - connection.connect((self.host, self.port)) + connection.connect(self.address) self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) except (socket.error, IOError), err: - raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) + raise NetLibError('Error connecting to "%s": %s' % (self.address[0], err)) self.connection = connection def settimeout(self, n): @@ -269,8 +297,9 @@ class BaseHandler(SocketCloseMixin): """ rbufsize = -1 wbufsize = -1 - def __init__(self, connection): + def __init__(self, connection, address): self.connection = connection + self.address = Address.wrap(address) self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) @@ -339,19 +368,18 @@ class BaseHandler(SocketCloseMixin): class TCPServer: request_queue_size = 20 - def __init__(self, server_address, use_ipv6=False): - self.server_address = server_address - self.use_ipv6 = use_ipv6 + def __init__(self, address): + self.address = Address.wrap(address) self.__is_shut_down = threading.Event() self.__shutdown_request = False - self.socket = socket.socket(socket.AF_INET6 if self.use_ipv6 else socket.AF_INET, socket.SOCK_STREAM) + self.socket = socket.socket(self.address.family, socket.SOCK_STREAM) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.socket.bind(self.server_address) - self.server_address = self.socket.getsockname() - self.port = self.server_address[1] + self.socket.bind(self.address) + self.address = Address.wrap(self.socket.getsockname()) self.socket.listen(self.request_queue_size) def connection_thread(self, connection, client_address): + client_address = Address(client_address) try: self.handle_client_connection(connection, client_address) except: diff --git a/netlib/test.py b/netlib/test.py index f5599082a..565b97cd4 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -17,19 +17,18 @@ class ServerTestBase: ssl = None handler = None addr = ("localhost", 0) - use_ipv6 = False @classmethod def setupAll(cls): cls.q = Queue.Queue() s = cls.makeserver() - cls.port = s.port + cls.port = s.address.port cls.server = ServerThread(s) cls.server.start() @classmethod def makeserver(cls): - return TServer(cls.ssl, cls.q, cls.handler, cls.addr, cls.use_ipv6) + return TServer(cls.ssl, cls.q, cls.handler, cls.addr) @classmethod def teardownAll(cls): @@ -41,17 +40,17 @@ class ServerTestBase: class TServer(tcp.TCPServer): - def __init__(self, ssl, q, handler_klass, addr, use_ipv6): + def __init__(self, ssl, q, handler_klass, addr): """ ssl: A {cert, key, v3_only} dict. """ - tcp.TCPServer.__init__(self, addr, use_ipv6=use_ipv6) + tcp.TCPServer.__init__(self, addr) self.ssl, self.q = ssl, q self.handler_klass = handler_klass self.last_handler = None def handle_client_connection(self, request, client_address): - h = self.handler_klass(request) + h = self.handler_klass(request, client_address) self.last_handler = h if self.ssl: cert = certutils.SSLCert.from_pem( diff --git a/test/test_http.py b/test/test_http.py index a03861151..e80e4b8f4 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -223,7 +223,7 @@ class TestReadResponseNoContentLength(test.ServerTestBase): handler = NoContentLengthHTTPHandler def test_no_content_length(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() httpversion, code, msg, headers, content = http.read_response(c.rfile, "GET", None) assert content == "bar\r\n\r\n" diff --git a/test/test_tcp.py b/test/test_tcp.py index 7f2c21c4e..49e206356 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -73,7 +73,7 @@ class TestServer(test.ServerTestBase): handler = EchoHandler def test_echo(self): testval = "echo!\n" - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.wfile.write(testval) c.wfile.flush() @@ -88,7 +88,7 @@ class TestServerBind(test.ServerTestBase): for i in range(20): random_port = random.randrange(1024, 65535) try: - c = tcp.TCPClient("127.0.0.1", self.port, source_address=("127.0.0.1", random_port)) + c = tcp.TCPClient(("127.0.0.1", self.port), source_address=("127.0.0.1", random_port)) c.connect() assert c.rfile.readline() == str(("127.0.0.1", random_port)) return @@ -98,11 +98,11 @@ class TestServerBind(test.ServerTestBase): class TestServerIPv6(test.ServerTestBase): handler = EchoHandler - use_ipv6 = True + addr = tcp.Address(("localhost", 0), use_ipv6=True) def test_echo(self): testval = "echo!\n" - c = tcp.TCPClient("::1", self.port, use_ipv6=True) + c = tcp.TCPClient(tcp.Address(("::1", self.port), use_ipv6=True)) c.connect() c.wfile.write(testval) c.wfile.flush() @@ -127,7 +127,7 @@ class TestFinishFail(test.ServerTestBase): handler = FinishFailHandler def test_disconnect_in_finish(self): testval = "echo!\n" - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.wfile.write("foo\n") c.wfile.flush() @@ -137,7 +137,7 @@ class TestDisconnect(test.ServerTestBase): handler = EchoHandler def test_echo(self): testval = "echo!\n" - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.wfile.write(testval) c.wfile.flush() @@ -153,7 +153,7 @@ class TestServerSSL(test.ServerTestBase): v3_only = False ) def test_echo(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl(sni="foo.com", options=tcp.OP_ALL) testval = "echo!\n" @@ -174,7 +174,7 @@ class TestSSLv3Only(test.ServerTestBase): v3_only = True ) def test_failure(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com", method=tcp.TLSv1_METHOD) @@ -188,13 +188,13 @@ class TestSSLClientCert(test.ServerTestBase): v3_only = False ) def test_clientcert(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl(cert=tutils.test_data.path("data/clientcert/client.pem")) assert c.rfile.readline().strip() == "1" def test_clientcert_err(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() tutils.raises( tcp.NetLibError, @@ -212,7 +212,7 @@ class TestSNI(test.ServerTestBase): v3_only = False ) def test_echo(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl(sni="foo.com") assert c.rfile.readline() == "foo.com" @@ -228,7 +228,7 @@ class TestClientCipherList(test.ServerTestBase): cipher_list = 'RC4-SHA' ) def test_echo(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl(sni="foo.com") assert c.rfile.readline() == "['RC4-SHA']" @@ -243,7 +243,7 @@ class TestSSLDisconnect(test.ServerTestBase): v3_only = False ) def test_echo(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl() # Excercise SSL.ZeroReturnError @@ -255,7 +255,7 @@ class TestSSLDisconnect(test.ServerTestBase): class TestDisconnect(test.ServerTestBase): def test_echo(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.rfile.read(10) c.wfile.write("foo") @@ -266,7 +266,7 @@ class TestDisconnect(test.ServerTestBase): class TestServerTimeOut(test.ServerTestBase): handler = TimeoutHandler def test_timeout(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() time.sleep(0.3) assert self.last_handler.timeout @@ -275,7 +275,7 @@ class TestServerTimeOut(test.ServerTestBase): class TestTimeOut(test.ServerTestBase): handler = HangHandler def test_timeout(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.settimeout(0.1) assert c.gettimeout() == 0.1 @@ -291,7 +291,7 @@ class TestSSLTimeOut(test.ServerTestBase): v3_only = False ) def test_timeout_client(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl() c.settimeout(0.1) @@ -300,7 +300,7 @@ class TestSSLTimeOut(test.ServerTestBase): class TestTCPClient: def test_conerr(self): - c = tcp.TCPClient("127.0.0.1", 0) + c = tcp.TCPClient(("127.0.0.1", 0)) tutils.raises(tcp.NetLibError, c.connect)