Use async for tctx.cycle/tctx.invoke.

This commit is contained in:
Robert Xiao 2022-02-03 03:25:11 -08:00 committed by Maximilian Hils
parent e186ccb3ba
commit caf49300c2
5 changed files with 48 additions and 26 deletions

View File

@ -1,4 +1,3 @@
import contextlib
import asyncio import asyncio
import sys import sys
@ -79,15 +78,14 @@ class context:
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
return False return False
@contextlib.contextmanager async def cycle(self, addon, f):
def cycle(self, addon, f):
""" """
Cycles the flow through the events for the flow. Stops if a reply Cycles the flow through the events for the flow. Stops if a reply
is taken (as in flow interception). is taken (as in flow interception).
""" """
f.reply._state = "start" f.reply._state = "start"
for evt in eventsequence.iterate(f): for evt in eventsequence.iterate(f):
self.master.addons.invoke_addon_sync( await self.master.addons.invoke_addon(
addon, addon,
evt evt
) )
@ -115,11 +113,11 @@ class context:
sc = script.Script(path, False) sc = script.Script(path, False)
return sc.addons[0] if sc.addons else None return sc.addons[0] if sc.addons else None
def invoke(self, addon, event: hooks.Hook): async def invoke(self, addon, event: hooks.Hook):
""" """
Recursively invoke an event on an addon and all its children. Recursively invoke an event on an addon and all its children.
""" """
return self.master.addons.invoke_addon_sync(addon, event) return await self.master.addons.invoke_addon(addon, event)
def command(self, func, *args): def command(self, func, *args):
""" """

View File

@ -1,4 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import asyncio
import click import click
from mitmproxy.addons import dumper from mitmproxy.addons import dumper
@ -6,12 +7,24 @@ from mitmproxy.test import tflow
from mitmproxy.test import taddons from mitmproxy.test import taddons
def run_async(coro):
"""
Run the given async function in a new event loop.
This allows async functions to be called synchronously.
"""
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(coro)
finally:
loop.close()
def show(flow_detail, flows): def show(flow_detail, flows):
d = dumper.Dumper() d = dumper.Dumper()
with taddons.context() as ctx: with taddons.context() as ctx:
ctx.configure(d, flow_detail=flow_detail) ctx.configure(d, flow_detail=flow_detail)
for f in flows: for f in flows:
ctx.cycle(d, f) run_async(ctx.cycle(d, f))
@click.group() @click.group()

View File

@ -7,7 +7,8 @@ from mitmproxy.test import taddons
from mitmproxy.test import tflow from mitmproxy.test import tflow
def test_simple(): @pytest.mark.asyncio
async def test_simple():
r = intercept.Intercept() r = intercept.Intercept()
with taddons.context(r) as tctx: with taddons.context(r) as tctx:
assert not r.filt assert not r.filt
@ -23,11 +24,11 @@ def test_simple():
tctx.configure(r, intercept="~s") tctx.configure(r, intercept="~s")
f = tflow.tflow(resp=True) f = tflow.tflow(resp=True)
tctx.cycle(r, f) await tctx.cycle(r, f)
assert f.intercepted assert f.intercepted
f = tflow.tflow(resp=False) f = tflow.tflow(resp=False)
tctx.cycle(r, f) await tctx.cycle(r, f)
assert not f.intercepted assert not f.intercepted
f = tflow.tflow(resp=True) f = tflow.tflow(resp=True)
@ -36,39 +37,41 @@ def test_simple():
tctx.configure(r, intercept_active=False) tctx.configure(r, intercept_active=False)
f = tflow.tflow(resp=True) f = tflow.tflow(resp=True)
tctx.cycle(r, f) await tctx.cycle(r, f)
assert not f.intercepted assert not f.intercepted
tctx.configure(r, intercept_active=True) tctx.configure(r, intercept_active=True)
f = tflow.tflow(resp=True) f = tflow.tflow(resp=True)
tctx.cycle(r, f) await tctx.cycle(r, f)
assert f.intercepted assert f.intercepted
def test_tcp(): @pytest.mark.asyncio
async def test_tcp():
r = intercept.Intercept() r = intercept.Intercept()
with taddons.context(r) as tctx: with taddons.context(r) as tctx:
tctx.configure(r, intercept="~tcp") tctx.configure(r, intercept="~tcp")
f = tflow.ttcpflow() f = tflow.ttcpflow()
tctx.cycle(r, f) await tctx.cycle(r, f)
assert f.intercepted assert f.intercepted
tctx.configure(r, intercept_active=False) tctx.configure(r, intercept_active=False)
f = tflow.ttcpflow() f = tflow.ttcpflow()
tctx.cycle(r, f) await tctx.cycle(r, f)
assert not f.intercepted assert not f.intercepted
def test_already_taken(): @pytest.mark.asyncio
async def test_already_taken():
r = intercept.Intercept() r = intercept.Intercept()
with taddons.context(r) as tctx: with taddons.context(r) as tctx:
tctx.configure(r, intercept="~q") tctx.configure(r, intercept="~q")
f = tflow.tflow() f = tflow.tflow()
tctx.invoke(r, layers.http.HttpRequestHook(f)) await tctx.invoke(r, layers.http.HttpRequestHook(f))
assert f.intercepted assert f.intercepted
f = tflow.tflow() f = tflow.tflow()
f.reply.take() f.reply.take()
tctx.invoke(r, layers.http.HttpRequestHook(f)) await tctx.invoke(r, layers.http.HttpRequestHook(f))
assert not f.intercepted assert not f.intercepted

View File

@ -340,7 +340,8 @@ def test_server_playback_full():
assert not tf.response assert not tf.response
def test_server_playback_kill(): @pytest.mark.asyncio
async def test_server_playback_kill():
s = serverplayback.ServerPlayback() s = serverplayback.ServerPlayback()
with taddons.context(s) as tctx: with taddons.context(s) as tctx:
tctx.configure( tctx.configure(
@ -355,7 +356,7 @@ def test_server_playback_kill():
f = tflow.tflow() f = tflow.tflow()
f.request.host = "nonexistent" f.request.host = "nonexistent"
tctx.cycle(s, f) await tctx.cycle(s, f)
assert f.error assert f.error

View File

@ -1,3 +1,4 @@
import asyncio
import time import time
import pytest import pytest
@ -15,7 +16,8 @@ class Thing:
class TestConcurrent: class TestConcurrent:
def test_concurrent(self, tdata): @pytest.mark.asyncio
async def test_concurrent(self, tdata):
with taddons.context() as tctx: with taddons.context() as tctx:
sc = tctx.script( sc = tctx.script(
tdata.path( tdata.path(
@ -23,8 +25,10 @@ class TestConcurrent:
) )
) )
f1, f2 = tflow.tflow(), tflow.tflow() f1, f2 = tflow.tflow(), tflow.tflow()
tctx.cycle(sc, f1) await asyncio.gather(
tctx.cycle(sc, f2) tctx.cycle(sc, f1),
tctx.cycle(sc, f2),
)
start = time.time() start = time.time()
while time.time() - start < 5: while time.time() - start < 5:
if f1.reply.state == f2.reply.state == "committed": if f1.reply.state == f2.reply.state == "committed":
@ -41,7 +45,8 @@ class TestConcurrent:
) )
await tctx.master.await_log("decorator not supported") await tctx.master.await_log("decorator not supported")
def test_concurrent_class(self, tdata): @pytest.mark.asyncio
async def test_concurrent_class(self, tdata):
with taddons.context() as tctx: with taddons.context() as tctx:
sc = tctx.script( sc = tctx.script(
tdata.path( tdata.path(
@ -49,8 +54,10 @@ class TestConcurrent:
) )
) )
f1, f2 = tflow.tflow(), tflow.tflow() f1, f2 = tflow.tflow(), tflow.tflow()
tctx.cycle(sc, f1) await asyncio.gather(
tctx.cycle(sc, f2) tctx.cycle(sc, f1),
tctx.cycle(sc, f2),
)
start = time.time() start = time.time()
while time.time() - start < 5: while time.time() - start < 5:
if f1.reply.state == f2.reply.state == "committed": if f1.reply.state == f2.reply.state == "committed":