Merge pull request #3035 from cortesi/aiosimpler

asyncio consolidation
This commit is contained in:
Aldo Cortesi 2018-04-07 09:37:58 +12:00 committed by GitHub
commit 5e2a1ec23c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 283 additions and 293 deletions

1
.gitignore vendored
View File

@ -14,6 +14,7 @@ build/
dist/ dist/
mitmproxy/contrib/kaitaistruct/*.ksy mitmproxy/contrib/kaitaistruct/*.ksy
.pytest_cache .pytest_cache
__pycache__
# UI # UI

View File

@ -8,7 +8,6 @@ from mitmproxy import exceptions
from mitmproxy import eventsequence from mitmproxy import eventsequence
from mitmproxy import controller from mitmproxy import controller
from mitmproxy import flow from mitmproxy import flow
from mitmproxy import log
from . import ctx from . import ctx
import pprint import pprint
@ -38,11 +37,8 @@ def cut_traceback(tb, func_name):
class StreamLog: class StreamLog:
""" def __init__(self, lg):
A class for redirecting output using contextlib. self.log = lg
"""
def __init__(self, log):
self.log = log
def write(self, buf): def write(self, buf):
if buf.strip(): if buf.strip():
@ -55,13 +51,7 @@ class StreamLog:
@contextlib.contextmanager @contextlib.contextmanager
def safecall(): def safecall():
# resolve ctx.master here. stdout_replacement = StreamLog(lambda message: ctx.log.warn(message))
# we want to be threadsafe, and ctx.master may already be cleared when an addon prints().
tell = ctx.master.tell
# don't use master.add_log (which is not thread-safe). Instead, put on event queue.
stdout_replacement = StreamLog(
lambda message: tell("log", log.LogEntry(message, "warn"))
)
try: try:
with contextlib.redirect_stdout(stdout_replacement): with contextlib.redirect_stdout(stdout_replacement):
yield yield
@ -189,7 +179,6 @@ class AddonManager:
Add addons to the end of the chain, and run their load event. Add addons to the end of the chain, and run their load event.
If any addon has sub-addons, they are registered. If any addon has sub-addons, they are registered.
""" """
with self.master.handlecontext():
for i in addons: for i in addons:
self.chain.append(self.register(i)) self.chain.append(self.register(i))
@ -207,7 +196,6 @@ class AddonManager:
raise exceptions.AddonManagerError("No such addon: %s" % n) raise exceptions.AddonManagerError("No such addon: %s" % n)
self.chain = [i for i in self.chain if i is not a] self.chain = [i for i in self.chain if i is not a]
del self.lookup[_get_name(a)] del self.lookup[_get_name(a)]
with self.master.handlecontext():
self.invoke_addon(a, "done") self.invoke_addon(a, "done")
def __len__(self): def __len__(self):
@ -220,7 +208,7 @@ class AddonManager:
name = _get_name(item) name = _get_name(item)
return name in self.lookup return name in self.lookup
def handle_lifecycle(self, name, message): async def handle_lifecycle(self, name, message):
""" """
Handle a lifecycle event. Handle a lifecycle event.
""" """
@ -251,8 +239,7 @@ class AddonManager:
def invoke_addon(self, addon, name, *args, **kwargs): def invoke_addon(self, addon, name, *args, **kwargs):
""" """
Invoke an event on an addon and all its children. This method must Invoke an event on an addon and all its children.
run within an established handler context.
""" """
if name not in eventsequence.Events: if name not in eventsequence.Events:
name = "event_" + name name = "event_" + name
@ -274,9 +261,8 @@ class AddonManager:
def trigger(self, name, *args, **kwargs): def trigger(self, name, *args, **kwargs):
""" """
Establish a handler context and trigger an event across all addons Trigger an event across all addons.
""" """
with self.master.handlecontext():
for i in self.chain: for i in self.chain:
try: try:
with safecall(): with safecall():

View File

@ -1,4 +1,5 @@
import mitmproxy import mitmproxy
from mitmproxy import ctx
class CheckCA: class CheckCA:
@ -15,10 +16,9 @@ class CheckCA:
if has_ca: if has_ca:
self.failed = mitmproxy.ctx.master.server.config.certstore.default_ca.has_expired() self.failed = mitmproxy.ctx.master.server.config.certstore.default_ca.has_expired()
if self.failed: if self.failed:
mitmproxy.ctx.master.add_log( ctx.log.warn(
"The mitmproxy certificate authority has expired!\n" "The mitmproxy certificate authority has expired!\n"
"Please delete all CA-related files in your ~/.mitmproxy folder.\n" "Please delete all CA-related files in your ~/.mitmproxy folder.\n"
"The CA will be regenerated automatically after restarting mitmproxy.\n" "The CA will be regenerated automatically after restarting mitmproxy.\n"
"Then make sure all your clients have the new CA installed.", "Then make sure all your clients have the new CA installed.",
"warn",
) )

View File

@ -95,11 +95,7 @@ class Command:
Call the command with a list of arguments. At this point, all Call the command with a list of arguments. At this point, all
arguments are strings. arguments are strings.
""" """
pargs = self.prepare_args(args) ret = self.func(*self.prepare_args(args))
with self.manager.master.handlecontext():
ret = self.func(*pargs)
if ret is None and self.returntype is None: if ret is None and self.returntype is None:
return return
typ = mitmproxy.types.CommandTypes.get(self.returntype) typ = mitmproxy.types.CommandTypes.get(self.returntype)

View File

@ -8,10 +8,10 @@ class Channel:
The only way for the proxy server to communicate with the master The only way for the proxy server to communicate with the master
is to use the channel it has been given. is to use the channel it has been given.
""" """
def __init__(self, loop, q, should_exit): def __init__(self, master, loop, should_exit):
self.master = master
self.loop = loop self.loop = loop
self.should_exit = should_exit self.should_exit = should_exit
self._q = q
def ask(self, mtype, m): def ask(self, mtype, m):
""" """
@ -22,7 +22,10 @@ class Channel:
exceptions.Kill: All connections should be closed immediately. exceptions.Kill: All connections should be closed immediately.
""" """
m.reply = Reply(m) m.reply = Reply(m)
asyncio.run_coroutine_threadsafe(self._q.put((mtype, m)), self.loop) asyncio.run_coroutine_threadsafe(
self.master.addons.handle_lifecycle(mtype, m),
self.loop,
)
g = m.reply.q.get() g = m.reply.q.get()
if g == exceptions.Kill: if g == exceptions.Kill:
raise exceptions.Kill() raise exceptions.Kill()
@ -34,7 +37,10 @@ class Channel:
then return immediately. then return immediately.
""" """
m.reply = DummyReply() m.reply = DummyReply()
asyncio.run_coroutine_threadsafe(self._q.put((mtype, m)), self.loop) asyncio.run_coroutine_threadsafe(
self.master.addons.handle_lifecycle(mtype, m),
self.loop,
)
NO_REPLY = object() # special object we can distinguish from a valid "None" reply. NO_REPLY = object() # special object we can distinguish from a valid "None" reply.

View File

@ -1,3 +1,5 @@
import asyncio
class LogEntry: class LogEntry:
def __init__(self, msg, level): def __init__(self, msg, level):
@ -54,7 +56,9 @@ class Log:
self(txt, "error") self(txt, "error")
def __call__(self, text, level="info"): def __call__(self, text, level="info"):
self.master.add_log(text, level) asyncio.get_event_loop().call_soon(
self.master.addons.trigger, "log", LogEntry(text, level)
)
LogTierOrder = [ LogTierOrder = [

View File

@ -1,5 +1,4 @@
import threading import threading
import contextlib
import asyncio import asyncio
import logging import logging
@ -43,15 +42,12 @@ class Master:
The master handles mitmproxy's main event loop. The master handles mitmproxy's main event loop.
""" """
def __init__(self, opts): def __init__(self, opts):
self.event_queue = asyncio.Queue()
self.should_exit = threading.Event() self.should_exit = threading.Event()
self.channel = controller.Channel( self.channel = controller.Channel(
self,
asyncio.get_event_loop(), asyncio.get_event_loop(),
self.event_queue,
self.should_exit, self.should_exit,
) )
asyncio.ensure_future(self.main())
asyncio.ensure_future(self.tick())
self.options = opts or options.Options() # type: options.Options self.options = opts or options.Options() # type: options.Options
self.commands = command.CommandManager(self) self.commands = command.CommandManager(self)
@ -59,6 +55,11 @@ class Master:
self._server = None self._server = None
self.first_tick = True self.first_tick = True
self.waiting_flows = [] self.waiting_flows = []
self.log = log.Log(self)
mitmproxy_ctx.master = self
mitmproxy_ctx.log = self.log
mitmproxy_ctx.options = self.options
@property @property
def server(self): def server(self):
@ -69,49 +70,11 @@ class Master:
server.set_channel(self.channel) server.set_channel(self.channel)
self._server = server self._server = server
@contextlib.contextmanager
def handlecontext(self):
# Handlecontexts also have to nest - leave cleanup to the outermost
if mitmproxy_ctx.master:
yield
return
mitmproxy_ctx.master = self
mitmproxy_ctx.log = log.Log(self)
mitmproxy_ctx.options = self.options
try:
yield
finally:
mitmproxy_ctx.master = None
mitmproxy_ctx.log = None
mitmproxy_ctx.options = None
# This is a vestigial function that will go away in a refactor very soon
def tell(self, mtype, m): # pragma: no cover
m.reply = controller.DummyReply()
self.event_queue.put((mtype, m))
def add_log(self, e, level):
"""
level: debug, alert, info, warn, error
"""
self.addons.trigger("log", log.LogEntry(e, level))
def start(self): def start(self):
self.should_exit.clear() self.should_exit.clear()
if self.server: if self.server:
ServerThread(self.server).start() ServerThread(self.server).start()
async def main(self):
while True:
try:
mtype, obj = await self.event_queue.get()
except RuntimeError:
return
if mtype not in eventsequence.Events: # pragma: no cover
raise exceptions.ControlException("Unknown event %s" % repr(mtype))
self.addons.handle_lifecycle(mtype, obj)
self.event_queue.task_done()
async def tick(self): async def tick(self):
if self.first_tick: if self.first_tick:
self.first_tick = False self.first_tick = False
@ -150,7 +113,7 @@ class Master:
f.request.host, f.request.port = upstream_spec.address f.request.host, f.request.port = upstream_spec.address
f.request.scheme = upstream_spec.scheme f.request.scheme = upstream_spec.scheme
def load_flow(self, f): async def load_flow(self, f):
""" """
Loads a flow and links websocket & handshake flows Loads a flow and links websocket & handshake flows
""" """
@ -168,7 +131,7 @@ class Master:
f.reply = controller.DummyReply() f.reply = controller.DummyReply()
for e, o in eventsequence.iterate(f): for e, o in eventsequence.iterate(f):
self.addons.handle_lifecycle(e, o) await self.addons.handle_lifecycle(e, o)
def replay_request( def replay_request(
self, self,

View File

@ -1,4 +1,5 @@
import contextlib import contextlib
import asyncio
import sys import sys
import mitmproxy.master import mitmproxy.master
@ -34,7 +35,7 @@ class RecordingMaster(mitmproxy.master.Master):
for i in self.logs: for i in self.logs:
print("%s: %s" % (i.level, i.msg), file=outf) print("%s: %s" % (i.level, i.msg), file=outf)
def has_log(self, txt, level=None): def _has_log(self, txt, level=None):
for i in self.logs: for i in self.logs:
if level and i.level != level: if level and i.level != level:
continue continue
@ -42,6 +43,14 @@ class RecordingMaster(mitmproxy.master.Master):
return True return True
return False return False
async def await_log(self, txt, level=None):
for i in range(20):
if self._has_log(txt, level):
return True
else:
await asyncio.sleep(0.1)
return False
def has_event(self, name): def has_event(self, name):
for i in self.events: for i in self.events:
if i[0] == name: if i[0] == name:
@ -65,7 +74,6 @@ class context:
options options
) )
self.options = self.master.options self.options = self.master.options
self.wrapped = None
if loadcore: if loadcore:
self.master.addons.add(core.Core()) self.master.addons.add(core.Core())
@ -73,20 +81,10 @@ class context:
for a in addons: for a in addons:
self.master.addons.add(a) self.master.addons.add(a)
def ctx(self):
"""
Returns a new handler context.
"""
return self.master.handlecontext()
def __enter__(self): def __enter__(self):
self.wrapped = self.ctx()
self.wrapped.__enter__()
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
self.wrapped.__exit__(exc_type, exc_value, traceback)
self.wrapped = None
return False return False
@contextlib.contextmanager @contextlib.contextmanager

View File

@ -121,7 +121,7 @@ class FlowDetails(tabs.Tabs):
viewmode, message viewmode, message
) )
if error: if error:
self.master.add_log(error, "debug") self.master.log.debug(error)
# Give hint that you have to tab for the response. # Give hint that you have to tab for the response.
if description == "No content" and isinstance(message, http.HTTPRequest): if description == "No content" and isinstance(message, http.HTTPRequest):
description = "No request content (press tab to view response)" description = "No request content (press tab to view response)"

View File

@ -4,6 +4,7 @@ import logging
import os.path import os.path
import re import re
from io import BytesIO from io import BytesIO
import asyncio
import mitmproxy.flow import mitmproxy.flow
import tornado.escape import tornado.escape
@ -235,7 +236,7 @@ class DumpFlows(RequestHandler):
self.view.clear() self.view.clear()
bio = BytesIO(self.filecontents) bio = BytesIO(self.filecontents)
for i in io.FlowReader(bio).stream(): for i in io.FlowReader(bio).stream():
self.master.load_flow(i) asyncio.call_soon(self.master.load_flow, i)
bio.close() bio.close()

View File

@ -114,17 +114,15 @@ class WebMaster(master.Master):
iol.add_callback(self.start) iol.add_callback(self.start)
web_url = "http://{}:{}/".format(self.options.web_iface, self.options.web_port) web_url = "http://{}:{}/".format(self.options.web_iface, self.options.web_port)
self.add_log( self.log.info(
"Web server listening at {}".format(web_url), "Web server listening at {}".format(web_url),
"info"
) )
if self.options.web_open_browser: if self.options.web_open_browser:
success = open_browser(web_url) success = open_browser(web_url)
if not success: if not success:
self.add_log( self.log.info(
"No web browser found. Please open a browser and point it to {}".format(web_url), "No web browser found. Please open a browser and point it to {}".format(web_url),
"info"
) )
try: try:
iol.start() iol.start()

View File

@ -88,6 +88,7 @@ setup(
"flake8>=3.5, <3.6", "flake8>=3.5, <3.6",
"Flask>=0.10.1, <0.13", "Flask>=0.10.1, <0.13",
"mypy>=0.580,<0.581", "mypy>=0.580,<0.581",
"pytest-asyncio>=0.8",
"pytest-cov>=2.5.1,<3", "pytest-cov>=2.5.1,<3",
"pytest-faulthandler>=1.3.1,<2", "pytest-faulthandler>=1.3.1,<2",
"pytest-timeout>=1.2.1,<2", "pytest-timeout>=1.2.1,<2",

View File

@ -17,7 +17,8 @@ from mitmproxy.test import taddons
(False, "fe80::", False), (False, "fe80::", False),
(False, "2001:4860:4860::8888", True), (False, "2001:4860:4860::8888", True),
]) ])
def test_allowremote(allow_remote, ip, should_be_killed): @pytest.mark.asyncio
async def test_allowremote(allow_remote, ip, should_be_killed):
ar = allowremote.AllowRemote() ar = allowremote.AllowRemote()
up = proxyauth.ProxyAuth() up = proxyauth.ProxyAuth()
with taddons.context(ar, up) as tctx: with taddons.context(ar, up) as tctx:
@ -28,7 +29,7 @@ def test_allowremote(allow_remote, ip, should_be_killed):
ar.clientconnect(layer) ar.clientconnect(layer)
if should_be_killed: if should_be_killed:
assert tctx.master.has_log("Client connection was killed", "warn") assert await tctx.master.await_log("Client connection was killed", "warn")
else: else:
assert tctx.master.logs == [] assert tctx.master.logs == []
tctx.master.clear() tctx.master.clear()

View File

@ -1,31 +1,33 @@
from unittest import mock from unittest import mock
import pytest
from mitmproxy.addons import browser from mitmproxy.addons import browser
from mitmproxy.test import taddons from mitmproxy.test import taddons
def test_browser(): @pytest.mark.asyncio
async def test_browser():
with mock.patch("subprocess.Popen") as po, mock.patch("shutil.which") as which: with mock.patch("subprocess.Popen") as po, mock.patch("shutil.which") as which:
which.return_value = "chrome" which.return_value = "chrome"
b = browser.Browser() b = browser.Browser()
with taddons.context() as tctx: with taddons.context() as tctx:
b.start() b.start()
assert po.called assert po.called
b.start()
assert not tctx.master.has_log("already running") b.start()
b.browser.poll = lambda: None b.browser.poll = lambda: None
b.start() b.start()
assert tctx.master.has_log("already running") assert await tctx.master.await_log("already running")
b.done() b.done()
assert not b.browser assert not b.browser
def test_no_browser(): @pytest.mark.asyncio
async def test_no_browser():
with mock.patch("shutil.which") as which: with mock.patch("shutil.which") as which:
which.return_value = False which.return_value = False
b = browser.Browser() b = browser.Browser()
with taddons.context() as tctx: with taddons.context() as tctx:
b.start() b.start()
assert tctx.master.has_log("platform is not supported") assert await tctx.master.await_log("platform is not supported")

View File

@ -8,12 +8,15 @@ from mitmproxy.test import taddons
class TestCheckCA: class TestCheckCA:
@pytest.mark.parametrize('expired', [False, True]) @pytest.mark.parametrize('expired', [False, True])
def test_check_ca(self, expired): @pytest.mark.asyncio
async def test_check_ca(self, expired):
msg = 'The mitmproxy certificate authority has expired!' msg = 'The mitmproxy certificate authority has expired!'
with taddons.context() as tctx: with taddons.context() as tctx:
tctx.master.server = mock.MagicMock() tctx.master.server = mock.MagicMock()
tctx.master.server.config.certstore.default_ca.has_expired = mock.MagicMock(return_value=expired) tctx.master.server.config.certstore.default_ca.has_expired = mock.MagicMock(
return_value = expired
)
a = check_ca.CheckCA() a = check_ca.CheckCA()
tctx.configure(a) tctx.configure(a)
assert tctx.master.has_log(msg) is expired assert await tctx.master.await_log(msg) == expired

View File

@ -71,7 +71,8 @@ def qr(f):
return fp.read() return fp.read()
def test_cut_clip(): @pytest.mark.asyncio
async def test_cut_clip():
v = view.View() v = view.View()
c = cut.Cut() c = cut.Cut()
with taddons.context() as tctx: with taddons.context() as tctx:
@ -95,7 +96,7 @@ def test_cut_clip():
"copy/paste mechanism for your system." "copy/paste mechanism for your system."
pc.side_effect = pyperclip.PyperclipException(log_message) pc.side_effect = pyperclip.PyperclipException(log_message)
tctx.command(c.clip, "@all", "request.method") tctx.command(c.clip, "@all", "request.method")
assert tctx.master.has_log(log_message, level="error") assert await tctx.master.await_log(log_message, level="error")
def test_cut_save(tmpdir): def test_cut_save(tmpdir):
@ -125,7 +126,8 @@ def test_cut_save(tmpdir):
(IsADirectoryError, "Is a directory"), (IsADirectoryError, "Is a directory"),
(FileNotFoundError, "No such file or directory") (FileNotFoundError, "No such file or directory")
]) ])
def test_cut_save_open(exception, log_message, tmpdir): @pytest.mark.asyncio
async def test_cut_save_open(exception, log_message, tmpdir):
f = str(tmpdir.join("path")) f = str(tmpdir.join("path"))
v = view.View() v = view.View()
c = cut.Cut() c = cut.Cut()
@ -136,7 +138,7 @@ def test_cut_save_open(exception, log_message, tmpdir):
with mock.patch("mitmproxy.addons.cut.open") as m: with mock.patch("mitmproxy.addons.cut.open") as m:
m.side_effect = exception(log_message) m.side_effect = exception(log_message)
tctx.command(c.save, "@all", "request.method", f) tctx.command(c.save, "@all", "request.method", f)
assert tctx.master.has_log(log_message, level="error") assert await tctx.master.await_log(log_message, level="error")
def test_cut(): def test_cut():

View File

@ -141,15 +141,16 @@ def test_echo_request_line():
class TestContentView: class TestContentView:
@mock.patch("mitmproxy.contentviews.auto.ViewAuto.__call__") @pytest.mark.asyncio
def test_contentview(self, view_auto): async def test_contentview(self):
view_auto.side_effect = exceptions.ContentViewException("") with mock.patch("mitmproxy.contentviews.auto.ViewAuto.__call__") as va:
va.side_effect = exceptions.ContentViewException("")
sio = io.StringIO() sio = io.StringIO()
d = dumper.Dumper(sio) d = dumper.Dumper(sio)
with taddons.context(d) as ctx: with taddons.context(d) as ctx:
ctx.configure(d, flow_detail=4) ctx.configure(d, flow_detail=4)
d.response(tflow.tflow()) d.response(tflow.tflow())
assert ctx.master.has_log("content viewer failed") assert await ctx.master.await_log("content viewer failed")
def test_tcp(): def test_tcp():

View File

@ -125,17 +125,19 @@ def test_export(tmpdir):
(IsADirectoryError, "Is a directory"), (IsADirectoryError, "Is a directory"),
(FileNotFoundError, "No such file or directory") (FileNotFoundError, "No such file or directory")
]) ])
def test_export_open(exception, log_message, tmpdir): @pytest.mark.asyncio
async def test_export_open(exception, log_message, tmpdir):
f = str(tmpdir.join("path")) f = str(tmpdir.join("path"))
e = export.Export() e = export.Export()
with taddons.context() as tctx: with taddons.context() as tctx:
with mock.patch("mitmproxy.addons.export.open") as m: with mock.patch("mitmproxy.addons.export.open") as m:
m.side_effect = exception(log_message) m.side_effect = exception(log_message)
e.file("raw", tflow.tflow(resp=True), f) e.file("raw", tflow.tflow(resp=True), f)
assert tctx.master.has_log(log_message, level="error") assert await tctx.master.await_log(log_message, level="error")
def test_clip(tmpdir): @pytest.mark.asyncio
async def test_clip(tmpdir):
e = export.Export() e = export.Export()
with taddons.context() as tctx: with taddons.context() as tctx:
with pytest.raises(exceptions.CommandError): with pytest.raises(exceptions.CommandError):
@ -158,4 +160,4 @@ def test_clip(tmpdir):
"copy/paste mechanism for your system." "copy/paste mechanism for your system."
pc.side_effect = pyperclip.PyperclipException(log_message) pc.side_effect = pyperclip.PyperclipException(log_message)
e.clip("raw", tflow.tflow(resp=True)) e.clip("raw", tflow.tflow(resp=True))
assert tctx.master.has_log(log_message, level="error") assert await tctx.master.await_log(log_message, level="error")

View File

@ -55,26 +55,28 @@ class TestReadFile:
with pytest.raises(exceptions.OptionsError): with pytest.raises(exceptions.OptionsError):
rf.running() rf.running()
@mock.patch('mitmproxy.master.Master.load_flow') @pytest.mark.asyncio
def test_corrupt(self, mck, corrupt_data): async def test_corrupt(self, corrupt_data):
rf = readfile.ReadFile() rf = readfile.ReadFile()
with taddons.context(rf) as tctx: with taddons.context(rf) as tctx:
with mock.patch('mitmproxy.master.Master.load_flow') as mck:
with pytest.raises(exceptions.FlowReadException): with pytest.raises(exceptions.FlowReadException):
rf.load_flows(io.BytesIO(b"qibble")) rf.load_flows(io.BytesIO(b"qibble"))
assert not mck.called assert not mck.called
assert len(tctx.master.logs) == 1
tctx.master.clear()
with pytest.raises(exceptions.FlowReadException): with pytest.raises(exceptions.FlowReadException):
rf.load_flows(corrupt_data) rf.load_flows(corrupt_data)
assert await tctx.master.await_log("file corrupted")
assert mck.called assert mck.called
assert len(tctx.master.logs) == 2
def test_nonexisting_file(self): @pytest.mark.asyncio
async def test_nonexisting_file(self):
rf = readfile.ReadFile() rf = readfile.ReadFile()
with taddons.context(rf) as tctx: with taddons.context(rf) as tctx:
with pytest.raises(exceptions.FlowReadException): with pytest.raises(exceptions.FlowReadException):
rf.load_flows_from_path("nonexistent") rf.load_flows_from_path("nonexistent")
assert len(tctx.master.logs) == 1 assert await tctx.master.await_log("nonexistent")
class TestReadFileStdin: class TestReadFileStdin:

View File

@ -79,7 +79,8 @@ class TestReplaceFile:
r.request(f) r.request(f)
assert f.request.content == b"bar" assert f.request.content == b"bar"
def test_nonexistent(self, tmpdir): @pytest.mark.asyncio
async def test_nonexistent(self, tmpdir):
r = replace.Replace() r = replace.Replace()
with taddons.context(r) as tctx: with taddons.context(r) as tctx:
with pytest.raises(Exception, match="Invalid file path"): with pytest.raises(Exception, match="Invalid file path"):
@ -97,6 +98,5 @@ class TestReplaceFile:
tmpfile.remove() tmpfile.remove()
f = tflow.tflow() f = tflow.tflow()
f.request.content = b"foo" f.request.content = b"foo"
assert not tctx.master.logs
r.request(f) r.request(f)
assert tctx.master.logs assert await tctx.master.await_log("could not read")

View File

@ -1,13 +1,11 @@
import os import os
import sys import sys
import traceback import traceback
from unittest import mock
import pytest import pytest
from mitmproxy import addonmanager from mitmproxy import addonmanager
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy import log
from mitmproxy.addons import script from mitmproxy.addons import script
from mitmproxy.test import taddons from mitmproxy.test import taddons
from mitmproxy.test import tflow from mitmproxy.test import tflow
@ -49,17 +47,15 @@ def test_load_fullname():
assert not hasattr(ns2, "addons") assert not hasattr(ns2, "addons")
def test_script_print_stdout(): @pytest.mark.asyncio
async def test_script_print_stdout():
with taddons.context() as tctx: with taddons.context() as tctx:
with mock.patch('mitmproxy.ctx.master.tell') as mock_warn:
with addonmanager.safecall(): with addonmanager.safecall():
ns = script.load_script( ns = script.load_script(
tutils.test_data.path( tutils.test_data.path("mitmproxy/data/addonscripts/print.py")
"mitmproxy/data/addonscripts/print.py"
)
) )
ns.load(addonmanager.Loader(tctx.master)) ns.load(addonmanager.Loader(tctx.master))
mock_warn.assert_called_once_with("log", log.LogEntry("stdoutprint", "warn")) assert await tctx.master.await_log("stdoutprint")
class TestScript: class TestScript:
@ -101,7 +97,8 @@ class TestScript:
assert rec.call_log[0][1] == "request" assert rec.call_log[0][1] == "request"
def test_reload(self, tmpdir): @pytest.mark.asyncio
async def test_reload(self, tmpdir):
with taddons.context() as tctx: with taddons.context() as tctx:
f = tmpdir.join("foo.py") f = tmpdir.join("foo.py")
f.ensure(file=True) f.ensure(file=True)
@ -109,15 +106,15 @@ class TestScript:
sc = script.Script(str(f)) sc = script.Script(str(f))
tctx.configure(sc) tctx.configure(sc)
sc.tick() sc.tick()
assert tctx.master.has_log("Loading") assert await tctx.master.await_log("Loading")
tctx.master.clear() tctx.master.clear()
assert not tctx.master.has_log("Loading")
sc.last_load, sc.last_mtime = 0, 0 sc.last_load, sc.last_mtime = 0, 0
sc.tick() sc.tick()
assert tctx.master.has_log("Loading") assert await tctx.master.await_log("Loading")
def test_exception(self): @pytest.mark.asyncio
async def test_exception(self):
with taddons.context() as tctx: with taddons.context() as tctx:
sc = script.Script( sc = script.Script(
tutils.test_data.path("mitmproxy/data/addonscripts/error.py") tutils.test_data.path("mitmproxy/data/addonscripts/error.py")
@ -129,8 +126,8 @@ class TestScript:
f = tflow.tflow(resp=True) f = tflow.tflow(resp=True)
tctx.master.addons.trigger("request", f) tctx.master.addons.trigger("request", f)
assert tctx.master.has_log("ValueError: Error!") assert await tctx.master.await_log("ValueError: Error!")
assert tctx.master.has_log("error.py") assert await tctx.master.await_log("error.py")
def test_addon(self): def test_addon(self):
with taddons.context() as tctx: with taddons.context() as tctx:
@ -166,13 +163,15 @@ class TestCutTraceback:
class TestScriptLoader: class TestScriptLoader:
def test_script_run(self): @pytest.mark.asyncio
async def test_script_run(self):
rp = tutils.test_data.path( rp = tutils.test_data.path(
"mitmproxy/data/addonscripts/recorder/recorder.py" "mitmproxy/data/addonscripts/recorder/recorder.py"
) )
sc = script.ScriptLoader() sc = script.ScriptLoader()
with taddons.context(sc) as tctx: with taddons.context(sc) as tctx:
sc.script_run([tflow.tflow(resp=True)], rp) sc.script_run([tflow.tflow(resp=True)], rp)
await tctx.master.await_log("recorder response")
debug = [i.msg for i in tctx.master.logs if i.level == "debug"] debug = [i.msg for i in tctx.master.logs if i.level == "debug"]
assert debug == [ assert debug == [
'recorder load', 'recorder running', 'recorder configure', 'recorder load', 'recorder running', 'recorder configure',
@ -181,11 +180,12 @@ class TestScriptLoader:
'recorder responseheaders', 'recorder response' 'recorder responseheaders', 'recorder response'
] ]
def test_script_run_nonexistent(self): @pytest.mark.asyncio
async def test_script_run_nonexistent(self):
sc = script.ScriptLoader() sc = script.ScriptLoader()
with taddons.context(sc) as tctx: with taddons.context(sc) as tctx:
sc.script_run([tflow.tflow(resp=True)], "/") sc.script_run([tflow.tflow(resp=True)], "/")
tctx.master.has_log("/: No such script") assert await tctx.master.await_log("/: No such script")
def test_simple(self): def test_simple(self):
sc = script.ScriptLoader() sc = script.ScriptLoader()
@ -243,19 +243,21 @@ class TestScriptLoader:
tctx.invoke(sc, "tick") tctx.invoke(sc, "tick")
assert len(tctx.master.addons) == 1 assert len(tctx.master.addons) == 1
def test_script_error_handler(self): @pytest.mark.asyncio
async def test_script_error_handler(self):
path = "/sample/path/example.py" path = "/sample/path/example.py"
exc = SyntaxError exc = SyntaxError
msg = "Error raised" msg = "Error raised"
tb = True tb = True
with taddons.context() as tctx: with taddons.context() as tctx:
script.script_error_handler(path, exc, msg, tb) script.script_error_handler(path, exc, msg, tb)
assert tctx.master.has_log("/sample/path/example.py") assert await tctx.master.await_log("/sample/path/example.py")
assert tctx.master.has_log("Error raised") assert await tctx.master.await_log("Error raised")
assert tctx.master.has_log("lineno") assert await tctx.master.await_log("lineno")
assert tctx.master.has_log("NoneType") assert await tctx.master.await_log("NoneType")
def test_order(self): @pytest.mark.asyncio
async def test_order(self):
rec = tutils.test_data.path("mitmproxy/data/addonscripts/recorder") rec = tutils.test_data.path("mitmproxy/data/addonscripts/recorder")
sc = script.ScriptLoader() sc = script.ScriptLoader()
sc.is_running = True sc.is_running = True
@ -269,6 +271,7 @@ class TestScriptLoader:
] ]
) )
tctx.master.addons.invoke_addon(sc, "tick") tctx.master.addons.invoke_addon(sc, "tick")
await tctx.master.await_log("c tick")
debug = [i.msg for i in tctx.master.logs if i.level == "debug"] debug = [i.msg for i in tctx.master.logs if i.level == "debug"]
assert debug == [ assert debug == [
'a load', 'a load',
@ -287,7 +290,7 @@ class TestScriptLoader:
'c tick', 'c tick',
] ]
tctx.master.logs = [] tctx.master.clear()
tctx.configure( tctx.configure(
sc, sc,
scripts = [ scripts = [
@ -297,6 +300,7 @@ class TestScriptLoader:
] ]
) )
await tctx.master.await_log("c configure")
debug = [i.msg for i in tctx.master.logs if i.level == "debug"] debug = [i.msg for i in tctx.master.logs if i.level == "debug"]
assert debug == [ assert debug == [
'c configure', 'c configure',
@ -313,6 +317,7 @@ class TestScriptLoader:
] ]
) )
tctx.master.addons.invoke_addon(sc, "tick") tctx.master.addons.invoke_addon(sc, "tick")
await tctx.master.await_log("a tick")
debug = [i.msg for i in tctx.master.logs if i.level == "debug"] debug = [i.msg for i in tctx.master.logs if i.level == "debug"]
assert debug == [ assert debug == [

View File

@ -1,15 +1,17 @@
import pytest
from mitmproxy import proxy from mitmproxy import proxy
from mitmproxy.addons import termstatus from mitmproxy.addons import termstatus
from mitmproxy.test import taddons from mitmproxy.test import taddons
def test_configure(): @pytest.mark.asyncio
async def test_configure():
ts = termstatus.TermStatus() ts = termstatus.TermStatus()
with taddons.context() as ctx: with taddons.context() as ctx:
ctx.master.server = proxy.DummyServer() ctx.master.server = proxy.DummyServer()
ctx.configure(ts, server=False) ctx.configure(ts, server=False)
ts.running() ts.running()
assert not ctx.master.logs
ctx.configure(ts, server=True) ctx.configure(ts, server=True)
ts.running() ts.running()
assert ctx.master.logs await ctx.master.await_log("server listening")

View File

@ -159,7 +159,8 @@ def test_orders():
assert v.order_options() assert v.order_options()
def test_load(tmpdir): @pytest.mark.asyncio
async def test_load(tmpdir):
path = str(tmpdir.join("path")) path = str(tmpdir.join("path"))
v = view.View() v = view.View()
with taddons.context() as tctx: with taddons.context() as tctx:
@ -182,7 +183,7 @@ def test_load(tmpdir):
with open(path, "wb") as f: with open(path, "wb") as f:
f.write(b"invalidflows") f.write(b"invalidflows")
v.load_file(path) v.load_file(path)
assert tctx.master.has_log("Invalid data format.") assert await tctx.master.await_log("Invalid data format.")
def test_resolve(): def test_resolve():

View File

@ -3,7 +3,6 @@ import os
import struct import struct
import tempfile import tempfile
import traceback import traceback
import time
from mitmproxy import options from mitmproxy import options
from mitmproxy import exceptions from mitmproxy import exceptions
@ -48,6 +47,7 @@ class _WebSocketServerBase(net_tservers.ServerTestBase):
class _WebSocketTestBase: class _WebSocketTestBase:
client = None
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
@ -286,7 +286,8 @@ class TestPing(_WebSocketTest):
wfile.flush() wfile.flush()
websockets.Frame.from_file(rfile) websockets.Frame.from_file(rfile)
def test_ping(self): @pytest.mark.asyncio
async def test_ping(self):
self.setup_connection() self.setup_connection()
frame = websockets.Frame.from_file(self.client.rfile) frame = websockets.Frame.from_file(self.client.rfile)
@ -296,7 +297,7 @@ class TestPing(_WebSocketTest):
assert frame.header.opcode == websockets.OPCODE.PING assert frame.header.opcode == websockets.OPCODE.PING
assert frame.payload == b'' # We don't send payload to other end assert frame.payload == b'' # We don't send payload to other end
assert self.master.has_log("Pong Received from server", "info") assert await self.master.await_log("Pong Received from server", "info")
class TestPong(_WebSocketTest): class TestPong(_WebSocketTest):
@ -314,7 +315,8 @@ class TestPong(_WebSocketTest):
wfile.flush() wfile.flush()
websockets.Frame.from_file(rfile) websockets.Frame.from_file(rfile)
def test_pong(self): @pytest.mark.asyncio
async def test_pong(self):
self.setup_connection() self.setup_connection()
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
@ -327,12 +329,7 @@ class TestPong(_WebSocketTest):
assert frame.header.opcode == websockets.OPCODE.PONG assert frame.header.opcode == websockets.OPCODE.PONG
assert frame.payload == b'foobar' assert frame.payload == b'foobar'
for i in range(20): assert await self.master.await_log("pong received")
if self.master.has_log("Pong Received from server", "info"):
break
time.sleep(0.01)
else:
raise AssertionError("No pong seen")
class TestClose(_WebSocketTest): class TestClose(_WebSocketTest):

View File

@ -1,3 +1,4 @@
import asyncio
import os import os
import socket import socket
import time import time
@ -123,8 +124,6 @@ class TcpMixin:
i2 = self.pathod("306") i2 = self.pathod("306")
self._ignore_off() self._ignore_off()
self.master.event_queue.join()
assert n.status_code == 304 assert n.status_code == 304
assert i.status_code == 305 assert i.status_code == 305
assert i2.status_code == 306 assert i2.status_code == 306
@ -168,8 +167,6 @@ class TcpMixin:
i2 = self.pathod("306") i2 = self.pathod("306")
self._tcpproxy_off() self._tcpproxy_off()
self.master.event_queue.join()
assert n.status_code == 304 assert n.status_code == 304
assert i.status_code == 305 assert i.status_code == 305
assert i2.status_code == 306 assert i2.status_code == 306
@ -238,13 +235,14 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin):
assert p.request(req) assert p.request(req)
assert p.request(req) assert p.request(req)
def test_get_connection_switching(self): @pytest.mark.asyncio
async def test_get_connection_switching(self):
req = "get:'%s/p/200:b@1'" req = "get:'%s/p/200:b@1'"
p = self.pathoc() p = self.pathoc()
with p.connect(): with p.connect():
assert p.request(req % self.server.urlbase) assert p.request(req % self.server.urlbase)
assert p.request(req % self.server2.urlbase) assert p.request(req % self.server2.urlbase)
assert self.proxy.tmaster.has_log("serverdisconnect") assert await self.proxy.tmaster.await_log("serverdisconnect")
def test_blank_leading_line(self): def test_blank_leading_line(self):
p = self.pathoc() p = self.pathoc()
@ -447,13 +445,14 @@ class TestReverse(tservers.ReverseProxyTest, CommonMixin, TcpMixin):
req = self.master.state.flows[0].request req = self.master.state.flows[0].request
assert req.host_header == "127.0.0.1" assert req.host_header == "127.0.0.1"
def test_selfconnection(self): @pytest.mark.asyncio
async def test_selfconnection(self):
self.options.mode = "reverse:http://127.0.0.1:0" self.options.mode = "reverse:http://127.0.0.1:0"
p = self.pathoc() p = self.pathoc()
with p.connect(): with p.connect():
p.request("get:/") p.request("get:/")
assert self.master.has_log("The proxy shall not connect to itself.") assert await self.master.await_log("The proxy shall not connect to itself.")
class TestReverseSSL(tservers.ReverseProxyTest, CommonMixin, TcpMixin): class TestReverseSSL(tservers.ReverseProxyTest, CommonMixin, TcpMixin):
@ -553,7 +552,6 @@ class TestHttps2Http(tservers.ReverseProxyTest):
p = self.pathoc(ssl=True, sni="example.com") p = self.pathoc(ssl=True, sni="example.com")
with p.connect(): with p.connect():
assert p.request("get:'/p/200'").status_code == 200 assert p.request("get:'/p/200'").status_code == 200
assert not self.proxy.tmaster.has_log("error in handle_sni")
def test_http(self): def test_http(self):
p = self.pathoc(ssl=False) p = self.pathoc(ssl=False)
@ -818,11 +816,13 @@ class TestServerConnect(tservers.HTTPProxyTest):
opts.upstream_cert = False opts.upstream_cert = False
return opts return opts
def test_unnecessary_serverconnect(self): @pytest.mark.asyncio
async def test_unnecessary_serverconnect(self):
"""A replayed/fake response with no upstream_cert should not connect to an upstream server""" """A replayed/fake response with no upstream_cert should not connect to an upstream server"""
self.set_addons(AFakeResponse()) self.set_addons(AFakeResponse())
assert self.pathod("200").status_code == 200 assert self.pathod("200").status_code == 200
assert not self.proxy.tmaster.has_log("serverconnect") asyncio.sleep(0.1)
assert not self.proxy.tmaster._has_log("serverconnect")
class AKillRequest: class AKillRequest:

View File

@ -1,3 +1,5 @@
import pytest
from mitmproxy.test import tflow from mitmproxy.test import tflow
from mitmproxy.test import tutils from mitmproxy.test import tutils
from mitmproxy.test import taddons from mitmproxy.test import taddons
@ -31,14 +33,15 @@ class TestConcurrent(tservers.MasterTest):
return return
raise ValueError("Script never acked") raise ValueError("Script never acked")
def test_concurrent_err(self): @pytest.mark.asyncio
async def test_concurrent_err(self):
with taddons.context() as tctx: with taddons.context() as tctx:
tctx.script( tctx.script(
tutils.test_data.path( tutils.test_data.path(
"mitmproxy/data/addonscripts/concurrent_decorator_err.py" "mitmproxy/data/addonscripts/concurrent_decorator_err.py"
) )
) )
assert tctx.master.has_log("decorator not supported") assert await tctx.master.await_log("decorator not supported")
def test_concurrent_class(self): def test_concurrent_class(self):
with taddons.context() as tctx: with taddons.context() as tctx:

View File

@ -1,4 +1,6 @@
import pytest import pytest
from unittest import mock
from mitmproxy import addons from mitmproxy import addons
from mitmproxy import addonmanager from mitmproxy import addonmanager
@ -65,7 +67,8 @@ def test_halt():
assert end.custom_called assert end.custom_called
def test_lifecycle(): @pytest.mark.asyncio
async def test_lifecycle():
o = options.Options() o = options.Options()
m = master.Master(o) m = master.Master(o)
a = addonmanager.AddonManager(m) a = addonmanager.AddonManager(m)
@ -77,7 +80,7 @@ def test_lifecycle():
a.remove(TAddon("nonexistent")) a.remove(TAddon("nonexistent"))
f = tflow.tflow() f = tflow.tflow()
a.handle_lifecycle("request", f) await a.handle_lifecycle("request", f)
a._configure_all(o, o.keys()) a._configure_all(o, o.keys())
@ -86,19 +89,21 @@ def test_defaults():
assert addons.default_addons() assert addons.default_addons()
def test_loader(): @pytest.mark.asyncio
async def test_loader():
with taddons.context() as tctx: with taddons.context() as tctx:
with mock.patch("mitmproxy.ctx.log.warn") as warn:
l = addonmanager.Loader(tctx.master) l = addonmanager.Loader(tctx.master)
l.add_option("custom_option", bool, False, "help") l.add_option("custom_option", bool, False, "help")
assert "custom_option" in l.master.options assert "custom_option" in l.master.options
# calling this again with the same signature is a no-op. # calling this again with the same signature is a no-op.
l.add_option("custom_option", bool, False, "help") l.add_option("custom_option", bool, False, "help")
assert not tctx.master.has_log("Over-riding existing option") assert not warn.called
# a different signature should emit a warning though. # a different signature should emit a warning though.
l.add_option("custom_option", bool, True, "help") l.add_option("custom_option", bool, True, "help")
assert tctx.master.has_log("Over-riding existing option") assert warn.called
def cmd(a: str) -> str: def cmd(a: str) -> str:
return "foo" return "foo"
@ -106,7 +111,8 @@ def test_loader():
l.add_command("test.command", cmd) l.add_command("test.command", cmd)
def test_simple(): @pytest.mark.asyncio
async def test_simple():
with taddons.context(loadcore=False) as tctx: with taddons.context(loadcore=False) as tctx:
a = tctx.master.addons a = tctx.master.addons
@ -120,14 +126,14 @@ def test_simple():
assert not a.chain assert not a.chain
a.add(TAddon("one")) a.add(TAddon("one"))
a.trigger("done") a.trigger("running")
a.trigger("tick") a.trigger("tick")
assert tctx.master.has_log("not callable") assert await tctx.master.await_log("not callable")
tctx.master.clear() tctx.master.clear()
a.get("one").tick = addons a.get("one").tick = addons
a.trigger("tick") a.trigger("tick")
assert not tctx.master.has_log("not callable") assert not await tctx.master.await_log("not callable")
a.remove(a.get("one")) a.remove(a.get("one"))
assert not a.get("one") assert not a.get("one")

View File

@ -5,12 +5,11 @@ import pytest
from mitmproxy.exceptions import Kill, ControlException from mitmproxy.exceptions import Kill, ControlException
from mitmproxy import controller from mitmproxy import controller
from mitmproxy.test import taddons from mitmproxy.test import taddons
import mitmproxy.ctx
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_master(): async def test_master():
class TMsg:
pass
class tAddon: class tAddon:
def log(self, _): def log(self, _):
@ -20,12 +19,11 @@ async def test_master():
assert not ctx.master.should_exit.is_set() assert not ctx.master.should_exit.is_set()
async def test(): async def test():
msg = TMsg() mitmproxy.ctx.log("test")
msg.reply = controller.DummyReply()
await ctx.master.channel.tell("log", msg)
asyncio.ensure_future(test()) asyncio.ensure_future(test())
assert not ctx.master.should_exit.is_set() assert await ctx.master.await_log("test")
assert ctx.master.should_exit.is_set()
class TestReply: class TestReply:

View File

@ -2,7 +2,7 @@ import io
from unittest import mock from unittest import mock
import pytest import pytest
from mitmproxy.test import tflow, tutils from mitmproxy.test import tflow, tutils, taddons
import mitmproxy.io import mitmproxy.io
from mitmproxy import flowfilter from mitmproxy import flowfilter
from mitmproxy import options from mitmproxy import options
@ -97,27 +97,27 @@ class TestSerialize:
class TestFlowMaster: class TestFlowMaster:
def test_load_http_flow_reverse(self): @pytest.mark.asyncio
s = tservers.TestState() async def test_load_http_flow_reverse(self):
opts = options.Options( opts = options.Options(
mode="reverse:https://use-this-domain" mode="reverse:https://use-this-domain"
) )
fm = master.Master(opts) s = tservers.TestState()
fm.addons.add(s) with taddons.context(s, options=opts) as ctx:
f = tflow.tflow(resp=True) f = tflow.tflow(resp=True)
fm.load_flow(f) await ctx.master.load_flow(f)
assert s.flows[0].request.host == "use-this-domain" assert s.flows[0].request.host == "use-this-domain"
def test_load_websocket_flow(self): @pytest.mark.asyncio
s = tservers.TestState() async def test_load_websocket_flow(self):
opts = options.Options( opts = options.Options(
mode="reverse:https://use-this-domain" mode="reverse:https://use-this-domain"
) )
fm = master.Master(opts) s = tservers.TestState()
fm.addons.add(s) with taddons.context(s, options=opts) as ctx:
f = tflow.twebsocketflow() f = tflow.twebsocketflow()
fm.load_flow(f.handshake_flow) await ctx.master.load_flow(f.handshake_flow)
fm.load_flow(f) await ctx.master.load_flow(f)
assert s.flows[0].request.host == "use-this-domain" assert s.flows[0].request.host == "use-this-domain"
assert s.flows[1].handshake_flow == f.handshake_flow assert s.flows[1].handshake_flow == f.handshake_flow
assert len(s.flows[1].messages) == len(f.messages) assert len(s.flows[1].messages) == len(f.messages)
@ -150,31 +150,27 @@ class TestFlowMaster:
assert rt.f.request.http_version == "HTTP/1.1" assert rt.f.request.http_version == "HTTP/1.1"
assert ":authority" not in rt.f.request.headers assert ":authority" not in rt.f.request.headers
def test_all(self): @pytest.mark.asyncio
async def test_all(self):
opts = options.Options(
mode="reverse:https://use-this-domain"
)
s = tservers.TestState() s = tservers.TestState()
fm = master.Master(None) with taddons.context(s, options=opts) as ctx:
fm.addons.add(s)
f = tflow.tflow(req=None) f = tflow.tflow(req=None)
fm.addons.handle_lifecycle("clientconnect", f.client_conn) await ctx.master.addons.handle_lifecycle("clientconnect", f.client_conn)
f.request = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq()) f.request = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq())
fm.addons.handle_lifecycle("request", f) await ctx.master.addons.handle_lifecycle("request", f)
assert len(s.flows) == 1 assert len(s.flows) == 1
f.response = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp()) f.response = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp())
fm.addons.handle_lifecycle("response", f) await ctx.master.addons.handle_lifecycle("response", f)
assert len(s.flows) == 1 assert len(s.flows) == 1
fm.addons.handle_lifecycle("clientdisconnect", f.client_conn) await ctx.master.addons.handle_lifecycle("clientdisconnect", f.client_conn)
f.error = flow.Error("msg") f.error = flow.Error("msg")
fm.addons.handle_lifecycle("error", f) await ctx.master.addons.handle_lifecycle("error", f)
# FIXME: This no longer works, because we consume on the main loop.
# fm.tell("foo", f)
# with pytest.raises(ControlException):
# fm.addons.trigger("unknown")
fm.shutdown()
class TestError: class TestError:

View File

@ -1,21 +1,27 @@
import io import io
import pytest
from mitmproxy.test import taddons from mitmproxy.test import taddons
from mitmproxy.test import tutils from mitmproxy.test import tutils
from mitmproxy import ctx from mitmproxy import ctx
def test_recordingmaster(): @pytest.mark.asyncio
async def test_recordingmaster():
with taddons.context() as tctx: with taddons.context() as tctx:
assert not tctx.master.has_log("nonexistent") assert not tctx.master._has_log("nonexistent")
assert not tctx.master.has_event("nonexistent") assert not tctx.master.has_event("nonexistent")
ctx.log.error("foo") ctx.log.error("foo")
assert not tctx.master.has_log("foo", level="debug") assert not tctx.master._has_log("foo", level="debug")
assert tctx.master.has_log("foo", level="error") assert await tctx.master.await_log("foo", level="error")
def test_dumplog(): @pytest.mark.asyncio
async def test_dumplog():
with taddons.context() as tctx: with taddons.context() as tctx:
ctx.log.info("testing") ctx.log.info("testing")
await ctx.master.await_log("testing")
s = io.StringIO() s = io.StringIO()
tctx.master.dump_log(s) tctx.master.dump_log(s)
assert s.getvalue() assert s.getvalue()

View File

@ -4,13 +4,16 @@ from mitmproxy.tools.console import keymap
from mitmproxy.tools.console import master from mitmproxy.tools.console import master
from mitmproxy import command from mitmproxy import command
import pytest
def test_commands_exist():
@pytest.mark.asyncio
async def test_commands_exist():
km = keymap.Keymap(None) km = keymap.Keymap(None)
defaultkeys.map(km) defaultkeys.map(km)
assert km.bindings assert km.bindings
m = master.ConsoleMaster(None) m = master.ConsoleMaster(None)
m.load_flow(tflow()) await m.load_flow(tflow())
for binding in km.bindings: for binding in km.bindings:
cmd, *args = command.lexer(binding.command) cmd, *args = command.lexer(binding.command)

View File

@ -4,7 +4,10 @@ from mitmproxy import options
from mitmproxy.tools import console from mitmproxy.tools import console
from ... import tservers from ... import tservers
import pytest
@pytest.mark.asyncio
class TestMaster(tservers.MasterTest): class TestMaster(tservers.MasterTest):
def mkmaster(self, **opts): def mkmaster(self, **opts):
o = options.Options(**opts) o = options.Options(**opts)
@ -12,11 +15,11 @@ class TestMaster(tservers.MasterTest):
m.addons.trigger("configure", o.keys()) m.addons.trigger("configure", o.keys())
return m return m
def test_basic(self): async def test_basic(self):
m = self.mkmaster() m = self.mkmaster()
for i in (1, 2, 3): for i in (1, 2, 3):
try: try:
self.dummy_cycle(m, 1, b"") await self.dummy_cycle(m, 1, b"")
except urwid.ExitMainLoop: except urwid.ExitMainLoop:
pass pass
assert len(m.view) == i assert len(m.view) == i

View File

@ -2,6 +2,7 @@ import json as _json
import logging import logging
from unittest import mock from unittest import mock
import os import os
import asyncio
import pytest import pytest
import tornado.testing import tornado.testing
@ -32,6 +33,11 @@ def json(resp: httpclient.HTTPResponse):
@pytest.mark.usefixtures("no_tornado_logging") @pytest.mark.usefixtures("no_tornado_logging")
class TestApp(tornado.testing.AsyncHTTPTestCase): class TestApp(tornado.testing.AsyncHTTPTestCase):
def get_new_ioloop(self):
io_loop = tornado.platform.asyncio.AsyncIOLoop()
asyncio.set_event_loop(io_loop.asyncio_loop)
return io_loop
def get_app(self): def get_app(self):
o = options.Options(http2=False) o = options.Options(http2=False)
m = webmaster.WebMaster(o, with_termlog=False) m = webmaster.WebMaster(o, with_termlog=False)
@ -39,7 +45,7 @@ class TestApp(tornado.testing.AsyncHTTPTestCase):
f.id = "42" f.id = "42"
m.view.add([f]) m.view.add([f])
m.view.add([tflow.tflow(err=True)]) m.view.add([tflow.tflow(err=True)])
m.add_log("test log", "info") m.log.info("test log")
self.master = m self.master = m
self.view = m.view self.view = m.view
self.events = m.events self.events = m.events
@ -75,12 +81,6 @@ class TestApp(tornado.testing.AsyncHTTPTestCase):
resp = self.fetch("/flows/dump") resp = self.fetch("/flows/dump")
assert b"address" in resp.body assert b"address" in resp.body
self.view.clear()
assert not len(self.view)
assert self.fetch("/flows/dump", method="POST", body=resp.body).code == 200
assert len(self.view)
def test_clear(self): def test_clear(self):
events = self.events.data.copy() events = self.events.data.copy()
flows = list(self.view) flows = list(self.view)

View File

@ -1,6 +1,8 @@
from mitmproxy.tools.web import master from mitmproxy.tools.web import master
from mitmproxy import options from mitmproxy import options
import pytest
from ... import tservers from ... import tservers
@ -9,8 +11,9 @@ class TestWebMaster(tservers.MasterTest):
o = options.Options(**opts) o = options.Options(**opts)
return master.WebMaster(o) return master.WebMaster(o)
def test_basic(self): @pytest.mark.asyncio
async def test_basic(self):
m = self.mkmaster() m = self.mkmaster()
for i in (1, 2, 3): for i in (1, 2, 3):
self.dummy_cycle(m, 1, b"") await self.dummy_cycle(m, 1, b"")
assert len(m.view) == i assert len(m.view) == i

View File

@ -26,20 +26,20 @@ from mitmproxy.test import taddons
class MasterTest: class MasterTest:
def cycle(self, master, content): async def cycle(self, master, content):
f = tflow.tflow(req=tutils.treq(content=content)) f = tflow.tflow(req=tutils.treq(content=content))
layer = mock.Mock("mitmproxy.proxy.protocol.base.Layer") layer = mock.Mock("mitmproxy.proxy.protocol.base.Layer")
layer.client_conn = f.client_conn layer.client_conn = f.client_conn
layer.reply = controller.DummyReply() layer.reply = controller.DummyReply()
master.addons.handle_lifecycle("clientconnect", layer) await master.addons.handle_lifecycle("clientconnect", layer)
for i in eventsequence.iterate(f): for i in eventsequence.iterate(f):
master.addons.handle_lifecycle(*i) await master.addons.handle_lifecycle(*i)
master.addons.handle_lifecycle("clientdisconnect", layer) await master.addons.handle_lifecycle("clientdisconnect", layer)
return f return f
def dummy_cycle(self, master, n, content): async def dummy_cycle(self, master, n, content):
for i in range(n): for i in range(n):
self.cycle(master, content) await self.cycle(master, content)
master.shutdown() master.shutdown()
def flowfile(self, path): def flowfile(self, path):