mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-27 02:24:18 +00:00
close() methods for clients and servers.
This commit is contained in:
parent
1227369db3
commit
63d789109a
@ -66,7 +66,7 @@ class FileLike:
|
|||||||
if v:
|
if v:
|
||||||
try:
|
try:
|
||||||
return self.o.sendall(v)
|
return self.o.sendall(v)
|
||||||
except SSL.SysCallError:
|
except SSL.Error:
|
||||||
raise NetLibDisconnect()
|
raise NetLibDisconnect()
|
||||||
|
|
||||||
def readline(self, size = None):
|
def readline(self, size = None):
|
||||||
@ -125,6 +125,20 @@ class TCPClient:
|
|||||||
raise NetLibError('Error connecting to "%s": %s' % (self.host, err))
|
raise NetLibError('Error connecting to "%s": %s' % (self.host, err))
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""
|
||||||
|
Does a hard close of the socket, i.e. a shutdown, followed by a close.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self.ssl_established:
|
||||||
|
self.connection.shutdown()
|
||||||
|
else:
|
||||||
|
self.connection.shutdown(socket.SHUT_RDWR)
|
||||||
|
self.connection.close()
|
||||||
|
except (socket.error, SSL.Error):
|
||||||
|
# Socket probably already closed
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BaseHandler:
|
class BaseHandler:
|
||||||
"""
|
"""
|
||||||
@ -170,7 +184,7 @@ class BaseHandler:
|
|||||||
self.wfile.flush()
|
self.wfile.flush()
|
||||||
self.wfile.close()
|
self.wfile.close()
|
||||||
self.rfile.close()
|
self.rfile.close()
|
||||||
self.connection.close()
|
self.close()
|
||||||
except socket.error:
|
except socket.error:
|
||||||
# Remote has disconnected
|
# Remote has disconnected
|
||||||
pass
|
pass
|
||||||
@ -195,6 +209,20 @@ class BaseHandler:
|
|||||||
def handle(self): # pragma: no cover
|
def handle(self): # pragma: no cover
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""
|
||||||
|
Does a hard close of the socket, i.e. a shutdown, followed by a close.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self.ssl_established:
|
||||||
|
self.connection.shutdown()
|
||||||
|
else:
|
||||||
|
self.connection.shutdown(socket.SHUT_RDWR)
|
||||||
|
self.connection.close()
|
||||||
|
except (socket.error, SSL.Error):
|
||||||
|
# Socket probably already closed
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TCPServer:
|
class TCPServer:
|
||||||
request_queue_size = 20
|
request_queue_size = 20
|
||||||
@ -252,7 +280,7 @@ class TCPServer:
|
|||||||
Called when handle_connection raises an exception.
|
Called when handle_connection raises an exception.
|
||||||
"""
|
"""
|
||||||
# If a thread has persisted after interpreter exit, the module might be
|
# If a thread has persisted after interpreter exit, the module might be
|
||||||
# none.
|
# none.
|
||||||
if traceback:
|
if traceback:
|
||||||
exc = traceback.format_exc()
|
exc = traceback.format_exc()
|
||||||
print >> fp, '-'*40
|
print >> fp, '-'*40
|
||||||
|
@ -54,7 +54,7 @@ class EchoHandler(tcp.BaseHandler):
|
|||||||
|
|
||||||
class DisconnectHandler(tcp.BaseHandler):
|
class DisconnectHandler(tcp.BaseHandler):
|
||||||
def handle(self):
|
def handle(self):
|
||||||
self.finish()
|
self.close()
|
||||||
|
|
||||||
|
|
||||||
class TServer(tcp.TCPServer):
|
class TServer(tcp.TCPServer):
|
||||||
@ -102,6 +102,20 @@ class TestServer(ServerTestBase):
|
|||||||
assert c.rfile.readline() == testval
|
assert c.rfile.readline() == testval
|
||||||
|
|
||||||
|
|
||||||
|
class TestDisconnect(ServerTestBase):
|
||||||
|
@classmethod
|
||||||
|
def makeserver(cls):
|
||||||
|
return TServer(("127.0.0.1", 0), False, cls.q, EchoHandler)
|
||||||
|
|
||||||
|
def test_echo(self):
|
||||||
|
testval = "echo!\n"
|
||||||
|
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||||
|
c.connect()
|
||||||
|
c.wfile.write(testval)
|
||||||
|
c.wfile.flush()
|
||||||
|
assert c.rfile.readline() == testval
|
||||||
|
|
||||||
|
|
||||||
class TestServerSSL(ServerTestBase):
|
class TestServerSSL(ServerTestBase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def makeserver(cls):
|
def makeserver(cls):
|
||||||
@ -154,6 +168,24 @@ class TestSSLDisconnect(ServerTestBase):
|
|||||||
c.convert_to_ssl()
|
c.convert_to_ssl()
|
||||||
# Excercise SSL.ZeroReturnError
|
# Excercise SSL.ZeroReturnError
|
||||||
c.rfile.read(10)
|
c.rfile.read(10)
|
||||||
|
c.close()
|
||||||
|
tutils.raises(tcp.NetLibDisconnect, c.wfile.write, "foo")
|
||||||
|
tutils.raises(Queue.Empty, self.q.get_nowait)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDisconnect(ServerTestBase):
|
||||||
|
@classmethod
|
||||||
|
def makeserver(cls):
|
||||||
|
return TServer(("127.0.0.1", 0), False, cls.q, DisconnectHandler)
|
||||||
|
|
||||||
|
def test_echo(self):
|
||||||
|
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||||
|
c.connect()
|
||||||
|
# Excercise SSL.ZeroReturnError
|
||||||
|
c.rfile.read(10)
|
||||||
|
c.wfile.write("foo")
|
||||||
|
c.close()
|
||||||
|
c.close()
|
||||||
|
|
||||||
|
|
||||||
class TestTCPClient:
|
class TestTCPClient:
|
||||||
|
Loading…
Reference in New Issue
Block a user