diff --git a/netlib/tcp.py b/netlib/tcp.py index 6ba58d868..b7f2b3bcc 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -66,7 +66,7 @@ class FileLike: if v: try: return self.o.sendall(v) - except SSL.SysCallError: + except SSL.Error: raise NetLibDisconnect() def readline(self, size = None): @@ -125,6 +125,20 @@ class TCPClient: raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) self.connection = connection + def close(self): + """ + Does a hard close of the socket, i.e. a shutdown, followed by a close. + """ + try: + if self.ssl_established: + self.connection.shutdown() + else: + self.connection.shutdown(socket.SHUT_RDWR) + self.connection.close() + except (socket.error, SSL.Error): + # Socket probably already closed + pass + class BaseHandler: """ @@ -170,7 +184,7 @@ class BaseHandler: self.wfile.flush() self.wfile.close() self.rfile.close() - self.connection.close() + self.close() except socket.error: # Remote has disconnected pass @@ -195,6 +209,20 @@ class BaseHandler: def handle(self): # pragma: no cover raise NotImplementedError + def close(self): + """ + Does a hard close of the socket, i.e. a shutdown, followed by a close. + """ + try: + if self.ssl_established: + self.connection.shutdown() + else: + self.connection.shutdown(socket.SHUT_RDWR) + self.connection.close() + except (socket.error, SSL.Error): + # Socket probably already closed + pass + class TCPServer: request_queue_size = 20 @@ -252,7 +280,7 @@ class TCPServer: Called when handle_connection raises an exception. """ # If a thread has persisted after interpreter exit, the module might be - # none. + # none. if traceback: exc = traceback.format_exc() print >> fp, '-'*40 diff --git a/test/test_tcp.py b/test/test_tcp.py index 359890d52..cb27c63b2 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -54,7 +54,7 @@ class EchoHandler(tcp.BaseHandler): class DisconnectHandler(tcp.BaseHandler): def handle(self): - self.finish() + self.close() class TServer(tcp.TCPServer): @@ -102,6 +102,20 @@ class TestServer(ServerTestBase): assert c.rfile.readline() == testval +class TestDisconnect(ServerTestBase): + @classmethod + def makeserver(cls): + return TServer(("127.0.0.1", 0), False, cls.q, EchoHandler) + + def test_echo(self): + testval = "echo!\n" + c = tcp.TCPClient("127.0.0.1", self.port) + c.connect() + c.wfile.write(testval) + c.wfile.flush() + assert c.rfile.readline() == testval + + class TestServerSSL(ServerTestBase): @classmethod def makeserver(cls): @@ -154,6 +168,24 @@ class TestSSLDisconnect(ServerTestBase): c.convert_to_ssl() # Excercise SSL.ZeroReturnError c.rfile.read(10) + c.close() + tutils.raises(tcp.NetLibDisconnect, c.wfile.write, "foo") + tutils.raises(Queue.Empty, self.q.get_nowait) + + +class TestDisconnect(ServerTestBase): + @classmethod + def makeserver(cls): + return TServer(("127.0.0.1", 0), False, cls.q, DisconnectHandler) + + def test_echo(self): + c = tcp.TCPClient("127.0.0.1", self.port) + c.connect() + # Excercise SSL.ZeroReturnError + c.rfile.read(10) + c.wfile.write("foo") + c.close() + c.close() class TestTCPClient: