diff --git a/mitmproxy/addons/view.py b/mitmproxy/addons/view.py index 13a17c56e..aa3e11edb 100644 --- a/mitmproxy/addons/view.py +++ b/mitmproxy/addons/view.py @@ -339,11 +339,12 @@ class View(collections.Sequence): """ Load flows into the view, without processing them with addons. """ - for i in io.FlowReader(open(path, "rb")).stream(): - # Do this to get a new ID, so we can load the same file N times and - # get new flows each time. It would be more efficient to just have a - # .newid() method or something. - self.add([i.copy()]) + with open(path, "rb") as f: + for i in io.FlowReader(f).stream(): + # Do this to get a new ID, so we can load the same file N times and + # get new flows each time. It would be more efficient to just have a + # .newid() method or something. + self.add([i.copy()]) @command.command("view.go") def go(self, dst: int) -> None: diff --git a/mitmproxy/net/socks.py b/mitmproxy/net/socks.py index 570a4afbb..fdfcfb804 100644 --- a/mitmproxy/net/socks.py +++ b/mitmproxy/net/socks.py @@ -82,12 +82,12 @@ class ClientGreeting: client_greeting = cls(ver, []) if fail_early: client_greeting.assert_socks5() - client_greeting.methods.fromstring(f.safe_read(nmethods)) + client_greeting.methods.frombytes(f.safe_read(nmethods)) return client_greeting def to_file(self, f): f.write(struct.pack("!BB", self.ver, len(self.methods))) - f.write(self.methods.tostring()) + f.write(self.methods.tobytes()) class ServerGreeting: diff --git a/mitmproxy/net/tcp.py b/mitmproxy/net/tcp.py index 81568d248..cdac4cd58 100644 --- a/mitmproxy/net/tcp.py +++ b/mitmproxy/net/tcp.py @@ -569,7 +569,9 @@ class TCPClient(_Connection): # Make sure to close the real socket, not the SSL proxy. # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection, # it tries to renegotiate... - if isinstance(self.connection, SSL.Connection): + if not self.connection: + return + elif isinstance(self.connection, SSL.Connection): close_socket(self.connection._socket) else: close_socket(self.connection) @@ -674,6 +676,8 @@ class TCPClient(_Connection): sock.setsockopt(socket.SOL_IP, socket.IP_TRANSPARENT, 1) # pragma: windows no cover pragma: osx no cover except Exception as e: # socket.IP_TRANSPARENT might not be available on every OS and Python version + if sock is not None: + sock.close() raise exceptions.TcpException( "Failed to spoof the source address: " + str(e) ) @@ -864,6 +868,8 @@ class TCPServer: self.socket.setsockopt(IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) self.socket.bind(self.address) except socket.error: + if self.socket: + self.socket.close() self.socket = None if not self.socket: diff --git a/mitmproxy/proxy/protocol/http_replay.py b/mitmproxy/proxy/protocol/http_replay.py index 25867871f..1aa918472 100644 --- a/mitmproxy/proxy/protocol/http_replay.py +++ b/mitmproxy/proxy/protocol/http_replay.py @@ -83,7 +83,11 @@ class RequestReplayThread(basethread.BaseThread): server.wfile.write(http1.assemble_request(r)) server.wfile.flush() + + if self.f.server_conn: + self.f.server_conn.close() self.f.server_conn = server + self.f.response = http.HTTPResponse.wrap( http1.read_response( server.rfile, diff --git a/mitmproxy/proxy/server.py b/mitmproxy/proxy/server.py index 50a2b76b8..5171fbee6 100644 --- a/mitmproxy/proxy/server.py +++ b/mitmproxy/proxy/server.py @@ -48,6 +48,8 @@ class ProxyServer(tcp.TCPServer): if config.options.mode == "transparent": platform.init_transparent_mode() except Exception as e: + if self.socket: + self.socket.close() raise exceptions.ServerException( 'Error starting proxy server: ' + repr(e) ) from e diff --git a/mitmproxy/tools/main.py b/mitmproxy/tools/main.py index d8fac077a..84dab1fe4 100644 --- a/mitmproxy/tools/main.py +++ b/mitmproxy/tools/main.py @@ -1,6 +1,8 @@ from __future__ import print_function # this is here for the version check to work on Python 2. import sys +# This must be at the very top, before importing anything else that might break! +# Keep all other imports below with the 'noqa' magic comment. if sys.version_info < (3, 5): print("#" * 49, file=sys.stderr) print("# mitmproxy only supports Python 3.5 and above! #", file=sys.stderr) @@ -13,8 +15,7 @@ from mitmproxy.tools import cmdline # noqa from mitmproxy import exceptions # noqa from mitmproxy import options # noqa from mitmproxy import optmanager # noqa -from mitmproxy.proxy import config # noqa -from mitmproxy.proxy import server # noqa +from mitmproxy import proxy # noqa from mitmproxy.utils import version_check # noqa from mitmproxy.utils import debug # noqa @@ -49,15 +50,7 @@ def process_options(parser, opts, args): adict[n] = getattr(args, n) opts.merge(adict) - pconf = config.ProxyConfig(opts) - if opts.server: - try: - return server.ProxyServer(pconf) - except exceptions.ServerException as v: - print(str(v), file=sys.stderr) - sys.exit(1) - else: - return server.DummyServer(pconf) + return proxy.config.ProxyConfig(opts) def run(MasterKlass, args, extra=None): # pragma: no cover @@ -74,7 +67,16 @@ def run(MasterKlass, args, extra=None): # pragma: no cover master = None try: unknown = optmanager.load_paths(opts, args.conf) - server = process_options(parser, opts, args) + pconf = process_options(parser, opts, args) + if pconf.options.server: + try: + server = proxy.server.ProxyServer(pconf) + except exceptions.ServerException as v: + print(str(v), file=sys.stderr) + sys.exit(1) + else: + server = proxy.server.DummyServer(pconf) + master = MasterKlass(opts, server) master.addons.trigger("configure", opts.keys()) master.addons.trigger("tick") diff --git a/pathod/language/generators.py b/pathod/language/generators.py index d716804da..93db30148 100644 --- a/pathod/language/generators.py +++ b/pathod/language/generators.py @@ -91,3 +91,7 @@ class FileGenerator: def __repr__(self): return "<%s" % self.path + + def close(self): + self.map.close() + self.fp.close() diff --git a/setup.py b/setup.py index a03d74fbc..38bd7ee47 100644 --- a/setup.py +++ b/setup.py @@ -61,8 +61,9 @@ setup( # It is not considered best practice to use install_requires to pin dependencies to specific versions. install_requires=[ "blinker>=1.4, <1.5", - "click>=6.2, <7", + "brotlipy>=0.5.1, <0.7", "certifi>=2015.11.20.1", # no semver here - this should always be on the last release! + "click>=6.2, <7", "construct>=2.8, <2.9", "cryptography>=1.4, <1.9", "cssutils>=1.0.1, <1.1", @@ -79,37 +80,29 @@ setup( "pyperclip>=1.5.22, <1.6", "requests>=2.9.1, <3", "ruamel.yaml>=0.13.2, <0.15", + "sortedcontainers>=1.5.4, <1.6", "tornado>=4.3, <4.6", "urwid>=1.3.1, <1.4", - "brotlipy>=0.5.1, <0.7", - "sortedcontainers>=1.5.4, <1.6", - # transitive from cryptography, we just blacklist here. - # https://github.com/pypa/setuptools/issues/861 - "setuptools>=11.3, !=29.0.0", ], extras_require={ ':sys_platform == "win32"': [ "pydivert>=2.0.3, <2.1", ], - ':sys_platform != "win32"': [ - ], 'dev': [ - "Flask>=0.10.1, <0.13", "flake8>=3.2.1, <3.4", + "Flask>=0.10.1, <0.13", "mypy>=0.501, <0.502", - "rstcheck>=2.2, <4.0", - "tox>=2.3, <3", - "pytest>=3, <3.1", "pytest-cov>=2.2.1, <3", + "pytest-faulthandler>=1.3.0, <2", "pytest-timeout>=1.0.0, <2", "pytest-xdist>=1.14, <2", - "pytest-faulthandler>=1.3.0, <2", - "sphinx>=1.3.5, <1.7", - "sphinx-autobuild>=0.5.2, <0.7", - "sphinxcontrib-documentedlist>=0.5.0, <0.7", + "pytest>=3.1, <4", + "rstcheck>=2.2, <4.0", "sphinx_rtd_theme>=0.1.9, <0.3", - ], - 'contentviews': [ + "sphinx-autobuild>=0.5.2, <0.7", + "sphinx>=1.3.5, <1.7", + "sphinxcontrib-documentedlist>=0.5.0, <0.7", + "tox>=2.3, <3", ], 'examples': [ "beautifulsoup4>=4.4.1, <4.7", diff --git a/test/conftest.py b/test/conftest.py index b4e1da932..bb9135488 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,8 +1,6 @@ import os import pytest import OpenSSL -import functools -from contextlib import contextmanager import mitmproxy.net.tcp @@ -32,21 +30,3 @@ skip_appveyor = pytest.mark.skipif( def disable_alpn(monkeypatch): monkeypatch.setattr(mitmproxy.net.tcp, 'HAS_ALPN', False) monkeypatch.setattr(OpenSSL.SSL._lib, 'Cryptography_HAS_ALPN', False) - - -################################################################################ -# TODO: remove this wrapper when pytest 3.1.0 is released -original_pytest_raises = pytest.raises - - -@contextmanager -@functools.wraps(original_pytest_raises) -def raises(exc, *args, **kwargs): - with original_pytest_raises(exc, *args, **kwargs) as exc_info: - yield - if 'match' in kwargs: - assert exc_info.match(kwargs['match']) - - -pytest.raises = raises -################################################################################ diff --git a/test/mitmproxy/addons/test_clientplayback.py b/test/mitmproxy/addons/test_clientplayback.py index 7ffda317b..6089b2d59 100644 --- a/test/mitmproxy/addons/test_clientplayback.py +++ b/test/mitmproxy/addons/test_clientplayback.py @@ -10,9 +10,10 @@ from mitmproxy.test import taddons def tdump(path, flows): - w = io.FlowWriter(open(path, "wb")) - for i in flows: - w.add(i) + with open(path, "wb") as f: + w = io.FlowWriter(f) + for i in flows: + w.add(i) class MockThread(): diff --git a/test/mitmproxy/addons/test_save.py b/test/mitmproxy/addons/test_save.py index 85c2a3985..a4e425cd4 100644 --- a/test/mitmproxy/addons/test_save.py +++ b/test/mitmproxy/addons/test_save.py @@ -26,8 +26,9 @@ def test_configure(tmpdir): def rd(p): - x = io.FlowReader(open(p, "rb")) - return list(x.stream()) + with open(p, "rb") as f: + x = io.FlowReader(f) + return list(x.stream()) def test_tcp(tmpdir): diff --git a/test/mitmproxy/addons/test_serverplayback.py b/test/mitmproxy/addons/test_serverplayback.py index 3ceab3fab..7605a5d99 100644 --- a/test/mitmproxy/addons/test_serverplayback.py +++ b/test/mitmproxy/addons/test_serverplayback.py @@ -11,9 +11,10 @@ from mitmproxy import io def tdump(path, flows): - w = io.FlowWriter(open(path, "wb")) - for i in flows: - w.add(i) + with open(path, "wb") as f: + w = io.FlowWriter(f) + for i in flows: + w.add(i) def test_load_file(tmpdir): diff --git a/test/mitmproxy/addons/test_view.py b/test/mitmproxy/addons/test_view.py index 6da136502..d5a3a4562 100644 --- a/test/mitmproxy/addons/test_view.py +++ b/test/mitmproxy/addons/test_view.py @@ -132,9 +132,10 @@ def test_filter(): def tdump(path, flows): - w = io.FlowWriter(open(path, "wb")) - for i in flows: - w.add(i) + with open(path, "wb") as f: + w = io.FlowWriter(f) + for i in flows: + w.add(i) def test_create(): diff --git a/test/mitmproxy/contentviews/test_protobuf.py b/test/mitmproxy/contentviews/test_protobuf.py index 31e382ecb..71e515769 100644 --- a/test/mitmproxy/contentviews/test_protobuf.py +++ b/test/mitmproxy/contentviews/test_protobuf.py @@ -17,7 +17,9 @@ def test_view_protobuf_request(): m.configure_mock(**attrs) n.return_value = m - content_type, output = v(open(p, "rb").read()) + with open(p, "rb") as f: + data = f.read() + content_type, output = v(data) assert content_type == "Protobuf" assert output[0] == [('text', b'1: "3bbc333c-e61c-433b-819a-0b9a8cc103b8"')] diff --git a/test/mitmproxy/net/test_tcp.py b/test/mitmproxy/net/test_tcp.py index 81d518885..234e8afb5 100644 --- a/test/mitmproxy/net/test_tcp.py +++ b/test/mitmproxy/net/test_tcp.py @@ -34,7 +34,7 @@ class ClientCipherListHandler(tcp.BaseHandler): sni = None def handle(self): - self.wfile.write("%s" % self.connection.get_cipher_list()) + self.wfile.write(str(self.connection.get_cipher_list()).encode()) self.wfile.flush() @@ -391,14 +391,15 @@ class TestSNI(tservers.ServerTestBase): class TestServerCipherList(tservers.ServerTestBase): handler = ClientCipherListHandler ssl = dict( - cipher_list='AES256-GCM-SHA384' + cipher_list=b'AES256-GCM-SHA384' ) def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): c.convert_to_ssl(sni="foo.com") - assert c.rfile.readline() == b"['AES256-GCM-SHA384']" + expected = b"['AES256-GCM-SHA384']" + assert c.rfile.read(len(expected) + 2) == expected class TestServerCurrentCipher(tservers.ServerTestBase): @@ -424,7 +425,7 @@ class TestServerCurrentCipher(tservers.ServerTestBase): class TestServerCipherListError(tservers.ServerTestBase): handler = ClientCipherListHandler ssl = dict( - cipher_list='bogus' + cipher_list=b'bogus' ) def test_echo(self): @@ -632,6 +633,7 @@ class TestTCPServer: with s.handler_counter: with pytest.raises(exceptions.Timeout): s.wait_for_silence() + s.shutdown() class TestFileLike: diff --git a/test/mitmproxy/net/tservers.py b/test/mitmproxy/net/tservers.py index ebe6d3eb1..44701aa50 100644 --- a/test/mitmproxy/net/tservers.py +++ b/test/mitmproxy/net/tservers.py @@ -16,9 +16,6 @@ class _ServerThread(threading.Thread): def run(self): self.server.serve_forever() - def shutdown(self): - self.server.shutdown() - class _TServer(tcp.TCPServer): @@ -54,9 +51,9 @@ class _TServer(tcp.TCPServer): raw_key = self.ssl.get( "key", tutils.test_data.path("mitmproxy/net/data/server.key")) - key = OpenSSL.crypto.load_privatekey( - OpenSSL.crypto.FILETYPE_PEM, - open(raw_key, "rb").read()) + with open(raw_key) as f: + raw_key = f.read() + key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw_key) if self.ssl.get("v3_only", False): method = OpenSSL.SSL.SSLv3_METHOD options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1 @@ -64,7 +61,8 @@ class _TServer(tcp.TCPServer): method = OpenSSL.SSL.SSLv23_METHOD options = None h.convert_to_ssl( - cert, key, + cert, + key, method=method, options=options, handle_sni=getattr(h, "handle_sni", None), @@ -103,7 +101,7 @@ class ServerTestBase: @classmethod def teardown_class(cls): - cls.server.shutdown() + cls.server.server.shutdown() def teardown(self): self.server.server.wait_for_silence() diff --git a/test/mitmproxy/platform/test_pf.py b/test/mitmproxy/platform/test_pf.py index f644bcc5a..3292d3456 100644 --- a/test/mitmproxy/platform/test_pf.py +++ b/test/mitmproxy/platform/test_pf.py @@ -9,10 +9,11 @@ class TestLookup: def test_simple(self): if sys.platform == "freebsd10": p = tutils.test_data.path("mitmproxy/data/pf02") - d = open(p, "rb").read() else: p = tutils.test_data.path("mitmproxy/data/pf01") - d = open(p, "rb").read() + with open(p, "rb") as f: + d = f.read() + assert pf.lookup("192.168.1.111", 40000, d) == ("5.5.5.5", 80) with pytest.raises(Exception, match="Could not resolve original destination"): pf.lookup("192.168.1.112", 40000, d) diff --git a/test/mitmproxy/proxy/protocol/test_http1.py b/test/mitmproxy/proxy/protocol/test_http1.py index 07cd7dcc0..b642afb3d 100644 --- a/test/mitmproxy/proxy/protocol/test_http1.py +++ b/test/mitmproxy/proxy/protocol/test_http1.py @@ -65,6 +65,7 @@ class TestExpectHeader(tservers.HTTPProxyTest): assert resp.status_code == 200 client.finish() + client.close() class TestHeadContentLength(tservers.HTTPProxyTest): diff --git a/test/mitmproxy/proxy/protocol/test_http2.py b/test/mitmproxy/proxy/protocol/test_http2.py index b07257b3b..261f8415c 100644 --- a/test/mitmproxy/proxy/protocol/test_http2.py +++ b/test/mitmproxy/proxy/protocol/test_http2.py @@ -118,12 +118,16 @@ class _Http2TestBase: self.master.reset([]) self.server.server.handle_server_event = self.handle_server_event - def _setup_connection(self): - client = mitmproxy.net.tcp.TCPClient(("127.0.0.1", self.proxy.port)) - client.connect() + def teardown(self): + if self.client: + self.client.close() + + def setup_connection(self): + self.client = mitmproxy.net.tcp.TCPClient(("127.0.0.1", self.proxy.port)) + self.client.connect() # send CONNECT request - client.wfile.write(http1.assemble_request(mitmproxy.net.http.Request( + self.client.wfile.write(http1.assemble_request(mitmproxy.net.http.Request( 'authority', b'CONNECT', b'', @@ -134,13 +138,13 @@ class _Http2TestBase: [(b'host', b'localhost:%d' % self.server.server.address[1])], b'', ))) - client.wfile.flush() + self.client.wfile.flush() # read CONNECT response - while client.rfile.readline() != b"\r\n": + while self.client.rfile.readline() != b"\r\n": pass - client.convert_to_ssl(alpn_protos=[b'h2']) + self.client.convert_to_ssl(alpn_protos=[b'h2']) config = h2.config.H2Configuration( client_side=True, @@ -148,10 +152,10 @@ class _Http2TestBase: validate_inbound_headers=False) h2_conn = h2.connection.H2Connection(config) h2_conn.initiate_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() - return client, h2_conn + return h2_conn def _send_request(self, wfile, @@ -205,8 +209,8 @@ class TestSimple(_Http2Test): if isinstance(event, h2.events.ConnectionTerminated): return False elif isinstance(event, h2.events.RequestReceived): - assert (b'client-foo', b'client-bar-1') in event.headers - assert (b'client-foo', b'client-bar-2') in event.headers + assert (b'self.client-foo', b'self.client-bar-1') in event.headers + assert (b'self.client-foo', b'self.client-bar-2') in event.headers elif isinstance(event, h2.events.StreamEnded): import warnings with warnings.catch_warnings(): @@ -233,32 +237,32 @@ class TestSimple(_Http2Test): def test_simple(self): response_body_buffer = b'' - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() self._send_request( - client.wfile, + self.client.wfile, h2_conn, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), - ('ClIeNt-FoO', 'client-bar-1'), - ('ClIeNt-FoO', 'client-bar-2'), + ('self.client-FoO', 'self.client-bar-1'), + ('self.client-FoO', 'self.client-bar-2'), ], body=b'request body') done = False while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) events = h2_conn.receive_data(raw) except exceptions.HttpException: print(traceback.format_exc()) assert False - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() for event in events: if isinstance(event, h2.events.DataReceived): @@ -267,8 +271,8 @@ class TestSimple(_Http2Test): done = True h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() assert len(self.master.state.flows) == 1 assert self.master.state.flows[0].response.status_code == 200 @@ -317,10 +321,10 @@ class TestRequestWithPriority(_Http2Test): def test_request_with_priority(self, http2_priority_enabled, priority, expected_priority): self.config.options.http2_priority = http2_priority_enabled - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() self._send_request( - client.wfile, + self.client.wfile, h2_conn, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), @@ -336,22 +340,22 @@ class TestRequestWithPriority(_Http2Test): done = False while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) events = h2_conn.receive_data(raw) except exceptions.HttpException: print(traceback.format_exc()) assert False - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() for event in events: if isinstance(event, h2.events.StreamEnded): done = True h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() assert len(self.master.state.flows) == 1 @@ -397,15 +401,15 @@ class TestPriority(_Http2Test): self.config.options.http2_priority = http2_priority_enabled self.__class__.priority_data = [] - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() if prioritize_before: h2_conn.prioritize(1, exclusive=priority[0], depends_on=priority[1], weight=priority[2]) - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() self._send_request( - client.wfile, + self.client.wfile, h2_conn, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), @@ -419,28 +423,28 @@ class TestPriority(_Http2Test): if not prioritize_before: h2_conn.prioritize(1, exclusive=priority[0], depends_on=priority[1], weight=priority[2]) h2_conn.end_stream(1) - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() done = False while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) events = h2_conn.receive_data(raw) except exceptions.HttpException: print(traceback.format_exc()) assert False - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() for event in events: if isinstance(event, h2.events.StreamEnded): done = True h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() assert len(self.master.state.flows) == 1 assert self.priority_data == expected_priority @@ -460,10 +464,10 @@ class TestStreamResetFromServer(_Http2Test): return True def test_request_with_priority(self): - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() self._send_request( - client.wfile, + self.client.wfile, h2_conn, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), @@ -476,22 +480,22 @@ class TestStreamResetFromServer(_Http2Test): done = False while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) events = h2_conn.receive_data(raw) except exceptions.HttpException: print(traceback.format_exc()) assert False - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() for event in events: if isinstance(event, h2.events.StreamReset): done = True h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() assert len(self.master.state.flows) == 1 assert self.master.state.flows[0].response is None @@ -510,10 +514,10 @@ class TestBodySizeLimit(_Http2Test): self.config.options.body_size_limit = "20" self.config.options._processed["body_size_limit"] = 20 - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() self._send_request( - client.wfile, + self.client.wfile, h2_conn, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), @@ -527,22 +531,22 @@ class TestBodySizeLimit(_Http2Test): done = False while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) events = h2_conn.receive_data(raw) except exceptions.HttpException: print(traceback.format_exc()) assert False - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() for event in events: if isinstance(event, h2.events.StreamReset): done = True h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() assert len(self.master.state.flows) == 0 @@ -609,9 +613,9 @@ class TestPushPromise(_Http2Test): return True def test_push_promise(self): - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() - self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ + self._send_request(self.client.wfile, h2_conn, stream_id=1, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), @@ -625,15 +629,15 @@ class TestPushPromise(_Http2Test): responses = 0 while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) events = h2_conn.receive_data(raw) except exceptions.HttpException: print(traceback.format_exc()) assert False except: break - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() for event in events: if isinstance(event, h2.events.StreamEnded): @@ -649,8 +653,8 @@ class TestPushPromise(_Http2Test): done = True h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() assert ended_streams == 3 assert pushed_streams == 2 @@ -665,9 +669,9 @@ class TestPushPromise(_Http2Test): assert len(pushed_flows) == 2 def test_push_promise_reset(self): - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() - self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ + self._send_request(self.client.wfile, h2_conn, stream_id=1, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), @@ -681,14 +685,14 @@ class TestPushPromise(_Http2Test): responses = 0 while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) events = h2_conn.receive_data(raw) except exceptions.HttpException: print(traceback.format_exc()) assert False - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() for event in events: if isinstance(event, h2.events.StreamEnded) and event.stream_id == 1: @@ -696,8 +700,8 @@ class TestPushPromise(_Http2Test): elif isinstance(event, h2.events.PushedStreamReceived): pushed_streams += 1 h2_conn.reset_stream(event.pushed_stream_id, error_code=0x8) - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() elif isinstance(event, h2.events.ResponseReceived): responses += 1 if isinstance(event, h2.events.ConnectionTerminated): @@ -707,8 +711,8 @@ class TestPushPromise(_Http2Test): done = True h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() bodies = [flow.response.content for flow in self.master.state.flows if flow.response] assert len(bodies) >= 1 @@ -728,9 +732,9 @@ class TestConnectionLost(_Http2Test): return False def test_connection_lost(self): - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() - self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ + self._send_request(self.client.wfile, h2_conn, stream_id=1, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), @@ -741,7 +745,7 @@ class TestConnectionLost(_Http2Test): done = False while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) h2_conn.receive_data(raw) except exceptions.HttpException: print(traceback.format_exc()) @@ -749,8 +753,8 @@ class TestConnectionLost(_Http2Test): except: break try: - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() except: break @@ -782,12 +786,12 @@ class TestMaxConcurrentStreams(_Http2Test): return True def test_max_concurrent_streams(self): - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() new_streams = [1, 3, 5, 7, 9, 11] for stream_id in new_streams: # this will exceed MAX_CONCURRENT_STREAMS on the server connection # and cause mitmproxy to throttle stream creation to the server - self._send_request(client.wfile, h2_conn, stream_id=stream_id, headers=[ + self._send_request(self.client.wfile, h2_conn, stream_id=stream_id, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), @@ -798,20 +802,20 @@ class TestMaxConcurrentStreams(_Http2Test): ended_streams = 0 while ended_streams != len(new_streams): try: - header, body = http2.read_raw_frame(client.rfile) + header, body = http2.read_raw_frame(self.client.rfile) events = h2_conn.receive_data(b''.join([header, body])) except: break - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() for event in events: if isinstance(event, h2.events.StreamEnded): ended_streams += 1 h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() assert len(self.master.state.flows) == len(new_streams) for flow in self.master.state.flows: @@ -831,9 +835,9 @@ class TestConnectionTerminated(_Http2Test): return True def test_connection_terminated(self): - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() - self._send_request(client.wfile, h2_conn, headers=[ + self._send_request(self.client.wfile, h2_conn, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), @@ -844,7 +848,7 @@ class TestConnectionTerminated(_Http2Test): connection_terminated_event = None while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) events = h2_conn.receive_data(raw) for event in events: if isinstance(event, h2.events.ConnectionTerminated): diff --git a/test/mitmproxy/proxy/protocol/test_websocket.py b/test/mitmproxy/proxy/protocol/test_websocket.py index 8dfc4f2b5..f78e173fc 100644 --- a/test/mitmproxy/proxy/protocol/test_websocket.py +++ b/test/mitmproxy/proxy/protocol/test_websocket.py @@ -79,9 +79,13 @@ class _WebSocketTestBase: self.master.reset([]) self.server.server.handle_websockets = self.handle_websockets - def _setup_connection(self): - client = tcp.TCPClient(("127.0.0.1", self.proxy.port)) - client.connect() + def teardown(self): + if self.client: + self.client.close() + + def setup_connection(self): + self.client = tcp.TCPClient(("127.0.0.1", self.proxy.port)) + self.client.connect() request = http.Request( "authority", @@ -92,14 +96,14 @@ class _WebSocketTestBase: "", "HTTP/1.1", content=b'') - client.wfile.write(http.http1.assemble_request(request)) - client.wfile.flush() + self.client.wfile.write(http.http1.assemble_request(request)) + self.client.wfile.flush() - response = http.http1.read_response(client.rfile, request) + response = http.http1.read_response(self.client.rfile, request) if self.ssl: - client.convert_to_ssl() - assert client.ssl_established + self.client.convert_to_ssl() + assert self.client.ssl_established request = http.Request( "relative", @@ -116,14 +120,12 @@ class _WebSocketTestBase: sec_websocket_key="1234", ), content=b'') - client.wfile.write(http.http1.assemble_request(request)) - client.wfile.flush() + self.client.wfile.write(http.http1.assemble_request(request)) + self.client.wfile.flush() - response = http.http1.read_response(client.rfile, request) + response = http.http1.read_response(self.client.rfile, request) assert websockets.check_handshake(response.headers) - return client - class _WebSocketTest(_WebSocketTestBase, _WebSocketServerBase): @@ -154,25 +156,25 @@ class TestSimple(_WebSocketTest): wfile.flush() def test_simple(self): - client = self._setup_connection() + self.setup_connection() - frame = websockets.Frame.from_file(client.rfile) + frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'server-foobar' - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar'))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) + self.client.wfile.flush() - frame = websockets.Frame.from_file(client.rfile) - assert frame.payload == b'client-foobar' + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.payload == b'self.client-foobar' - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef'))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef'))) + self.client.wfile.flush() - frame = websockets.Frame.from_file(client.rfile) + frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'\xde\xad\xbe\xef' - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.flush() assert len(self.master.state.flows) == 2 assert isinstance(self.master.state.flows[0], HTTPFlow) @@ -180,9 +182,9 @@ class TestSimple(_WebSocketTest): assert len(self.master.state.flows[1].messages) == 5 assert self.master.state.flows[1].messages[0].content == b'server-foobar' assert self.master.state.flows[1].messages[0].type == websockets.OPCODE.TEXT - assert self.master.state.flows[1].messages[1].content == b'client-foobar' + assert self.master.state.flows[1].messages[1].content == b'self.client-foobar' assert self.master.state.flows[1].messages[1].type == websockets.OPCODE.TEXT - assert self.master.state.flows[1].messages[2].content == b'client-foobar' + assert self.master.state.flows[1].messages[2].content == b'self.client-foobar' assert self.master.state.flows[1].messages[2].type == websockets.OPCODE.TEXT assert self.master.state.flows[1].messages[3].content == b'\xde\xad\xbe\xef' assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY @@ -203,19 +205,19 @@ class TestSimpleTLS(_WebSocketTest): wfile.flush() def test_simple_tls(self): - client = self._setup_connection() + self.setup_connection() - frame = websockets.Frame.from_file(client.rfile) + frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'server-foobar' - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar'))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) + self.client.wfile.flush() - frame = websockets.Frame.from_file(client.rfile) - assert frame.payload == b'client-foobar' + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.payload == b'self.client-foobar' - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.flush() class TestPing(_WebSocketTest): @@ -233,16 +235,16 @@ class TestPing(_WebSocketTest): wfile.flush() def test_ping(self): - client = self._setup_connection() + self.setup_connection() - frame = websockets.Frame.from_file(client.rfile) + frame = websockets.Frame.from_file(self.client.rfile) assert frame.header.opcode == websockets.OPCODE.PING assert frame.payload == b'foobar' - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) + self.client.wfile.flush() - frame = websockets.Frame.from_file(client.rfile) + frame = websockets.Frame.from_file(self.client.rfile) assert frame.header.opcode == websockets.OPCODE.TEXT assert frame.payload == b'pong-received' @@ -259,12 +261,12 @@ class TestPong(_WebSocketTest): wfile.flush() def test_pong(self): - client = self._setup_connection() + self.setup_connection() - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) + self.client.wfile.flush() - frame = websockets.Frame.from_file(client.rfile) + frame = websockets.Frame.from_file(self.client.rfile) assert frame.header.opcode == websockets.OPCODE.PONG assert frame.payload == b'foobar' @@ -282,34 +284,34 @@ class TestClose(_WebSocketTest): websockets.Frame.from_file(rfile) def test_close(self): - client = self._setup_connection() + self.setup_connection() - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.flush() - websockets.Frame.from_file(client.rfile) + websockets.Frame.from_file(self.client.rfile) with pytest.raises(exceptions.TcpDisconnect): - websockets.Frame.from_file(client.rfile) + websockets.Frame.from_file(self.client.rfile) def test_close_payload_1(self): - client = self._setup_connection() + self.setup_connection() - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42'))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42'))) + self.client.wfile.flush() - websockets.Frame.from_file(client.rfile) + websockets.Frame.from_file(self.client.rfile) with pytest.raises(exceptions.TcpDisconnect): - websockets.Frame.from_file(client.rfile) + websockets.Frame.from_file(self.client.rfile) def test_close_payload_2(self): - client = self._setup_connection() + self.setup_connection() - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar'))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar'))) + self.client.wfile.flush() - websockets.Frame.from_file(client.rfile) + websockets.Frame.from_file(self.client.rfile) with pytest.raises(exceptions.TcpDisconnect): - websockets.Frame.from_file(client.rfile) + websockets.Frame.from_file(self.client.rfile) class TestInvalidFrame(_WebSocketTest): @@ -320,9 +322,9 @@ class TestInvalidFrame(_WebSocketTest): wfile.flush() def test_invalid_frame(self): - client = self._setup_connection() + self.setup_connection() # with pytest.raises(exceptions.TcpDisconnect): - frame = websockets.Frame.from_file(client.rfile) + frame = websockets.Frame.from_file(self.client.rfile) assert frame.header.opcode == 15 assert frame.payload == b'foobar' diff --git a/test/mitmproxy/test_connections.py b/test/mitmproxy/test_connections.py index e320885d1..99367bb65 100644 --- a/test/mitmproxy/test_connections.py +++ b/test/mitmproxy/test_connections.py @@ -1,5 +1,4 @@ import socket -import os import threading import ssl import OpenSSL @@ -140,16 +139,17 @@ class TestServerConnection: assert d.last_log() c.finish() + c.close() d.shutdown() def test_terminate_error(self): d = test.Daemon() c = connections.ServerConnection((d.IFACE, d.port)) c.connect() + c.close() c.connection = mock.Mock() c.connection.recv = mock.Mock(return_value=False) c.connection.flush = mock.Mock(side_effect=exceptions.TcpDisconnect) - c.finish() d.shutdown() def test_sni(self): @@ -194,22 +194,25 @@ class TestClientConnectionTLS: s = socket.create_connection(address) s = ctx.wrap_socket(s, server_hostname=sni) s.send(b'foobar') - s.shutdown(socket.SHUT_RDWR) + s.close() threading.Thread(target=client_run).start() connection, client_address = sock.accept() c = connections.ClientConnection(connection, client_address, None) cert = tutils.test_data.path("mitmproxy/net/data/server.crt") + with open(tutils.test_data.path("mitmproxy/net/data/server.key")) as f: + raw_key = f.read() key = OpenSSL.crypto.load_privatekey( OpenSSL.crypto.FILETYPE_PEM, - open(tutils.test_data.path("mitmproxy/net/data/server.key"), "rb").read()) + raw_key) c.convert_to_ssl(cert, key) assert c.connected() assert c.sni == sni assert c.tls_established assert c.rfile.read(6) == b'foobar' c.finish() + sock.close() class TestServerConnectionTLS(tservers.ServerTestBase): @@ -222,7 +225,7 @@ class TestServerConnectionTLS(tservers.ServerTestBase): @pytest.mark.parametrize("clientcert", [ None, tutils.test_data.path("mitmproxy/data/clientcert"), - os.path.join(tutils.test_data.path("mitmproxy/data/clientcert"), "client.pem"), + tutils.test_data.path("mitmproxy/data/clientcert/client.pem"), ]) def test_tls(self, clientcert): c = connections.ServerConnection(("127.0.0.1", self.port)) diff --git a/test/mitmproxy/test_proxy.py b/test/mitmproxy/test_proxy.py index e1d0da006..299abab3c 100644 --- a/test/mitmproxy/test_proxy.py +++ b/test/mitmproxy/test_proxy.py @@ -32,8 +32,7 @@ class TestProcessProxyOptions: opts = options.Options() cmdline.common_options(parser, opts) args = parser.parse_args(args=args) - main.process_options(parser, opts, args) - pconf = config.ProxyConfig(opts) + pconf = main.process_options(parser, opts, args) return parser, pconf def assert_noerr(self, *args): diff --git a/test/mitmproxy/tservers.py b/test/mitmproxy/tservers.py index b8005529d..3a2050e10 100644 --- a/test/mitmproxy/tservers.py +++ b/test/mitmproxy/tservers.py @@ -133,7 +133,7 @@ class ProxyTestBase: @classmethod def teardown_class(cls): - # perf: we want to run tests in parallell + # perf: we want to run tests in parallel # should this ever cause an error, travis should catch it. # shutil.rmtree(cls.cadir) cls.proxy.shutdown() diff --git a/test/pathod/language/test_base.py b/test/pathod/language/test_base.py index ec460b079..910d298a6 100644 --- a/test/pathod/language/test_base.py +++ b/test/pathod/language/test_base.py @@ -202,12 +202,14 @@ class TestMisc: e.parseString("m@1") s = base.Settings(staticdir=str(tmpdir)) - tmpdir.join("path").write_binary(b"a" * 20, ensure=True) + with open(str(tmpdir.join("path")), 'wb') as f: + f.write(b"a" * 20) v = e.parseString("m