asyncio simplify: we don't need a queue for proxy->main loop comms

Instead, we just schedule coroutines directly onto the core loop.
This commit is contained in:
Aldo Cortesi 2018-04-02 20:50:10 +12:00 committed by Aldo Cortesi
parent cdbe6f97af
commit 0fa1280daa
13 changed files with 83 additions and 85 deletions

1
.gitignore vendored
View File

@ -14,6 +14,7 @@ build/
dist/
mitmproxy/contrib/kaitaistruct/*.ksy
.pytest_cache
__pycache__
# UI

View File

@ -3,6 +3,7 @@ import typing
import traceback
import contextlib
import sys
import asyncio
from mitmproxy import exceptions
from mitmproxy import eventsequence
@ -220,7 +221,7 @@ class AddonManager:
name = _get_name(item)
return name in self.lookup
def handle_lifecycle(self, name, message):
async def handle_lifecycle(self, name, message):
"""
Handle a lifecycle event.
"""

View File

@ -8,10 +8,10 @@ class Channel:
The only way for the proxy server to communicate with the master
is to use the channel it has been given.
"""
def __init__(self, loop, q, should_exit):
def __init__(self, master, loop, should_exit):
self.master = master
self.loop = loop
self.should_exit = should_exit
self._q = q
def ask(self, mtype, m):
"""
@ -22,7 +22,10 @@ class Channel:
exceptions.Kill: All connections should be closed immediately.
"""
m.reply = Reply(m)
asyncio.run_coroutine_threadsafe(self._q.put((mtype, m)), self.loop)
asyncio.run_coroutine_threadsafe(
self.master.addons.handle_lifecycle(mtype, m),
self.loop,
)
g = m.reply.q.get()
if g == exceptions.Kill:
raise exceptions.Kill()
@ -34,7 +37,10 @@ class Channel:
then return immediately.
"""
m.reply = DummyReply()
asyncio.run_coroutine_threadsafe(self._q.put((mtype, m)), self.loop)
asyncio.run_coroutine_threadsafe(
self.master.addons.handle_lifecycle(mtype, m),
self.loop,
)
NO_REPLY = object() # special object we can distinguish from a valid "None" reply.

View File

@ -43,14 +43,12 @@ class Master:
The master handles mitmproxy's main event loop.
"""
def __init__(self, opts):
self.event_queue = asyncio.Queue()
self.should_exit = threading.Event()
self.channel = controller.Channel(
self,
asyncio.get_event_loop(),
self.event_queue,
self.should_exit,
)
asyncio.ensure_future(self.main())
asyncio.ensure_future(self.tick())
self.options = opts or options.Options() # type: options.Options
@ -96,17 +94,6 @@ class Master:
if self.server:
ServerThread(self.server).start()
async def main(self):
while True:
try:
mtype, obj = await self.event_queue.get()
except RuntimeError:
return
if mtype not in eventsequence.Events: # pragma: no cover
raise exceptions.ControlException("Unknown event %s" % repr(mtype))
self.addons.handle_lifecycle(mtype, obj)
self.event_queue.task_done()
async def tick(self):
if self.first_tick:
self.first_tick = False
@ -145,7 +132,7 @@ class Master:
f.request.host, f.request.port = upstream_spec.address
f.request.scheme = upstream_spec.scheme
def load_flow(self, f):
async def load_flow(self, f):
"""
Loads a flow and links websocket & handshake flows
"""
@ -163,7 +150,7 @@ class Master:
f.reply = controller.DummyReply()
for e, o in eventsequence.iterate(f):
self.addons.handle_lifecycle(e, o)
await self.addons.handle_lifecycle(e, o)
def replay_request(
self,

View File

@ -4,6 +4,7 @@ import logging
import os.path
import re
from io import BytesIO
import asyncio
import mitmproxy.flow
import tornado.escape
@ -235,7 +236,7 @@ class DumpFlows(RequestHandler):
self.view.clear()
bio = BytesIO(self.filecontents)
for i in io.FlowReader(bio).stream():
self.master.load_flow(i)
asyncio.call_soon(self.master.load_flow, i)
bio.close()

View File

@ -123,8 +123,6 @@ class TcpMixin:
i2 = self.pathod("306")
self._ignore_off()
self.master.event_queue.join()
assert n.status_code == 304
assert i.status_code == 305
assert i2.status_code == 306
@ -168,8 +166,6 @@ class TcpMixin:
i2 = self.pathod("306")
self._tcpproxy_off()
self.master.event_queue.join()
assert n.status_code == 304
assert i.status_code == 305
assert i2.status_code == 306

View File

@ -65,7 +65,8 @@ def test_halt():
assert end.custom_called
def test_lifecycle():
@pytest.mark.asyncio
async def test_lifecycle():
o = options.Options()
m = master.Master(o)
a = addonmanager.AddonManager(m)
@ -77,7 +78,7 @@ def test_lifecycle():
a.remove(TAddon("nonexistent"))
f = tflow.tflow()
a.handle_lifecycle("request", f)
await a.handle_lifecycle("request", f)
a._configure_all(o, o.keys())

View File

@ -2,7 +2,7 @@ import io
from unittest import mock
import pytest
from mitmproxy.test import tflow, tutils
from mitmproxy.test import tflow, tutils, taddons
import mitmproxy.io
from mitmproxy import flowfilter
from mitmproxy import options
@ -97,27 +97,27 @@ class TestSerialize:
class TestFlowMaster:
def test_load_http_flow_reverse(self):
s = tservers.TestState()
@pytest.mark.asyncio
async def test_load_http_flow_reverse(self):
opts = options.Options(
mode="reverse:https://use-this-domain"
)
fm = master.Master(opts)
fm.addons.add(s)
s = tservers.TestState()
with taddons.context(s, options=opts) as ctx:
f = tflow.tflow(resp=True)
fm.load_flow(f)
await ctx.master.load_flow(f)
assert s.flows[0].request.host == "use-this-domain"
def test_load_websocket_flow(self):
s = tservers.TestState()
@pytest.mark.asyncio
async def test_load_websocket_flow(self):
opts = options.Options(
mode="reverse:https://use-this-domain"
)
fm = master.Master(opts)
fm.addons.add(s)
s = tservers.TestState()
with taddons.context(s, options=opts) as ctx:
f = tflow.twebsocketflow()
fm.load_flow(f.handshake_flow)
fm.load_flow(f)
await ctx.master.load_flow(f.handshake_flow)
await ctx.master.load_flow(f)
assert s.flows[0].request.host == "use-this-domain"
assert s.flows[1].handshake_flow == f.handshake_flow
assert len(s.flows[1].messages) == len(f.messages)
@ -150,31 +150,27 @@ class TestFlowMaster:
assert rt.f.request.http_version == "HTTP/1.1"
assert ":authority" not in rt.f.request.headers
def test_all(self):
@pytest.mark.asyncio
async def test_all(self):
opts = options.Options(
mode="reverse:https://use-this-domain"
)
s = tservers.TestState()
fm = master.Master(None)
fm.addons.add(s)
with taddons.context(s, options=opts) as ctx:
f = tflow.tflow(req=None)
fm.addons.handle_lifecycle("clientconnect", f.client_conn)
await ctx.master.addons.handle_lifecycle("clientconnect", f.client_conn)
f.request = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq())
fm.addons.handle_lifecycle("request", f)
await ctx.master.addons.handle_lifecycle("request", f)
assert len(s.flows) == 1
f.response = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp())
fm.addons.handle_lifecycle("response", f)
await ctx.master.addons.handle_lifecycle("response", f)
assert len(s.flows) == 1
fm.addons.handle_lifecycle("clientdisconnect", f.client_conn)
await ctx.master.addons.handle_lifecycle("clientdisconnect", f.client_conn)
f.error = flow.Error("msg")
fm.addons.handle_lifecycle("error", f)
# FIXME: This no longer works, because we consume on the main loop.
# fm.tell("foo", f)
# with pytest.raises(ControlException):
# fm.addons.trigger("unknown")
fm.shutdown()
await ctx.master.addons.handle_lifecycle("error", f)
class TestError:

View File

@ -4,13 +4,15 @@ from mitmproxy.tools.console import keymap
from mitmproxy.tools.console import master
from mitmproxy import command
import pytest
def test_commands_exist():
@pytest.mark.asyncio
async def test_commands_exist():
km = keymap.Keymap(None)
defaultkeys.map(km)
assert km.bindings
m = master.ConsoleMaster(None)
m.load_flow(tflow())
await m.load_flow(tflow())
for binding in km.bindings:
cmd, *args = command.lexer(binding.command)

View File

@ -4,6 +4,10 @@ from mitmproxy import options
from mitmproxy.tools import console
from ... import tservers
import pytest
@pytest.mark.asyncio
class TestMaster(tservers.MasterTest):
def mkmaster(self, **opts):
@ -12,11 +16,11 @@ class TestMaster(tservers.MasterTest):
m.addons.trigger("configure", o.keys())
return m
def test_basic(self):
async def test_basic(self):
m = self.mkmaster()
for i in (1, 2, 3):
try:
self.dummy_cycle(m, 1, b"")
await self.dummy_cycle(m, 1, b"")
except urwid.ExitMainLoop:
pass
assert len(m.view) == i

View File

@ -2,6 +2,7 @@ import json as _json
import logging
from unittest import mock
import os
import asyncio
import pytest
import tornado.testing
@ -32,6 +33,11 @@ def json(resp: httpclient.HTTPResponse):
@pytest.mark.usefixtures("no_tornado_logging")
class TestApp(tornado.testing.AsyncHTTPTestCase):
def get_new_ioloop(self):
io_loop = tornado.platform.asyncio.AsyncIOLoop()
asyncio.set_event_loop(io_loop.asyncio_loop)
return io_loop
def get_app(self):
o = options.Options(http2=False)
m = webmaster.WebMaster(o, with_termlog=False)
@ -75,12 +81,6 @@ class TestApp(tornado.testing.AsyncHTTPTestCase):
resp = self.fetch("/flows/dump")
assert b"address" in resp.body
self.view.clear()
assert not len(self.view)
assert self.fetch("/flows/dump", method="POST", body=resp.body).code == 200
assert len(self.view)
def test_clear(self):
events = self.events.data.copy()
flows = list(self.view)

View File

@ -1,6 +1,8 @@
from mitmproxy.tools.web import master
from mitmproxy import options
import pytest
from ... import tservers
@ -9,8 +11,9 @@ class TestWebMaster(tservers.MasterTest):
o = options.Options(**opts)
return master.WebMaster(o)
def test_basic(self):
@pytest.mark.asyncio
async def test_basic(self):
m = self.mkmaster()
for i in (1, 2, 3):
self.dummy_cycle(m, 1, b"")
await self.dummy_cycle(m, 1, b"")
assert len(m.view) == i

View File

@ -26,20 +26,20 @@ from mitmproxy.test import taddons
class MasterTest:
def cycle(self, master, content):
async def cycle(self, master, content):
f = tflow.tflow(req=tutils.treq(content=content))
layer = mock.Mock("mitmproxy.proxy.protocol.base.Layer")
layer.client_conn = f.client_conn
layer.reply = controller.DummyReply()
master.addons.handle_lifecycle("clientconnect", layer)
await master.addons.handle_lifecycle("clientconnect", layer)
for i in eventsequence.iterate(f):
master.addons.handle_lifecycle(*i)
master.addons.handle_lifecycle("clientdisconnect", layer)
await master.addons.handle_lifecycle(*i)
await master.addons.handle_lifecycle("clientdisconnect", layer)
return f
def dummy_cycle(self, master, n, content):
async def dummy_cycle(self, master, n, content):
for i in range(n):
self.cycle(master, content)
await self.cycle(master, content)
master.shutdown()
def flowfile(self, path):