fix various fd/socket leaks

This commit is contained in:
Thomas Kriechbaumer 2017-05-24 14:09:41 +02:00
parent 2faaa0b2a2
commit ae7e9efb5c
13 changed files with 45 additions and 29 deletions

View File

@ -339,7 +339,8 @@ class View(collections.Sequence):
""" """
Load flows into the view, without processing them with addons. Load flows into the view, without processing them with addons.
""" """
for i in io.FlowReader(open(path, "rb")).stream(): with open(path, "rb") as f:
for i in io.FlowReader(f).stream():
# Do this to get a new ID, so we can load the same file N times and # Do this to get a new ID, so we can load the same file N times and
# get new flows each time. It would be more efficient to just have a # get new flows each time. It would be more efficient to just have a
# .newid() method or something. # .newid() method or something.

View File

@ -676,6 +676,8 @@ class TCPClient(_Connection):
sock.setsockopt(socket.SOL_IP, socket.IP_TRANSPARENT, 1) # pragma: windows no cover pragma: osx no cover sock.setsockopt(socket.SOL_IP, socket.IP_TRANSPARENT, 1) # pragma: windows no cover pragma: osx no cover
except Exception as e: except Exception as e:
# socket.IP_TRANSPARENT might not be available on every OS and Python version # socket.IP_TRANSPARENT might not be available on every OS and Python version
if sock is not None:
sock.close()
raise exceptions.TcpException( raise exceptions.TcpException(
"Failed to spoof the source address: " + str(e) "Failed to spoof the source address: " + str(e)
) )

View File

@ -91,3 +91,7 @@ class FileGenerator:
def __repr__(self): def __repr__(self):
return "<%s" % self.path return "<%s" % self.path
def close(self):
self.map.close()
self.fp.close()

View File

@ -10,7 +10,8 @@ from mitmproxy.test import taddons
def tdump(path, flows): def tdump(path, flows):
w = io.FlowWriter(open(path, "wb")) with open(path, "wb") as f:
w = io.FlowWriter(f)
for i in flows: for i in flows:
w.add(i) w.add(i)

View File

@ -26,7 +26,8 @@ def test_configure(tmpdir):
def rd(p): def rd(p):
x = io.FlowReader(open(p, "rb")) with open(p, "rb") as f:
x = io.FlowReader(f)
return list(x.stream()) return list(x.stream())

View File

@ -11,7 +11,8 @@ from mitmproxy import io
def tdump(path, flows): def tdump(path, flows):
w = io.FlowWriter(open(path, "wb")) with open(path, "wb") as f:
w = io.FlowWriter(f)
for i in flows: for i in flows:
w.add(i) w.add(i)

View File

@ -132,7 +132,8 @@ def test_filter():
def tdump(path, flows): def tdump(path, flows):
w = io.FlowWriter(open(path, "wb")) with open(path, "wb") as f:
w = io.FlowWriter(f)
for i in flows: for i in flows:
w.add(i) w.add(i)

View File

@ -17,7 +17,9 @@ def test_view_protobuf_request():
m.configure_mock(**attrs) m.configure_mock(**attrs)
n.return_value = m n.return_value = m
content_type, output = v(open(p, "rb").read()) with open(p, "rb") as f:
data = f.read()
content_type, output = v(data)
assert content_type == "Protobuf" assert content_type == "Protobuf"
assert output[0] == [('text', b'1: "3bbc333c-e61c-433b-819a-0b9a8cc103b8"')] assert output[0] == [('text', b'1: "3bbc333c-e61c-433b-819a-0b9a8cc103b8"')]

View File

@ -34,7 +34,7 @@ class ClientCipherListHandler(tcp.BaseHandler):
sni = None sni = None
def handle(self): def handle(self):
self.wfile.write("%s" % self.connection.get_cipher_list()) self.wfile.write(str(self.connection.get_cipher_list()).encode())
self.wfile.flush() self.wfile.flush()
@ -391,14 +391,15 @@ class TestSNI(tservers.ServerTestBase):
class TestServerCipherList(tservers.ServerTestBase): class TestServerCipherList(tservers.ServerTestBase):
handler = ClientCipherListHandler handler = ClientCipherListHandler
ssl = dict( ssl = dict(
cipher_list='AES256-GCM-SHA384' cipher_list=b'AES256-GCM-SHA384'
) )
def test_echo(self): def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect(): with c.connect():
c.convert_to_ssl(sni="foo.com") c.convert_to_ssl(sni="foo.com")
assert c.rfile.readline() == b"['AES256-GCM-SHA384']" expected = b"['AES256-GCM-SHA384']"
assert c.rfile.read(len(expected) + 2) == expected
class TestServerCurrentCipher(tservers.ServerTestBase): class TestServerCurrentCipher(tservers.ServerTestBase):
@ -424,7 +425,7 @@ class TestServerCurrentCipher(tservers.ServerTestBase):
class TestServerCipherListError(tservers.ServerTestBase): class TestServerCipherListError(tservers.ServerTestBase):
handler = ClientCipherListHandler handler = ClientCipherListHandler
ssl = dict( ssl = dict(
cipher_list='bogus' cipher_list=b'bogus'
) )
def test_echo(self): def test_echo(self):
@ -632,6 +633,7 @@ class TestTCPServer:
with s.handler_counter: with s.handler_counter:
with pytest.raises(exceptions.Timeout): with pytest.raises(exceptions.Timeout):
s.wait_for_silence() s.wait_for_silence()
s.shutdown()
class TestFileLike: class TestFileLike:

View File

@ -9,10 +9,11 @@ class TestLookup:
def test_simple(self): def test_simple(self):
if sys.platform == "freebsd10": if sys.platform == "freebsd10":
p = tutils.test_data.path("mitmproxy/data/pf02") p = tutils.test_data.path("mitmproxy/data/pf02")
d = open(p, "rb").read()
else: else:
p = tutils.test_data.path("mitmproxy/data/pf01") p = tutils.test_data.path("mitmproxy/data/pf01")
d = open(p, "rb").read() with open(p, "rb") as f:
d = f.read()
assert pf.lookup("192.168.1.111", 40000, d) == ("5.5.5.5", 80) assert pf.lookup("192.168.1.111", 40000, d) == ("5.5.5.5", 80)
with pytest.raises(Exception, match="Could not resolve original destination"): with pytest.raises(Exception, match="Could not resolve original destination"):
pf.lookup("192.168.1.112", 40000, d) pf.lookup("192.168.1.112", 40000, d)

View File

@ -202,12 +202,14 @@ class TestMisc:
e.parseString("m@1") e.parseString("m@1")
s = base.Settings(staticdir=str(tmpdir)) s = base.Settings(staticdir=str(tmpdir))
tmpdir.join("path").write_binary(b"a" * 20, ensure=True) with open(str(tmpdir.join("path")), 'wb') as f:
f.write(b"a" * 20)
v = e.parseString("m<path")[0] v = e.parseString("m<path")[0]
with pytest.raises(Exception, match="Invalid value length"): with pytest.raises(Exception, match="Invalid value length"):
v.values(s) v.values(s)
tmpdir.join("path2").write_binary(b"a" * 4, ensure=True) with open(str(tmpdir.join("path2")), 'wb') as f:
f.write(b"a" * 4)
v = e.parseString("m<path2")[0] v = e.parseString("m<path2")[0]
assert v.values(s) assert v.values(s)

View File

@ -23,9 +23,7 @@ def test_filegenerator(tmpdir):
assert len(g[1:10]) == 9 assert len(g[1:10]) == 9
assert len(g[10000:10001]) == 0 assert len(g[10000:10001]) == 0
assert repr(g) assert repr(g)
# remove all references to FileGenerator instance to close the file g.close()
# handle.
del g
def test_transform_generator(): def test_transform_generator():

View File

@ -202,7 +202,7 @@ class TestApplySettings(net_tservers.ServerTestBase):
def handle(self): def handle(self):
# check settings acknowledgement # check settings acknowledgement
assert self.rfile.read(9) == codecs.decode('000000040100000000', 'hex_codec') assert self.rfile.read(9) == codecs.decode('000000040100000000', 'hex_codec')
self.wfile.write("OK") self.wfile.write(b"OK")
self.wfile.flush() self.wfile.flush()
self.rfile.safe_read(9) # just to keep the connection alive a bit longer self.rfile.safe_read(9) # just to keep the connection alive a bit longer