Merge pull request #3035 from cortesi/aiosimpler

asyncio consolidation
This commit is contained in:
Aldo Cortesi 2018-04-07 09:37:58 +12:00 committed by GitHub
commit 5e2a1ec23c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 283 additions and 293 deletions

1
.gitignore vendored
View File

@ -14,6 +14,7 @@ build/
dist/
mitmproxy/contrib/kaitaistruct/*.ksy
.pytest_cache
__pycache__
# UI

View File

@ -8,7 +8,6 @@ from mitmproxy import exceptions
from mitmproxy import eventsequence
from mitmproxy import controller
from mitmproxy import flow
from mitmproxy import log
from . import ctx
import pprint
@ -38,11 +37,8 @@ def cut_traceback(tb, func_name):
class StreamLog:
"""
A class for redirecting output using contextlib.
"""
def __init__(self, log):
self.log = log
def __init__(self, lg):
self.log = lg
def write(self, buf):
if buf.strip():
@ -55,13 +51,7 @@ class StreamLog:
@contextlib.contextmanager
def safecall():
# resolve ctx.master here.
# we want to be threadsafe, and ctx.master may already be cleared when an addon prints().
tell = ctx.master.tell
# don't use master.add_log (which is not thread-safe). Instead, put on event queue.
stdout_replacement = StreamLog(
lambda message: tell("log", log.LogEntry(message, "warn"))
)
stdout_replacement = StreamLog(lambda message: ctx.log.warn(message))
try:
with contextlib.redirect_stdout(stdout_replacement):
yield
@ -189,7 +179,6 @@ class AddonManager:
Add addons to the end of the chain, and run their load event.
If any addon has sub-addons, they are registered.
"""
with self.master.handlecontext():
for i in addons:
self.chain.append(self.register(i))
@ -207,7 +196,6 @@ class AddonManager:
raise exceptions.AddonManagerError("No such addon: %s" % n)
self.chain = [i for i in self.chain if i is not a]
del self.lookup[_get_name(a)]
with self.master.handlecontext():
self.invoke_addon(a, "done")
def __len__(self):
@ -220,7 +208,7 @@ class AddonManager:
name = _get_name(item)
return name in self.lookup
def handle_lifecycle(self, name, message):
async def handle_lifecycle(self, name, message):
"""
Handle a lifecycle event.
"""
@ -251,8 +239,7 @@ class AddonManager:
def invoke_addon(self, addon, name, *args, **kwargs):
"""
Invoke an event on an addon and all its children. This method must
run within an established handler context.
Invoke an event on an addon and all its children.
"""
if name not in eventsequence.Events:
name = "event_" + name
@ -274,9 +261,8 @@ class AddonManager:
def trigger(self, name, *args, **kwargs):
"""
Establish a handler context and trigger an event across all addons
Trigger an event across all addons.
"""
with self.master.handlecontext():
for i in self.chain:
try:
with safecall():

View File

@ -1,4 +1,5 @@
import mitmproxy
from mitmproxy import ctx
class CheckCA:
@ -15,10 +16,9 @@ class CheckCA:
if has_ca:
self.failed = mitmproxy.ctx.master.server.config.certstore.default_ca.has_expired()
if self.failed:
mitmproxy.ctx.master.add_log(
ctx.log.warn(
"The mitmproxy certificate authority has expired!\n"
"Please delete all CA-related files in your ~/.mitmproxy folder.\n"
"The CA will be regenerated automatically after restarting mitmproxy.\n"
"Then make sure all your clients have the new CA installed.",
"warn",
)

View File

@ -95,11 +95,7 @@ class Command:
Call the command with a list of arguments. At this point, all
arguments are strings.
"""
pargs = self.prepare_args(args)
with self.manager.master.handlecontext():
ret = self.func(*pargs)
ret = self.func(*self.prepare_args(args))
if ret is None and self.returntype is None:
return
typ = mitmproxy.types.CommandTypes.get(self.returntype)

View File

@ -8,10 +8,10 @@ class Channel:
The only way for the proxy server to communicate with the master
is to use the channel it has been given.
"""
def __init__(self, loop, q, should_exit):
def __init__(self, master, loop, should_exit):
self.master = master
self.loop = loop
self.should_exit = should_exit
self._q = q
def ask(self, mtype, m):
"""
@ -22,7 +22,10 @@ class Channel:
exceptions.Kill: All connections should be closed immediately.
"""
m.reply = Reply(m)
asyncio.run_coroutine_threadsafe(self._q.put((mtype, m)), self.loop)
asyncio.run_coroutine_threadsafe(
self.master.addons.handle_lifecycle(mtype, m),
self.loop,
)
g = m.reply.q.get()
if g == exceptions.Kill:
raise exceptions.Kill()
@ -34,7 +37,10 @@ class Channel:
then return immediately.
"""
m.reply = DummyReply()
asyncio.run_coroutine_threadsafe(self._q.put((mtype, m)), self.loop)
asyncio.run_coroutine_threadsafe(
self.master.addons.handle_lifecycle(mtype, m),
self.loop,
)
NO_REPLY = object() # special object we can distinguish from a valid "None" reply.

View File

@ -1,3 +1,5 @@
import asyncio
class LogEntry:
def __init__(self, msg, level):
@ -54,7 +56,9 @@ class Log:
self(txt, "error")
def __call__(self, text, level="info"):
self.master.add_log(text, level)
asyncio.get_event_loop().call_soon(
self.master.addons.trigger, "log", LogEntry(text, level)
)
LogTierOrder = [

View File

@ -1,5 +1,4 @@
import threading
import contextlib
import asyncio
import logging
@ -43,15 +42,12 @@ class Master:
The master handles mitmproxy's main event loop.
"""
def __init__(self, opts):
self.event_queue = asyncio.Queue()
self.should_exit = threading.Event()
self.channel = controller.Channel(
self,
asyncio.get_event_loop(),
self.event_queue,
self.should_exit,
)
asyncio.ensure_future(self.main())
asyncio.ensure_future(self.tick())
self.options = opts or options.Options() # type: options.Options
self.commands = command.CommandManager(self)
@ -59,6 +55,11 @@ class Master:
self._server = None
self.first_tick = True
self.waiting_flows = []
self.log = log.Log(self)
mitmproxy_ctx.master = self
mitmproxy_ctx.log = self.log
mitmproxy_ctx.options = self.options
@property
def server(self):
@ -69,49 +70,11 @@ class Master:
server.set_channel(self.channel)
self._server = server
@contextlib.contextmanager
def handlecontext(self):
# Handlecontexts also have to nest - leave cleanup to the outermost
if mitmproxy_ctx.master:
yield
return
mitmproxy_ctx.master = self
mitmproxy_ctx.log = log.Log(self)
mitmproxy_ctx.options = self.options
try:
yield
finally:
mitmproxy_ctx.master = None
mitmproxy_ctx.log = None
mitmproxy_ctx.options = None
# This is a vestigial function that will go away in a refactor very soon
def tell(self, mtype, m): # pragma: no cover
m.reply = controller.DummyReply()
self.event_queue.put((mtype, m))
def add_log(self, e, level):
"""
level: debug, alert, info, warn, error
"""
self.addons.trigger("log", log.LogEntry(e, level))
def start(self):
self.should_exit.clear()
if self.server:
ServerThread(self.server).start()
async def main(self):
while True:
try:
mtype, obj = await self.event_queue.get()
except RuntimeError:
return
if mtype not in eventsequence.Events: # pragma: no cover
raise exceptions.ControlException("Unknown event %s" % repr(mtype))
self.addons.handle_lifecycle(mtype, obj)
self.event_queue.task_done()
async def tick(self):
if self.first_tick:
self.first_tick = False
@ -150,7 +113,7 @@ class Master:
f.request.host, f.request.port = upstream_spec.address
f.request.scheme = upstream_spec.scheme
def load_flow(self, f):
async def load_flow(self, f):
"""
Loads a flow and links websocket & handshake flows
"""
@ -168,7 +131,7 @@ class Master:
f.reply = controller.DummyReply()
for e, o in eventsequence.iterate(f):
self.addons.handle_lifecycle(e, o)
await self.addons.handle_lifecycle(e, o)
def replay_request(
self,

View File

@ -1,4 +1,5 @@
import contextlib
import asyncio
import sys
import mitmproxy.master
@ -34,7 +35,7 @@ class RecordingMaster(mitmproxy.master.Master):
for i in self.logs:
print("%s: %s" % (i.level, i.msg), file=outf)
def has_log(self, txt, level=None):
def _has_log(self, txt, level=None):
for i in self.logs:
if level and i.level != level:
continue
@ -42,6 +43,14 @@ class RecordingMaster(mitmproxy.master.Master):
return True
return False
async def await_log(self, txt, level=None):
for i in range(20):
if self._has_log(txt, level):
return True
else:
await asyncio.sleep(0.1)
return False
def has_event(self, name):
for i in self.events:
if i[0] == name:
@ -65,7 +74,6 @@ class context:
options
)
self.options = self.master.options
self.wrapped = None
if loadcore:
self.master.addons.add(core.Core())
@ -73,20 +81,10 @@ class context:
for a in addons:
self.master.addons.add(a)
def ctx(self):
"""
Returns a new handler context.
"""
return self.master.handlecontext()
def __enter__(self):
self.wrapped = self.ctx()
self.wrapped.__enter__()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.wrapped.__exit__(exc_type, exc_value, traceback)
self.wrapped = None
return False
@contextlib.contextmanager

View File

@ -121,7 +121,7 @@ class FlowDetails(tabs.Tabs):
viewmode, message
)
if error:
self.master.add_log(error, "debug")
self.master.log.debug(error)
# Give hint that you have to tab for the response.
if description == "No content" and isinstance(message, http.HTTPRequest):
description = "No request content (press tab to view response)"

View File

@ -4,6 +4,7 @@ import logging
import os.path
import re
from io import BytesIO
import asyncio
import mitmproxy.flow
import tornado.escape
@ -235,7 +236,7 @@ class DumpFlows(RequestHandler):
self.view.clear()
bio = BytesIO(self.filecontents)
for i in io.FlowReader(bio).stream():
self.master.load_flow(i)
asyncio.call_soon(self.master.load_flow, i)
bio.close()

View File

@ -114,17 +114,15 @@ class WebMaster(master.Master):
iol.add_callback(self.start)
web_url = "http://{}:{}/".format(self.options.web_iface, self.options.web_port)
self.add_log(
self.log.info(
"Web server listening at {}".format(web_url),
"info"
)
if self.options.web_open_browser:
success = open_browser(web_url)
if not success:
self.add_log(
self.log.info(
"No web browser found. Please open a browser and point it to {}".format(web_url),
"info"
)
try:
iol.start()

View File

@ -88,6 +88,7 @@ setup(
"flake8>=3.5, <3.6",
"Flask>=0.10.1, <0.13",
"mypy>=0.580,<0.581",
"pytest-asyncio>=0.8",
"pytest-cov>=2.5.1,<3",
"pytest-faulthandler>=1.3.1,<2",
"pytest-timeout>=1.2.1,<2",

View File

@ -17,7 +17,8 @@ from mitmproxy.test import taddons
(False, "fe80::", False),
(False, "2001:4860:4860::8888", True),
])
def test_allowremote(allow_remote, ip, should_be_killed):
@pytest.mark.asyncio
async def test_allowremote(allow_remote, ip, should_be_killed):
ar = allowremote.AllowRemote()
up = proxyauth.ProxyAuth()
with taddons.context(ar, up) as tctx:
@ -28,7 +29,7 @@ def test_allowremote(allow_remote, ip, should_be_killed):
ar.clientconnect(layer)
if should_be_killed:
assert tctx.master.has_log("Client connection was killed", "warn")
assert await tctx.master.await_log("Client connection was killed", "warn")
else:
assert tctx.master.logs == []
tctx.master.clear()

View File

@ -1,31 +1,33 @@
from unittest import mock
import pytest
from mitmproxy.addons import browser
from mitmproxy.test import taddons
def test_browser():
@pytest.mark.asyncio
async def test_browser():
with mock.patch("subprocess.Popen") as po, mock.patch("shutil.which") as which:
which.return_value = "chrome"
b = browser.Browser()
with taddons.context() as tctx:
b.start()
assert po.called
b.start()
assert not tctx.master.has_log("already running")
b.start()
b.browser.poll = lambda: None
b.start()
assert tctx.master.has_log("already running")
assert await tctx.master.await_log("already running")
b.done()
assert not b.browser
def test_no_browser():
@pytest.mark.asyncio
async def test_no_browser():
with mock.patch("shutil.which") as which:
which.return_value = False
b = browser.Browser()
with taddons.context() as tctx:
b.start()
assert tctx.master.has_log("platform is not supported")
assert await tctx.master.await_log("platform is not supported")

View File

@ -8,12 +8,15 @@ from mitmproxy.test import taddons
class TestCheckCA:
@pytest.mark.parametrize('expired', [False, True])
def test_check_ca(self, expired):
@pytest.mark.asyncio
async def test_check_ca(self, expired):
msg = 'The mitmproxy certificate authority has expired!'
with taddons.context() as tctx:
tctx.master.server = mock.MagicMock()
tctx.master.server.config.certstore.default_ca.has_expired = mock.MagicMock(return_value=expired)
tctx.master.server.config.certstore.default_ca.has_expired = mock.MagicMock(
return_value = expired
)
a = check_ca.CheckCA()
tctx.configure(a)
assert tctx.master.has_log(msg) is expired
assert await tctx.master.await_log(msg) == expired

View File

@ -71,7 +71,8 @@ def qr(f):
return fp.read()
def test_cut_clip():
@pytest.mark.asyncio
async def test_cut_clip():
v = view.View()
c = cut.Cut()
with taddons.context() as tctx:
@ -95,7 +96,7 @@ def test_cut_clip():
"copy/paste mechanism for your system."
pc.side_effect = pyperclip.PyperclipException(log_message)
tctx.command(c.clip, "@all", "request.method")
assert tctx.master.has_log(log_message, level="error")
assert await tctx.master.await_log(log_message, level="error")
def test_cut_save(tmpdir):
@ -125,7 +126,8 @@ def test_cut_save(tmpdir):
(IsADirectoryError, "Is a directory"),
(FileNotFoundError, "No such file or directory")
])
def test_cut_save_open(exception, log_message, tmpdir):
@pytest.mark.asyncio
async def test_cut_save_open(exception, log_message, tmpdir):
f = str(tmpdir.join("path"))
v = view.View()
c = cut.Cut()
@ -136,7 +138,7 @@ def test_cut_save_open(exception, log_message, tmpdir):
with mock.patch("mitmproxy.addons.cut.open") as m:
m.side_effect = exception(log_message)
tctx.command(c.save, "@all", "request.method", f)
assert tctx.master.has_log(log_message, level="error")
assert await tctx.master.await_log(log_message, level="error")
def test_cut():

View File

@ -141,15 +141,16 @@ def test_echo_request_line():
class TestContentView:
@mock.patch("mitmproxy.contentviews.auto.ViewAuto.__call__")
def test_contentview(self, view_auto):
view_auto.side_effect = exceptions.ContentViewException("")
@pytest.mark.asyncio
async def test_contentview(self):
with mock.patch("mitmproxy.contentviews.auto.ViewAuto.__call__") as va:
va.side_effect = exceptions.ContentViewException("")
sio = io.StringIO()
d = dumper.Dumper(sio)
with taddons.context(d) as ctx:
ctx.configure(d, flow_detail=4)
d.response(tflow.tflow())
assert ctx.master.has_log("content viewer failed")
assert await ctx.master.await_log("content viewer failed")
def test_tcp():

View File

@ -125,17 +125,19 @@ def test_export(tmpdir):
(IsADirectoryError, "Is a directory"),
(FileNotFoundError, "No such file or directory")
])
def test_export_open(exception, log_message, tmpdir):
@pytest.mark.asyncio
async def test_export_open(exception, log_message, tmpdir):
f = str(tmpdir.join("path"))
e = export.Export()
with taddons.context() as tctx:
with mock.patch("mitmproxy.addons.export.open") as m:
m.side_effect = exception(log_message)
e.file("raw", tflow.tflow(resp=True), f)
assert tctx.master.has_log(log_message, level="error")
assert await tctx.master.await_log(log_message, level="error")
def test_clip(tmpdir):
@pytest.mark.asyncio
async def test_clip(tmpdir):
e = export.Export()
with taddons.context() as tctx:
with pytest.raises(exceptions.CommandError):
@ -158,4 +160,4 @@ def test_clip(tmpdir):
"copy/paste mechanism for your system."
pc.side_effect = pyperclip.PyperclipException(log_message)
e.clip("raw", tflow.tflow(resp=True))
assert tctx.master.has_log(log_message, level="error")
assert await tctx.master.await_log(log_message, level="error")

View File

@ -55,26 +55,28 @@ class TestReadFile:
with pytest.raises(exceptions.OptionsError):
rf.running()
@mock.patch('mitmproxy.master.Master.load_flow')
def test_corrupt(self, mck, corrupt_data):
@pytest.mark.asyncio
async def test_corrupt(self, corrupt_data):
rf = readfile.ReadFile()
with taddons.context(rf) as tctx:
with mock.patch('mitmproxy.master.Master.load_flow') as mck:
with pytest.raises(exceptions.FlowReadException):
rf.load_flows(io.BytesIO(b"qibble"))
assert not mck.called
assert len(tctx.master.logs) == 1
tctx.master.clear()
with pytest.raises(exceptions.FlowReadException):
rf.load_flows(corrupt_data)
assert await tctx.master.await_log("file corrupted")
assert mck.called
assert len(tctx.master.logs) == 2
def test_nonexisting_file(self):
@pytest.mark.asyncio
async def test_nonexisting_file(self):
rf = readfile.ReadFile()
with taddons.context(rf) as tctx:
with pytest.raises(exceptions.FlowReadException):
rf.load_flows_from_path("nonexistent")
assert len(tctx.master.logs) == 1
assert await tctx.master.await_log("nonexistent")
class TestReadFileStdin:

View File

@ -79,7 +79,8 @@ class TestReplaceFile:
r.request(f)
assert f.request.content == b"bar"
def test_nonexistent(self, tmpdir):
@pytest.mark.asyncio
async def test_nonexistent(self, tmpdir):
r = replace.Replace()
with taddons.context(r) as tctx:
with pytest.raises(Exception, match="Invalid file path"):
@ -97,6 +98,5 @@ class TestReplaceFile:
tmpfile.remove()
f = tflow.tflow()
f.request.content = b"foo"
assert not tctx.master.logs
r.request(f)
assert tctx.master.logs
assert await tctx.master.await_log("could not read")

View File

@ -1,13 +1,11 @@
import os
import sys
import traceback
from unittest import mock
import pytest
from mitmproxy import addonmanager
from mitmproxy import exceptions
from mitmproxy import log
from mitmproxy.addons import script
from mitmproxy.test import taddons
from mitmproxy.test import tflow
@ -49,17 +47,15 @@ def test_load_fullname():
assert not hasattr(ns2, "addons")
def test_script_print_stdout():
@pytest.mark.asyncio
async def test_script_print_stdout():
with taddons.context() as tctx:
with mock.patch('mitmproxy.ctx.master.tell') as mock_warn:
with addonmanager.safecall():
ns = script.load_script(
tutils.test_data.path(
"mitmproxy/data/addonscripts/print.py"
)
tutils.test_data.path("mitmproxy/data/addonscripts/print.py")
)
ns.load(addonmanager.Loader(tctx.master))
mock_warn.assert_called_once_with("log", log.LogEntry("stdoutprint", "warn"))
assert await tctx.master.await_log("stdoutprint")
class TestScript:
@ -101,7 +97,8 @@ class TestScript:
assert rec.call_log[0][1] == "request"
def test_reload(self, tmpdir):
@pytest.mark.asyncio
async def test_reload(self, tmpdir):
with taddons.context() as tctx:
f = tmpdir.join("foo.py")
f.ensure(file=True)
@ -109,15 +106,15 @@ class TestScript:
sc = script.Script(str(f))
tctx.configure(sc)
sc.tick()
assert tctx.master.has_log("Loading")
assert await tctx.master.await_log("Loading")
tctx.master.clear()
assert not tctx.master.has_log("Loading")
sc.last_load, sc.last_mtime = 0, 0
sc.tick()
assert tctx.master.has_log("Loading")
assert await tctx.master.await_log("Loading")
def test_exception(self):
@pytest.mark.asyncio
async def test_exception(self):
with taddons.context() as tctx:
sc = script.Script(
tutils.test_data.path("mitmproxy/data/addonscripts/error.py")
@ -129,8 +126,8 @@ class TestScript:
f = tflow.tflow(resp=True)
tctx.master.addons.trigger("request", f)
assert tctx.master.has_log("ValueError: Error!")
assert tctx.master.has_log("error.py")
assert await tctx.master.await_log("ValueError: Error!")
assert await tctx.master.await_log("error.py")
def test_addon(self):
with taddons.context() as tctx:
@ -166,13 +163,15 @@ class TestCutTraceback:
class TestScriptLoader:
def test_script_run(self):
@pytest.mark.asyncio
async def test_script_run(self):
rp = tutils.test_data.path(
"mitmproxy/data/addonscripts/recorder/recorder.py"
)
sc = script.ScriptLoader()
with taddons.context(sc) as tctx:
sc.script_run([tflow.tflow(resp=True)], rp)
await tctx.master.await_log("recorder response")
debug = [i.msg for i in tctx.master.logs if i.level == "debug"]
assert debug == [
'recorder load', 'recorder running', 'recorder configure',
@ -181,11 +180,12 @@ class TestScriptLoader:
'recorder responseheaders', 'recorder response'
]
def test_script_run_nonexistent(self):
@pytest.mark.asyncio
async def test_script_run_nonexistent(self):
sc = script.ScriptLoader()
with taddons.context(sc) as tctx:
sc.script_run([tflow.tflow(resp=True)], "/")
tctx.master.has_log("/: No such script")
assert await tctx.master.await_log("/: No such script")
def test_simple(self):
sc = script.ScriptLoader()
@ -243,19 +243,21 @@ class TestScriptLoader:
tctx.invoke(sc, "tick")
assert len(tctx.master.addons) == 1
def test_script_error_handler(self):
@pytest.mark.asyncio
async def test_script_error_handler(self):
path = "/sample/path/example.py"
exc = SyntaxError
msg = "Error raised"
tb = True
with taddons.context() as tctx:
script.script_error_handler(path, exc, msg, tb)
assert tctx.master.has_log("/sample/path/example.py")
assert tctx.master.has_log("Error raised")
assert tctx.master.has_log("lineno")
assert tctx.master.has_log("NoneType")
assert await tctx.master.await_log("/sample/path/example.py")
assert await tctx.master.await_log("Error raised")
assert await tctx.master.await_log("lineno")
assert await tctx.master.await_log("NoneType")
def test_order(self):
@pytest.mark.asyncio
async def test_order(self):
rec = tutils.test_data.path("mitmproxy/data/addonscripts/recorder")
sc = script.ScriptLoader()
sc.is_running = True
@ -269,6 +271,7 @@ class TestScriptLoader:
]
)
tctx.master.addons.invoke_addon(sc, "tick")
await tctx.master.await_log("c tick")
debug = [i.msg for i in tctx.master.logs if i.level == "debug"]
assert debug == [
'a load',
@ -287,7 +290,7 @@ class TestScriptLoader:
'c tick',
]
tctx.master.logs = []
tctx.master.clear()
tctx.configure(
sc,
scripts = [
@ -297,6 +300,7 @@ class TestScriptLoader:
]
)
await tctx.master.await_log("c configure")
debug = [i.msg for i in tctx.master.logs if i.level == "debug"]
assert debug == [
'c configure',
@ -313,6 +317,7 @@ class TestScriptLoader:
]
)
tctx.master.addons.invoke_addon(sc, "tick")
await tctx.master.await_log("a tick")
debug = [i.msg for i in tctx.master.logs if i.level == "debug"]
assert debug == [

View File

@ -1,15 +1,17 @@
import pytest
from mitmproxy import proxy
from mitmproxy.addons import termstatus
from mitmproxy.test import taddons
def test_configure():
@pytest.mark.asyncio
async def test_configure():
ts = termstatus.TermStatus()
with taddons.context() as ctx:
ctx.master.server = proxy.DummyServer()
ctx.configure(ts, server=False)
ts.running()
assert not ctx.master.logs
ctx.configure(ts, server=True)
ts.running()
assert ctx.master.logs
await ctx.master.await_log("server listening")

View File

@ -159,7 +159,8 @@ def test_orders():
assert v.order_options()
def test_load(tmpdir):
@pytest.mark.asyncio
async def test_load(tmpdir):
path = str(tmpdir.join("path"))
v = view.View()
with taddons.context() as tctx:
@ -182,7 +183,7 @@ def test_load(tmpdir):
with open(path, "wb") as f:
f.write(b"invalidflows")
v.load_file(path)
assert tctx.master.has_log("Invalid data format.")
assert await tctx.master.await_log("Invalid data format.")
def test_resolve():

View File

@ -3,7 +3,6 @@ import os
import struct
import tempfile
import traceback
import time
from mitmproxy import options
from mitmproxy import exceptions
@ -48,6 +47,7 @@ class _WebSocketServerBase(net_tservers.ServerTestBase):
class _WebSocketTestBase:
client = None
@classmethod
def setup_class(cls):
@ -286,7 +286,8 @@ class TestPing(_WebSocketTest):
wfile.flush()
websockets.Frame.from_file(rfile)
def test_ping(self):
@pytest.mark.asyncio
async def test_ping(self):
self.setup_connection()
frame = websockets.Frame.from_file(self.client.rfile)
@ -296,7 +297,7 @@ class TestPing(_WebSocketTest):
assert frame.header.opcode == websockets.OPCODE.PING
assert frame.payload == b'' # We don't send payload to other end
assert self.master.has_log("Pong Received from server", "info")
assert await self.master.await_log("Pong Received from server", "info")
class TestPong(_WebSocketTest):
@ -314,7 +315,8 @@ class TestPong(_WebSocketTest):
wfile.flush()
websockets.Frame.from_file(rfile)
def test_pong(self):
@pytest.mark.asyncio
async def test_pong(self):
self.setup_connection()
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
@ -327,12 +329,7 @@ class TestPong(_WebSocketTest):
assert frame.header.opcode == websockets.OPCODE.PONG
assert frame.payload == b'foobar'
for i in range(20):
if self.master.has_log("Pong Received from server", "info"):
break
time.sleep(0.01)
else:
raise AssertionError("No pong seen")
assert await self.master.await_log("pong received")
class TestClose(_WebSocketTest):

View File

@ -1,3 +1,4 @@
import asyncio
import os
import socket
import time
@ -123,8 +124,6 @@ class TcpMixin:
i2 = self.pathod("306")
self._ignore_off()
self.master.event_queue.join()
assert n.status_code == 304
assert i.status_code == 305
assert i2.status_code == 306
@ -168,8 +167,6 @@ class TcpMixin:
i2 = self.pathod("306")
self._tcpproxy_off()
self.master.event_queue.join()
assert n.status_code == 304
assert i.status_code == 305
assert i2.status_code == 306
@ -238,13 +235,14 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin):
assert p.request(req)
assert p.request(req)
def test_get_connection_switching(self):
@pytest.mark.asyncio
async def test_get_connection_switching(self):
req = "get:'%s/p/200:b@1'"
p = self.pathoc()
with p.connect():
assert p.request(req % self.server.urlbase)
assert p.request(req % self.server2.urlbase)
assert self.proxy.tmaster.has_log("serverdisconnect")
assert await self.proxy.tmaster.await_log("serverdisconnect")
def test_blank_leading_line(self):
p = self.pathoc()
@ -447,13 +445,14 @@ class TestReverse(tservers.ReverseProxyTest, CommonMixin, TcpMixin):
req = self.master.state.flows[0].request
assert req.host_header == "127.0.0.1"
def test_selfconnection(self):
@pytest.mark.asyncio
async def test_selfconnection(self):
self.options.mode = "reverse:http://127.0.0.1:0"
p = self.pathoc()
with p.connect():
p.request("get:/")
assert self.master.has_log("The proxy shall not connect to itself.")
assert await self.master.await_log("The proxy shall not connect to itself.")
class TestReverseSSL(tservers.ReverseProxyTest, CommonMixin, TcpMixin):
@ -553,7 +552,6 @@ class TestHttps2Http(tservers.ReverseProxyTest):
p = self.pathoc(ssl=True, sni="example.com")
with p.connect():
assert p.request("get:'/p/200'").status_code == 200
assert not self.proxy.tmaster.has_log("error in handle_sni")
def test_http(self):
p = self.pathoc(ssl=False)
@ -818,11 +816,13 @@ class TestServerConnect(tservers.HTTPProxyTest):
opts.upstream_cert = False
return opts
def test_unnecessary_serverconnect(self):
@pytest.mark.asyncio
async def test_unnecessary_serverconnect(self):
"""A replayed/fake response with no upstream_cert should not connect to an upstream server"""
self.set_addons(AFakeResponse())
assert self.pathod("200").status_code == 200
assert not self.proxy.tmaster.has_log("serverconnect")
asyncio.sleep(0.1)
assert not self.proxy.tmaster._has_log("serverconnect")
class AKillRequest:

View File

@ -1,3 +1,5 @@
import pytest
from mitmproxy.test import tflow
from mitmproxy.test import tutils
from mitmproxy.test import taddons
@ -31,14 +33,15 @@ class TestConcurrent(tservers.MasterTest):
return
raise ValueError("Script never acked")
def test_concurrent_err(self):
@pytest.mark.asyncio
async def test_concurrent_err(self):
with taddons.context() as tctx:
tctx.script(
tutils.test_data.path(
"mitmproxy/data/addonscripts/concurrent_decorator_err.py"
)
)
assert tctx.master.has_log("decorator not supported")
assert await tctx.master.await_log("decorator not supported")
def test_concurrent_class(self):
with taddons.context() as tctx:

View File

@ -1,4 +1,6 @@
import pytest
from unittest import mock
from mitmproxy import addons
from mitmproxy import addonmanager
@ -65,7 +67,8 @@ def test_halt():
assert end.custom_called
def test_lifecycle():
@pytest.mark.asyncio
async def test_lifecycle():
o = options.Options()
m = master.Master(o)
a = addonmanager.AddonManager(m)
@ -77,7 +80,7 @@ def test_lifecycle():
a.remove(TAddon("nonexistent"))
f = tflow.tflow()
a.handle_lifecycle("request", f)
await a.handle_lifecycle("request", f)
a._configure_all(o, o.keys())
@ -86,19 +89,21 @@ def test_defaults():
assert addons.default_addons()
def test_loader():
@pytest.mark.asyncio
async def test_loader():
with taddons.context() as tctx:
with mock.patch("mitmproxy.ctx.log.warn") as warn:
l = addonmanager.Loader(tctx.master)
l.add_option("custom_option", bool, False, "help")
assert "custom_option" in l.master.options
# calling this again with the same signature is a no-op.
l.add_option("custom_option", bool, False, "help")
assert not tctx.master.has_log("Over-riding existing option")
assert not warn.called
# a different signature should emit a warning though.
l.add_option("custom_option", bool, True, "help")
assert tctx.master.has_log("Over-riding existing option")
assert warn.called
def cmd(a: str) -> str:
return "foo"
@ -106,7 +111,8 @@ def test_loader():
l.add_command("test.command", cmd)
def test_simple():
@pytest.mark.asyncio
async def test_simple():
with taddons.context(loadcore=False) as tctx:
a = tctx.master.addons
@ -120,14 +126,14 @@ def test_simple():
assert not a.chain
a.add(TAddon("one"))
a.trigger("done")
a.trigger("running")
a.trigger("tick")
assert tctx.master.has_log("not callable")
assert await tctx.master.await_log("not callable")
tctx.master.clear()
a.get("one").tick = addons
a.trigger("tick")
assert not tctx.master.has_log("not callable")
assert not await tctx.master.await_log("not callable")
a.remove(a.get("one"))
assert not a.get("one")

View File

@ -5,12 +5,11 @@ import pytest
from mitmproxy.exceptions import Kill, ControlException
from mitmproxy import controller
from mitmproxy.test import taddons
import mitmproxy.ctx
@pytest.mark.asyncio
async def test_master():
class TMsg:
pass
class tAddon:
def log(self, _):
@ -20,12 +19,11 @@ async def test_master():
assert not ctx.master.should_exit.is_set()
async def test():
msg = TMsg()
msg.reply = controller.DummyReply()
await ctx.master.channel.tell("log", msg)
mitmproxy.ctx.log("test")
asyncio.ensure_future(test())
assert not ctx.master.should_exit.is_set()
assert await ctx.master.await_log("test")
assert ctx.master.should_exit.is_set()
class TestReply:

View File

@ -2,7 +2,7 @@ import io
from unittest import mock
import pytest
from mitmproxy.test import tflow, tutils
from mitmproxy.test import tflow, tutils, taddons
import mitmproxy.io
from mitmproxy import flowfilter
from mitmproxy import options
@ -97,27 +97,27 @@ class TestSerialize:
class TestFlowMaster:
def test_load_http_flow_reverse(self):
s = tservers.TestState()
@pytest.mark.asyncio
async def test_load_http_flow_reverse(self):
opts = options.Options(
mode="reverse:https://use-this-domain"
)
fm = master.Master(opts)
fm.addons.add(s)
s = tservers.TestState()
with taddons.context(s, options=opts) as ctx:
f = tflow.tflow(resp=True)
fm.load_flow(f)
await ctx.master.load_flow(f)
assert s.flows[0].request.host == "use-this-domain"
def test_load_websocket_flow(self):
s = tservers.TestState()
@pytest.mark.asyncio
async def test_load_websocket_flow(self):
opts = options.Options(
mode="reverse:https://use-this-domain"
)
fm = master.Master(opts)
fm.addons.add(s)
s = tservers.TestState()
with taddons.context(s, options=opts) as ctx:
f = tflow.twebsocketflow()
fm.load_flow(f.handshake_flow)
fm.load_flow(f)
await ctx.master.load_flow(f.handshake_flow)
await ctx.master.load_flow(f)
assert s.flows[0].request.host == "use-this-domain"
assert s.flows[1].handshake_flow == f.handshake_flow
assert len(s.flows[1].messages) == len(f.messages)
@ -150,31 +150,27 @@ class TestFlowMaster:
assert rt.f.request.http_version == "HTTP/1.1"
assert ":authority" not in rt.f.request.headers
def test_all(self):
@pytest.mark.asyncio
async def test_all(self):
opts = options.Options(
mode="reverse:https://use-this-domain"
)
s = tservers.TestState()
fm = master.Master(None)
fm.addons.add(s)
with taddons.context(s, options=opts) as ctx:
f = tflow.tflow(req=None)
fm.addons.handle_lifecycle("clientconnect", f.client_conn)
await ctx.master.addons.handle_lifecycle("clientconnect", f.client_conn)
f.request = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq())
fm.addons.handle_lifecycle("request", f)
await ctx.master.addons.handle_lifecycle("request", f)
assert len(s.flows) == 1
f.response = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp())
fm.addons.handle_lifecycle("response", f)
await ctx.master.addons.handle_lifecycle("response", f)
assert len(s.flows) == 1
fm.addons.handle_lifecycle("clientdisconnect", f.client_conn)
await ctx.master.addons.handle_lifecycle("clientdisconnect", f.client_conn)
f.error = flow.Error("msg")
fm.addons.handle_lifecycle("error", f)
# FIXME: This no longer works, because we consume on the main loop.
# fm.tell("foo", f)
# with pytest.raises(ControlException):
# fm.addons.trigger("unknown")
fm.shutdown()
await ctx.master.addons.handle_lifecycle("error", f)
class TestError:

View File

@ -1,21 +1,27 @@
import io
import pytest
from mitmproxy.test import taddons
from mitmproxy.test import tutils
from mitmproxy import ctx
def test_recordingmaster():
@pytest.mark.asyncio
async def test_recordingmaster():
with taddons.context() as tctx:
assert not tctx.master.has_log("nonexistent")
assert not tctx.master._has_log("nonexistent")
assert not tctx.master.has_event("nonexistent")
ctx.log.error("foo")
assert not tctx.master.has_log("foo", level="debug")
assert tctx.master.has_log("foo", level="error")
assert not tctx.master._has_log("foo", level="debug")
assert await tctx.master.await_log("foo", level="error")
def test_dumplog():
@pytest.mark.asyncio
async def test_dumplog():
with taddons.context() as tctx:
ctx.log.info("testing")
await ctx.master.await_log("testing")
s = io.StringIO()
tctx.master.dump_log(s)
assert s.getvalue()

View File

@ -4,13 +4,16 @@ from mitmproxy.tools.console import keymap
from mitmproxy.tools.console import master
from mitmproxy import command
import pytest
def test_commands_exist():
@pytest.mark.asyncio
async def test_commands_exist():
km = keymap.Keymap(None)
defaultkeys.map(km)
assert km.bindings
m = master.ConsoleMaster(None)
m.load_flow(tflow())
await m.load_flow(tflow())
for binding in km.bindings:
cmd, *args = command.lexer(binding.command)

View File

@ -4,7 +4,10 @@ from mitmproxy import options
from mitmproxy.tools import console
from ... import tservers
import pytest
@pytest.mark.asyncio
class TestMaster(tservers.MasterTest):
def mkmaster(self, **opts):
o = options.Options(**opts)
@ -12,11 +15,11 @@ class TestMaster(tservers.MasterTest):
m.addons.trigger("configure", o.keys())
return m
def test_basic(self):
async def test_basic(self):
m = self.mkmaster()
for i in (1, 2, 3):
try:
self.dummy_cycle(m, 1, b"")
await self.dummy_cycle(m, 1, b"")
except urwid.ExitMainLoop:
pass
assert len(m.view) == i

View File

@ -2,6 +2,7 @@ import json as _json
import logging
from unittest import mock
import os
import asyncio
import pytest
import tornado.testing
@ -32,6 +33,11 @@ def json(resp: httpclient.HTTPResponse):
@pytest.mark.usefixtures("no_tornado_logging")
class TestApp(tornado.testing.AsyncHTTPTestCase):
def get_new_ioloop(self):
io_loop = tornado.platform.asyncio.AsyncIOLoop()
asyncio.set_event_loop(io_loop.asyncio_loop)
return io_loop
def get_app(self):
o = options.Options(http2=False)
m = webmaster.WebMaster(o, with_termlog=False)
@ -39,7 +45,7 @@ class TestApp(tornado.testing.AsyncHTTPTestCase):
f.id = "42"
m.view.add([f])
m.view.add([tflow.tflow(err=True)])
m.add_log("test log", "info")
m.log.info("test log")
self.master = m
self.view = m.view
self.events = m.events
@ -75,12 +81,6 @@ class TestApp(tornado.testing.AsyncHTTPTestCase):
resp = self.fetch("/flows/dump")
assert b"address" in resp.body
self.view.clear()
assert not len(self.view)
assert self.fetch("/flows/dump", method="POST", body=resp.body).code == 200
assert len(self.view)
def test_clear(self):
events = self.events.data.copy()
flows = list(self.view)

View File

@ -1,6 +1,8 @@
from mitmproxy.tools.web import master
from mitmproxy import options
import pytest
from ... import tservers
@ -9,8 +11,9 @@ class TestWebMaster(tservers.MasterTest):
o = options.Options(**opts)
return master.WebMaster(o)
def test_basic(self):
@pytest.mark.asyncio
async def test_basic(self):
m = self.mkmaster()
for i in (1, 2, 3):
self.dummy_cycle(m, 1, b"")
await self.dummy_cycle(m, 1, b"")
assert len(m.view) == i

View File

@ -26,20 +26,20 @@ from mitmproxy.test import taddons
class MasterTest:
def cycle(self, master, content):
async def cycle(self, master, content):
f = tflow.tflow(req=tutils.treq(content=content))
layer = mock.Mock("mitmproxy.proxy.protocol.base.Layer")
layer.client_conn = f.client_conn
layer.reply = controller.DummyReply()
master.addons.handle_lifecycle("clientconnect", layer)
await master.addons.handle_lifecycle("clientconnect", layer)
for i in eventsequence.iterate(f):
master.addons.handle_lifecycle(*i)
master.addons.handle_lifecycle("clientdisconnect", layer)
await master.addons.handle_lifecycle(*i)
await master.addons.handle_lifecycle("clientdisconnect", layer)
return f
def dummy_cycle(self, master, n, content):
async def dummy_cycle(self, master, n, content):
for i in range(n):
self.cycle(master, content)
await self.cycle(master, content)
master.shutdown()
def flowfile(self, path):