close() methods for clients and servers.

This commit is contained in:
Aldo Cortesi 2012-07-20 14:43:51 +12:00
parent 1227369db3
commit 63d789109a
2 changed files with 64 additions and 4 deletions

View File

@ -66,7 +66,7 @@ class FileLike:
if v: if v:
try: try:
return self.o.sendall(v) return self.o.sendall(v)
except SSL.SysCallError: except SSL.Error:
raise NetLibDisconnect() raise NetLibDisconnect()
def readline(self, size = None): def readline(self, size = None):
@ -125,6 +125,20 @@ class TCPClient:
raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) raise NetLibError('Error connecting to "%s": %s' % (self.host, err))
self.connection = connection self.connection = connection
def close(self):
"""
Does a hard close of the socket, i.e. a shutdown, followed by a close.
"""
try:
if self.ssl_established:
self.connection.shutdown()
else:
self.connection.shutdown(socket.SHUT_RDWR)
self.connection.close()
except (socket.error, SSL.Error):
# Socket probably already closed
pass
class BaseHandler: class BaseHandler:
""" """
@ -170,7 +184,7 @@ class BaseHandler:
self.wfile.flush() self.wfile.flush()
self.wfile.close() self.wfile.close()
self.rfile.close() self.rfile.close()
self.connection.close() self.close()
except socket.error: except socket.error:
# Remote has disconnected # Remote has disconnected
pass pass
@ -195,6 +209,20 @@ class BaseHandler:
def handle(self): # pragma: no cover def handle(self): # pragma: no cover
raise NotImplementedError raise NotImplementedError
def close(self):
"""
Does a hard close of the socket, i.e. a shutdown, followed by a close.
"""
try:
if self.ssl_established:
self.connection.shutdown()
else:
self.connection.shutdown(socket.SHUT_RDWR)
self.connection.close()
except (socket.error, SSL.Error):
# Socket probably already closed
pass
class TCPServer: class TCPServer:
request_queue_size = 20 request_queue_size = 20

View File

@ -54,7 +54,7 @@ class EchoHandler(tcp.BaseHandler):
class DisconnectHandler(tcp.BaseHandler): class DisconnectHandler(tcp.BaseHandler):
def handle(self): def handle(self):
self.finish() self.close()
class TServer(tcp.TCPServer): class TServer(tcp.TCPServer):
@ -102,6 +102,20 @@ class TestServer(ServerTestBase):
assert c.rfile.readline() == testval assert c.rfile.readline() == testval
class TestDisconnect(ServerTestBase):
@classmethod
def makeserver(cls):
return TServer(("127.0.0.1", 0), False, cls.q, EchoHandler)
def test_echo(self):
testval = "echo!\n"
c = tcp.TCPClient("127.0.0.1", self.port)
c.connect()
c.wfile.write(testval)
c.wfile.flush()
assert c.rfile.readline() == testval
class TestServerSSL(ServerTestBase): class TestServerSSL(ServerTestBase):
@classmethod @classmethod
def makeserver(cls): def makeserver(cls):
@ -154,6 +168,24 @@ class TestSSLDisconnect(ServerTestBase):
c.convert_to_ssl() c.convert_to_ssl()
# Excercise SSL.ZeroReturnError # Excercise SSL.ZeroReturnError
c.rfile.read(10) c.rfile.read(10)
c.close()
tutils.raises(tcp.NetLibDisconnect, c.wfile.write, "foo")
tutils.raises(Queue.Empty, self.q.get_nowait)
class TestDisconnect(ServerTestBase):
@classmethod
def makeserver(cls):
return TServer(("127.0.0.1", 0), False, cls.q, DisconnectHandler)
def test_echo(self):
c = tcp.TCPClient("127.0.0.1", self.port)
c.connect()
# Excercise SSL.ZeroReturnError
c.rfile.read(10)
c.wfile.write("foo")
c.close()
c.close()
class TestTCPClient: class TestTCPClient: