This commit is contained in:
Maximilian Hils 2014-10-22 17:54:20 +02:00
parent 29a4e91050
commit ed5e685565

View File

@ -204,6 +204,37 @@ class Address(object):
return not self.__eq__(other)
def close_socket(sock):
"""
Does a hard close of a socket, without emitting a RST.
"""
try:
# We already indicate that we close our end.
# If we close RD, any further received bytes would result in a RST being set, which we want to avoid
# for our purposes
sock.shutdown(socket.SHUT_WR) # may raise "Transport endpoint is not connected" on Linux
# Section 4.2.2.13 of RFC 1122 tells us that a close() with any
# pending readable data could lead to an immediate RST being sent (which is the case on Windows).
# http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html
#
# However, we cannot rely on the shutdown()-followed-by-read()-eof technique proposed by the page above:
# Some remote machines just don't send a TCP FIN, which would leave us in the unfortunate situation that
# recv() would block infinitely.
# As a workaround, we set a timeout here even if we are in blocking mode.
# Please let us know if you have a better solution to this problem.
sock.settimeout(sock.gettimeout() or 20)
# may raise a timeout/disconnect exception.
while sock.recv(4096): # pragma: no cover
pass
except socket.error:
pass
sock.close()
class _Connection(object):
def get_current_cipher(self):
if not self.ssl_established:
@ -216,59 +247,36 @@ class _Connection(object):
def finish(self):
self.finished = True
try:
# If we have an SSL connection, wfile.close == connection.close
# (We call _FileLike.set_descriptor(conn))
# Closing the socket is not our task, therefore we don't call close then.
if type(self.connection) != SSL.Connection:
if not getattr(self.wfile, "closed", False):
self.wfile.flush()
self.close()
self.wfile.close()
self.rfile.close()
except (socket.error, NetLibDisconnect):
# Remote has disconnected
pass
def close(self):
"""
Does a hard close of the socket, i.e. a shutdown, followed by a
close.
"""
try:
if type(self.connection) == SSL.Connection:
else:
try:
self.connection.shutdown()
self.connection.sock_shutdown(socket.SHUT_WR)
else:
self.connection.shutdown(socket.SHUT_WR)
if type(self.connection) != SSL.Connection or self.ssl_established:
# Section 4.2.2.13 of RFC 1122 tells us that a close() with any
# pending readable data could lead to an immediate RST being sent (which is the case on Windows).
# http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html
#
# However, we cannot rely on the shutdown()-followed-by-read()-eof technique proposed by the page above:
# Some remote machines just don't send a TCP FIN, which would leave us in the unfortunate situation that
# recv() would block infinitely.
# As a workaround, we set a timeout here even if we were in blocking mode.
# Please let us know if you have a better solution to this problem.
#
# Do not call this for every SSL.Connection:
# If the SSL handshake failed at the first place, OpenSSL's SSL_read tries to negotiate the connection
# again at this point, calls the SNI handler and segfaults.
# https://github.com/mitmproxy/mitmproxy/issues/373#issuecomment-58383499
timeout = self.connection.gettimeout()
self.connection.settimeout(timeout or 60)
while self.connection.recv(4096): # pragma: no cover
pass
self.connection.settimeout(timeout)
self.connection.close()
except (socket.error, SSL.Error, IOError):
# Socket probably already closed
pass
except SSL.Error:
pass
class TCPClient(_Connection):
rbufsize = -1
wbufsize = -1
def close(self):
# Make sure to close the real socket, not the SSL proxy.
# OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection,
# it tries to renegotiate...
if type(self.connection) == SSL.Connection:
close_socket(self.connection._socket)
else:
close_socket(self.connection)
def __init__(self, address, source_address=None):
self.address = Address.wrap(address)
self.source_address = Address.wrap(source_address) if source_address else None
@ -430,7 +438,6 @@ class BaseHandler(_Connection):
self.connection.settimeout(n)
class TCPServer(object):
request_queue_size = 20
def __init__(self, address):
@ -450,11 +457,7 @@ class TCPServer(object):
except:
self.handle_error(connection, client_address)
finally:
try:
connection.shutdown(socket.SHUT_RDWR)
except:
pass
connection.close()
close_socket(connection)
def serve_forever(self, poll_interval=0.1):
self.__is_shut_down.clear()