nuke tutils.tmpdir, use pytest tmpdir

This commit is contained in:
Thomas Kriechbaumer 2017-03-12 22:55:22 +01:00
parent d069ba9da5
commit 1b045d24bc
13 changed files with 252 additions and 324 deletions

View File

@ -1,9 +1,5 @@
from io import BytesIO
import tempfile
import os
import time import time
import shutil from io import BytesIO
from contextlib import contextmanager
from mitmproxy.utils import data from mitmproxy.utils import data
from mitmproxy.net import tcp from mitmproxy.net import tcp
@ -13,18 +9,6 @@ from mitmproxy.net import http
test_data = data.Data(__name__).push("../../test/") test_data = data.Data(__name__).push("../../test/")
@contextmanager
def tmpdir(*args, **kwargs):
orig_workdir = os.getcwd()
temp_workdir = tempfile.mkdtemp(*args, **kwargs)
os.chdir(temp_workdir)
yield temp_workdir
os.chdir(orig_workdir)
shutil.rmtree(temp_workdir)
def treader(bytes): def treader(bytes):
""" """
Construct a tcp.Read object from bytes. Construct a tcp.Read object from bytes.

View File

@ -1,9 +1,7 @@
import os
import pytest import pytest
from unittest import mock from unittest import mock
from mitmproxy.test import tflow from mitmproxy.test import tflow
from mitmproxy.test import tutils
from mitmproxy import io from mitmproxy import io
from mitmproxy import exceptions from mitmproxy import exceptions
@ -49,11 +47,10 @@ class TestClientPlayback:
cp.tick() cp.tick()
assert cp.current_thread is None assert cp.current_thread is None
def test_configure(self): def test_configure(self, tmpdir):
cp = clientplayback.ClientPlayback() cp = clientplayback.ClientPlayback()
with taddons.context() as tctx: with taddons.context() as tctx:
with tutils.tmpdir() as td: path = str(tmpdir.join("flows"))
path = os.path.join(td, "flows")
tdump(path, [tflow.tflow()]) tdump(path, [tflow.tflow()])
tctx.configure(cp, client_replay=[path]) tctx.configure(cp, client_replay=[path])
tctx.configure(cp, client_replay=[]) tctx.configure(cp, client_replay=[])

View File

@ -1,11 +1,9 @@
import os.path
import pytest import pytest
from mitmproxy.test import tflow
from mitmproxy.test import tutils
from .. import tservers from .. import tservers
from mitmproxy.addons import replace from mitmproxy.addons import replace
from mitmproxy.test import taddons from mitmproxy.test import taddons
from mitmproxy.test import tflow
class TestReplace: class TestReplace:
@ -71,18 +69,16 @@ class TestUpstreamProxy(tservers.HTTPUpstreamProxyTest):
class TestReplaceFile: class TestReplaceFile:
def test_simple(self): def test_simple(self, tmpdir):
r = replace.ReplaceFile() r = replace.ReplaceFile()
with tutils.tmpdir() as td: rp = tmpdir.join("replacement")
rp = os.path.join(td, "replacement") rp.write("bar")
with open(rp, "w") as f:
f.write("bar")
with taddons.context() as tctx: with taddons.context() as tctx:
tctx.configure( tctx.configure(
r, r,
replacement_files = [ replacement_files = [
"/~q/foo/" + rp, "/~q/foo/" + str(rp),
"/~s/foo/" + rp, "/~s/foo/" + str(rp),
"/~b nonexistent/nonexistent/nonexistent", "/~b nonexistent/nonexistent/nonexistent",
] ]
) )

View File

@ -68,13 +68,12 @@ class TestParseCommand:
with pytest.raises(ValueError): with pytest.raises(ValueError):
script.parse_command(" ") script.parse_command(" ")
def test_no_script_file(self): def test_no_script_file(self, tmpdir):
with pytest.raises(Exception, match="not found"): with pytest.raises(Exception, match="not found"):
script.parse_command("notfound") script.parse_command("notfound")
with tutils.tmpdir() as dir:
with pytest.raises(Exception, match="Not a file"): with pytest.raises(Exception, match="Not a file"):
script.parse_command(dir) script.parse_command(str(tmpdir))
def test_parse_args(self): def test_parse_args(self):
with utils.chdir(tutils.test_data.dirname): with utils.chdir(tutils.test_data.dirname):
@ -128,15 +127,13 @@ class TestScript:
recf = sc.ns.call_log[0] recf = sc.ns.call_log[0]
assert recf[1] == "request" assert recf[1] == "request"
def test_reload(self): def test_reload(self, tmpdir):
with taddons.context() as tctx: with taddons.context() as tctx:
with tutils.tmpdir(): f = tmpdir.join("foo.py")
with open("foo.py", "w"): f.ensure(file=True)
pass sc = script.Script(str(f))
sc = script.Script("foo.py")
tctx.configure(sc) tctx.configure(sc)
for _ in range(100): for _ in range(100):
with open("foo.py", "a") as f:
f.write(".") f.write(".")
sc.tick() sc.tick()
time.sleep(0.1) time.sleep(0.1)

View File

@ -1,10 +1,8 @@
import os
import urllib import urllib
import pytest import pytest
from mitmproxy.test import tutils
from mitmproxy.test import tflow
from mitmproxy.test import taddons from mitmproxy.test import taddons
from mitmproxy.test import tflow
import mitmproxy.test.tutils import mitmproxy.test.tutils
from mitmproxy.addons import serverplayback from mitmproxy.addons import serverplayback
@ -19,15 +17,14 @@ def tdump(path, flows):
w.add(i) w.add(i)
def test_config(): def test_config(tmpdir):
s = serverplayback.ServerPlayback() s = serverplayback.ServerPlayback()
with tutils.tmpdir() as p:
with taddons.context() as tctx: with taddons.context() as tctx:
fpath = os.path.join(p, "flows") fpath = str(tmpdir.join("flows"))
tdump(fpath, [tflow.tflow(resp=True)]) tdump(fpath, [tflow.tflow(resp=True)])
tctx.configure(s, server_replay=[fpath]) tctx.configure(s, server_replay=[fpath])
with pytest.raises(exceptions.OptionsError): with pytest.raises(exceptions.OptionsError):
tctx.configure(s, server_replay=[p]) tctx.configure(s, server_replay=[str(tmpdir)])
def test_tick(): def test_tick():

View File

@ -1,9 +1,7 @@
import os.path
import pytest import pytest
from mitmproxy.test import tflow
from mitmproxy.test import tutils
from mitmproxy.test import taddons from mitmproxy.test import taddons
from mitmproxy.test import tflow
from mitmproxy import io from mitmproxy import io
from mitmproxy import exceptions from mitmproxy import exceptions
@ -11,15 +9,13 @@ from mitmproxy import options
from mitmproxy.addons import streamfile from mitmproxy.addons import streamfile
def test_configure(): def test_configure(tmpdir):
sa = streamfile.StreamFile() sa = streamfile.StreamFile()
with taddons.context(options=options.Options()) as tctx: with taddons.context(options=options.Options()) as tctx:
with tutils.tmpdir() as tdir:
p = os.path.join(tdir, "foo")
with pytest.raises(exceptions.OptionsError): with pytest.raises(exceptions.OptionsError):
tctx.configure(sa, streamfile=tdir) tctx.configure(sa, streamfile=str(tmpdir))
with pytest.raises(Exception, match="Invalid filter"): with pytest.raises(Exception, match="Invalid filter"):
tctx.configure(sa, streamfile=p, filtstr="~~") tctx.configure(sa, streamfile=str(tmpdir.join("foo")), filtstr="~~")
tctx.configure(sa, filtstr="foo") tctx.configure(sa, filtstr="foo")
assert sa.filt assert sa.filt
tctx.configure(sa, filtstr=None) tctx.configure(sa, filtstr=None)
@ -31,11 +27,10 @@ def rd(p):
return list(x.stream()) return list(x.stream())
def test_tcp(): def test_tcp(tmpdir):
sa = streamfile.StreamFile() sa = streamfile.StreamFile()
with taddons.context() as tctx: with taddons.context() as tctx:
with tutils.tmpdir() as tdir: p = str(tmpdir.join("foo"))
p = os.path.join(tdir, "foo")
tctx.configure(sa, streamfile=p) tctx.configure(sa, streamfile=p)
tt = tflow.ttcpflow() tt = tflow.ttcpflow()
@ -45,11 +40,10 @@ def test_tcp():
assert rd(p) assert rd(p)
def test_simple(): def test_simple(tmpdir):
sa = streamfile.StreamFile() sa = streamfile.StreamFile()
with taddons.context() as tctx: with taddons.context() as tctx:
with tutils.tmpdir() as tdir: p = str(tmpdir.join("foo"))
p = os.path.join(tdir, "foo")
tctx.configure(sa, streamfile=p) tctx.configure(sa, streamfile=p)

View File

@ -11,8 +11,8 @@ from OpenSSL import SSL
from mitmproxy import certs from mitmproxy import certs
from mitmproxy.net import tcp from mitmproxy.net import tcp
from mitmproxy.test import tutils
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy.test import tutils
from . import tservers from . import tservers
from ...conftest import requires_alpn from ...conftest import requires_alpn
@ -783,12 +783,11 @@ class TestSSLKeyLogger(tservers.ServerTestBase):
cipher_list="AES256-SHA" cipher_list="AES256-SHA"
) )
def test_log(self): def test_log(self, tmpdir):
testval = b"echo!\n" testval = b"echo!\n"
_logfun = tcp.log_ssl_key _logfun = tcp.log_ssl_key
with tutils.tmpdir() as d: logfile = str(tmpdir.join("foo", "bar", "logfile"))
logfile = os.path.join(d, "foo", "bar", "logfile")
tcp.log_ssl_key = tcp.SSLKeyLogger(logfile) tcp.log_ssl_key = tcp.SSLKeyLogger(logfile)
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))

View File

@ -34,24 +34,21 @@ from mitmproxy.test import tutils
class TestCertStore: class TestCertStore:
def test_create_explicit(self): def test_create_explicit(self, tmpdir):
with tutils.tmpdir() as d: ca = certs.CertStore.from_store(str(tmpdir), "test")
ca = certs.CertStore.from_store(d, "test")
assert ca.get_cert(b"foo", []) assert ca.get_cert(b"foo", [])
ca2 = certs.CertStore.from_store(d, "test") ca2 = certs.CertStore.from_store(str(tmpdir), "test")
assert ca2.get_cert(b"foo", []) assert ca2.get_cert(b"foo", [])
assert ca.default_ca.get_serial_number() == ca2.default_ca.get_serial_number() assert ca.default_ca.get_serial_number() == ca2.default_ca.get_serial_number()
def test_create_no_common_name(self): def test_create_no_common_name(self, tmpdir):
with tutils.tmpdir() as d: ca = certs.CertStore.from_store(str(tmpdir), "test")
ca = certs.CertStore.from_store(d, "test")
assert ca.get_cert(None, [])[0].cn is None assert ca.get_cert(None, [])[0].cn is None
def test_create_tmp(self): def test_create_tmp(self, tmpdir):
with tutils.tmpdir() as d: ca = certs.CertStore.from_store(str(tmpdir), "test")
ca = certs.CertStore.from_store(d, "test")
assert ca.get_cert(b"foo.com", []) assert ca.get_cert(b"foo.com", [])
assert ca.get_cert(b"foo.com", []) assert ca.get_cert(b"foo.com", [])
assert ca.get_cert(b"*.foo.com", []) assert ca.get_cert(b"*.foo.com", [])
@ -59,25 +56,22 @@ class TestCertStore:
r = ca.get_cert(b"*.foo.com", []) r = ca.get_cert(b"*.foo.com", [])
assert r[1] == ca.default_privatekey assert r[1] == ca.default_privatekey
def test_sans(self): def test_sans(self, tmpdir):
with tutils.tmpdir() as d: ca = certs.CertStore.from_store(str(tmpdir), "test")
ca = certs.CertStore.from_store(d, "test")
c1 = ca.get_cert(b"foo.com", [b"*.bar.com"]) c1 = ca.get_cert(b"foo.com", [b"*.bar.com"])
ca.get_cert(b"foo.bar.com", []) ca.get_cert(b"foo.bar.com", [])
# assert c1 == c2 # assert c1 == c2
c3 = ca.get_cert(b"bar.com", []) c3 = ca.get_cert(b"bar.com", [])
assert not c1 == c3 assert not c1 == c3
def test_sans_change(self): def test_sans_change(self, tmpdir):
with tutils.tmpdir() as d: ca = certs.CertStore.from_store(str(tmpdir), "test")
ca = certs.CertStore.from_store(d, "test")
ca.get_cert(b"foo.com", [b"*.bar.com"]) ca.get_cert(b"foo.com", [b"*.bar.com"])
cert, key, chain_file = ca.get_cert(b"foo.bar.com", [b"*.baz.com"]) cert, key, chain_file = ca.get_cert(b"foo.bar.com", [b"*.baz.com"])
assert b"*.baz.com" in cert.altnames assert b"*.baz.com" in cert.altnames
def test_expire(self): def test_expire(self, tmpdir):
with tutils.tmpdir() as d: ca = certs.CertStore.from_store(str(tmpdir), "test")
ca = certs.CertStore.from_store(d, "test")
ca.STORE_CAP = 3 ca.STORE_CAP = 3
ca.get_cert(b"one.com", []) ca.get_cert(b"one.com", [])
ca.get_cert(b"two.com", []) ca.get_cert(b"two.com", [])
@ -100,35 +94,29 @@ class TestCertStore:
assert (b"three.com", ()) in ca.certs assert (b"three.com", ()) in ca.certs
assert (b"four.com", ()) in ca.certs assert (b"four.com", ()) in ca.certs
def test_overrides(self): def test_overrides(self, tmpdir):
with tutils.tmpdir() as d: ca1 = certs.CertStore.from_store(str(tmpdir.join("ca1")), "test")
ca1 = certs.CertStore.from_store(os.path.join(d, "ca1"), "test") ca2 = certs.CertStore.from_store(str(tmpdir.join("ca2")), "test")
ca2 = certs.CertStore.from_store(os.path.join(d, "ca2"), "test") assert not ca1.default_ca.get_serial_number() == ca2.default_ca.get_serial_number()
assert not ca1.default_ca.get_serial_number(
) == ca2.default_ca.get_serial_number()
dc = ca2.get_cert(b"foo.com", [b"sans.example.com"]) dc = ca2.get_cert(b"foo.com", [b"sans.example.com"])
dcp = os.path.join(d, "dc") dcp = tmpdir.join("dc")
f = open(dcp, "wb") dcp.write(dc[0].to_pem())
f.write(dc[0].to_pem()) ca1.add_cert_file(b"foo.com", str(dcp))
f.close()
ca1.add_cert_file(b"foo.com", dcp)
ret = ca1.get_cert(b"foo.com", []) ret = ca1.get_cert(b"foo.com", [])
assert ret[0].serial == dc[0].serial assert ret[0].serial == dc[0].serial
def test_create_dhparams(self): def test_create_dhparams(self, tmpdir):
with tutils.tmpdir() as d: filename = str(tmpdir.join("dhparam.pem"))
filename = os.path.join(d, "dhparam.pem")
certs.CertStore.load_dhparam(filename) certs.CertStore.load_dhparam(filename)
assert os.path.exists(filename) assert os.path.exists(filename)
class TestDummyCert: class TestDummyCert:
def test_with_ca(self): def test_with_ca(self, tmpdir):
with tutils.tmpdir() as d: ca = certs.CertStore.from_store(str(tmpdir), "test")
ca = certs.CertStore.from_store(d, "test")
r = certs.dummy_cert( r = certs.dummy_cert(
ca.default_privatekey, ca.default_privatekey,
ca.default_ca, ca.default_ca,

View File

@ -1,5 +1,4 @@
import json import json
import os
import shlex import shlex
import pytest import pytest
@ -142,9 +141,8 @@ class TestHARDump:
with pytest.raises(ScriptError): with pytest.raises(ScriptError):
tscript("complex/har_dump.py") tscript("complex/har_dump.py")
def test_simple(self): def test_simple(self, tmpdir):
with tutils.tmpdir() as tdir: path = str(tmpdir.join("somefile"))
path = os.path.join(tdir, "somefile")
m, sc = tscript("complex/har_dump.py", shlex.quote(path)) m, sc = tscript("complex/har_dump.py", shlex.quote(path))
m.addons.invoke(m, "response", self.flow()) m.addons.invoke(m, "response", self.flow())
@ -152,12 +150,10 @@ class TestHARDump:
with open(path, "r") as inp: with open(path, "r") as inp:
har = json.load(inp) har = json.load(inp)
assert len(har["log"]["entries"]) == 1 assert len(har["log"]["entries"]) == 1
def test_base64(self): def test_base64(self, tmpdir):
with tutils.tmpdir() as tdir: path = str(tmpdir.join("somefile"))
path = os.path.join(tdir, "somefile")
m, sc = tscript("complex/har_dump.py", shlex.quote(path)) m, sc = tscript("complex/har_dump.py", shlex.quote(path))
m.addons.invoke(m, "response", self.flow(resp_content=b"foo" + b"\xFF" * 10)) m.addons.invoke(m, "response", self.flow(resp_content=b"foo" + b"\xFF" * 10))
@ -165,7 +161,6 @@ class TestHARDump:
with open(path, "r") as inp: with open(path, "r") as inp:
har = json.load(inp) har = json.load(inp)
assert har["log"]["entries"][0]["response"]["content"]["encoding"] == "base64" assert har["log"]["entries"][0]["response"]["content"]["encoding"] == "base64"
def test_format_cookies(self): def test_format_cookies(self):
@ -187,7 +182,7 @@ class TestHARDump:
f = format_cookies([("n", "v", CA([("expires", "Mon, 24-Aug-2037 00:00:00 GMT")]))])[0] f = format_cookies([("n", "v", CA([("expires", "Mon, 24-Aug-2037 00:00:00 GMT")]))])[0]
assert f['expires'] assert f['expires']
def test_binary(self): def test_binary(self, tmpdir):
f = self.flow() f = self.flow()
f.request.method = "POST" f.request.method = "POST"
@ -196,8 +191,7 @@ class TestHARDump:
f.response.headers["random-junk"] = bytes(range(256)) f.response.headers["random-junk"] = bytes(range(256))
f.response.content = bytes(range(256)) f.response.content = bytes(range(256))
with tutils.tmpdir() as tdir: path = str(tmpdir.join("somefile"))
path = os.path.join(tdir, "somefile")
m, sc = tscript("complex/har_dump.py", shlex.quote(path)) m, sc = tscript("complex/har_dump.py", shlex.quote(path))
m.addons.invoke(m, "response", f) m.addons.invoke(m, "response", f)
@ -205,5 +199,4 @@ class TestHARDump:
with open(path, "r") as inp: with open(path, "r") as inp:
har = json.load(inp) har = json.load(inp)
assert len(har["log"]["entries"]) == 1 assert len(har["log"]["entries"]) == 1

View File

@ -1,5 +1,4 @@
import copy import copy
import os
import pytest import pytest
import typing import typing
import argparse import argparse
@ -7,7 +6,6 @@ import argparse
from mitmproxy import options from mitmproxy import options
from mitmproxy import optmanager from mitmproxy import optmanager
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy.test import tutils
class TO(optmanager.OptManager): class TO(optmanager.OptManager):
@ -238,11 +236,10 @@ def test_serialize_defaults():
assert o.serialize(None, defaults=True) assert o.serialize(None, defaults=True)
def test_saving(): def test_saving(tmpdir):
o = TD2() o = TD2()
o.three = "set" o.three = "set"
with tutils.tmpdir() as tdir: dst = str(tmpdir.join("conf"))
dst = os.path.join(tdir, "conf")
o.save(dst, defaults=True) o.save(dst, defaults=True)
o2 = TD2() o2 = TD2()

View File

@ -1,4 +1,3 @@
import os
import pytest import pytest
from unittest import mock from unittest import mock
@ -9,7 +8,6 @@ from mitmproxy import controller
from mitmproxy import options from mitmproxy import options
from mitmproxy.tools import dump from mitmproxy.tools import dump
from mitmproxy.test import tutils
from .. import tservers from .. import tservers
@ -19,9 +17,8 @@ class TestDumpMaster(tservers.MasterTest):
m = dump.DumpMaster(o, proxy.DummyServer(), with_termlog=False, with_dumper=False) m = dump.DumpMaster(o, proxy.DummyServer(), with_termlog=False, with_dumper=False)
return m return m
def test_read(self): def test_read(self, tmpdir):
with tutils.tmpdir() as t: p = str(tmpdir.join("read"))
p = os.path.join(t, "read")
self.flowfile(p) self.flowfile(p)
self.dummy_cycle( self.dummy_cycle(
self.mkmaster(None, rfile=p), self.mkmaster(None, rfile=p),

View File

@ -1,11 +1,8 @@
import os
import pytest import pytest
from pathod import language from pathod import language
from pathod.language import base, exceptions from pathod.language import base, exceptions
from mitmproxy.test import tutils
def parse_request(s): def parse_request(s):
return language.parse_pathoc(s).next() return language.parse_pathoc(s).next()
@ -137,24 +134,22 @@ class TestTokValueFile:
v = base.TokValue.parseString("<path")[0] v = base.TokValue.parseString("<path")[0]
assert v.path == "path" assert v.path == "path"
def test_access_control(self): def test_access_control(self, tmpdir):
v = base.TokValue.parseString("<path")[0] v = base.TokValue.parseString("<path")[0]
with tutils.tmpdir() as t: f = tmpdir.join("path")
p = os.path.join(t, "path")
with open(p, "wb") as f:
f.write(b"x" * 10000) f.write(b"x" * 10000)
assert v.get_generator(language.Settings(staticdir=t)) assert v.get_generator(language.Settings(staticdir=str(tmpdir)))
v = base.TokValue.parseString("<path2")[0] v = base.TokValue.parseString("<path2")[0]
with pytest.raises(exceptions.FileAccessDenied): with pytest.raises(exceptions.FileAccessDenied):
v.get_generator(language.Settings(staticdir=t)) v.get_generator(language.Settings(staticdir=str(tmpdir)))
with pytest.raises(Exception, match="access disabled"): with pytest.raises(Exception, match="access disabled"):
v.get_generator(language.Settings()) v.get_generator(language.Settings())
v = base.TokValue.parseString("</outside")[0] v = base.TokValue.parseString("</outside")[0]
with pytest.raises(Exception, match="outside"): with pytest.raises(Exception, match="outside"):
v.get_generator(language.Settings(staticdir=t)) v.get_generator(language.Settings(staticdir=str(tmpdir)))
def test_spec(self): def test_spec(self):
v = base.TokValue.parseString("<'one two'")[0] v = base.TokValue.parseString("<'one two'")[0]

View File

@ -1,7 +1,4 @@
import os
from pathod.language import generators from pathod.language import generators
from mitmproxy.test import tutils
def test_randomgenerator(): def test_randomgenerator():
@ -15,13 +12,10 @@ def test_randomgenerator():
assert len(g[1000:1001]) == 0 assert len(g[1000:1001]) == 0
def test_filegenerator(): def test_filegenerator(tmpdir):
with tutils.tmpdir() as t: f = tmpdir.join("foo")
path = os.path.join(t, "foo")
f = open(path, "wb")
f.write(b"x" * 10000) f.write(b"x" * 10000)
f.close() g = generators.FileGenerator(str(f))
g = generators.FileGenerator(path)
assert len(g) == 10000 assert len(g) == 10000
assert g[0] == b"x" assert g[0] == b"x"
assert g[-1] == b"x" assert g[-1] == b"x"