mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 08:11:00 +00:00
close() methods for clients and servers.
This commit is contained in:
parent
1227369db3
commit
63d789109a
@ -66,7 +66,7 @@ class FileLike:
|
||||
if v:
|
||||
try:
|
||||
return self.o.sendall(v)
|
||||
except SSL.SysCallError:
|
||||
except SSL.Error:
|
||||
raise NetLibDisconnect()
|
||||
|
||||
def readline(self, size = None):
|
||||
@ -125,6 +125,20 @@ class TCPClient:
|
||||
raise NetLibError('Error connecting to "%s": %s' % (self.host, err))
|
||||
self.connection = connection
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Does a hard close of the socket, i.e. a shutdown, followed by a close.
|
||||
"""
|
||||
try:
|
||||
if self.ssl_established:
|
||||
self.connection.shutdown()
|
||||
else:
|
||||
self.connection.shutdown(socket.SHUT_RDWR)
|
||||
self.connection.close()
|
||||
except (socket.error, SSL.Error):
|
||||
# Socket probably already closed
|
||||
pass
|
||||
|
||||
|
||||
class BaseHandler:
|
||||
"""
|
||||
@ -170,7 +184,7 @@ class BaseHandler:
|
||||
self.wfile.flush()
|
||||
self.wfile.close()
|
||||
self.rfile.close()
|
||||
self.connection.close()
|
||||
self.close()
|
||||
except socket.error:
|
||||
# Remote has disconnected
|
||||
pass
|
||||
@ -195,6 +209,20 @@ class BaseHandler:
|
||||
def handle(self): # pragma: no cover
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Does a hard close of the socket, i.e. a shutdown, followed by a close.
|
||||
"""
|
||||
try:
|
||||
if self.ssl_established:
|
||||
self.connection.shutdown()
|
||||
else:
|
||||
self.connection.shutdown(socket.SHUT_RDWR)
|
||||
self.connection.close()
|
||||
except (socket.error, SSL.Error):
|
||||
# Socket probably already closed
|
||||
pass
|
||||
|
||||
|
||||
class TCPServer:
|
||||
request_queue_size = 20
|
||||
|
@ -54,7 +54,7 @@ class EchoHandler(tcp.BaseHandler):
|
||||
|
||||
class DisconnectHandler(tcp.BaseHandler):
|
||||
def handle(self):
|
||||
self.finish()
|
||||
self.close()
|
||||
|
||||
|
||||
class TServer(tcp.TCPServer):
|
||||
@ -102,6 +102,20 @@ class TestServer(ServerTestBase):
|
||||
assert c.rfile.readline() == testval
|
||||
|
||||
|
||||
class TestDisconnect(ServerTestBase):
|
||||
@classmethod
|
||||
def makeserver(cls):
|
||||
return TServer(("127.0.0.1", 0), False, cls.q, EchoHandler)
|
||||
|
||||
def test_echo(self):
|
||||
testval = "echo!\n"
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
c.connect()
|
||||
c.wfile.write(testval)
|
||||
c.wfile.flush()
|
||||
assert c.rfile.readline() == testval
|
||||
|
||||
|
||||
class TestServerSSL(ServerTestBase):
|
||||
@classmethod
|
||||
def makeserver(cls):
|
||||
@ -154,6 +168,24 @@ class TestSSLDisconnect(ServerTestBase):
|
||||
c.convert_to_ssl()
|
||||
# Excercise SSL.ZeroReturnError
|
||||
c.rfile.read(10)
|
||||
c.close()
|
||||
tutils.raises(tcp.NetLibDisconnect, c.wfile.write, "foo")
|
||||
tutils.raises(Queue.Empty, self.q.get_nowait)
|
||||
|
||||
|
||||
class TestDisconnect(ServerTestBase):
|
||||
@classmethod
|
||||
def makeserver(cls):
|
||||
return TServer(("127.0.0.1", 0), False, cls.q, DisconnectHandler)
|
||||
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
c.connect()
|
||||
# Excercise SSL.ZeroReturnError
|
||||
c.rfile.read(10)
|
||||
c.wfile.write("foo")
|
||||
c.close()
|
||||
c.close()
|
||||
|
||||
|
||||
class TestTCPClient:
|
||||
|
Loading…
Reference in New Issue
Block a user