asyncio simplify: we don't need a queue for proxy->main loop comms

Instead, we just schedule coroutines directly onto the core loop.
This commit is contained in:
Aldo Cortesi 2018-04-02 20:50:10 +12:00 committed by Aldo Cortesi
parent cdbe6f97af
commit 0fa1280daa
13 changed files with 83 additions and 85 deletions

1
.gitignore vendored
View File

@ -14,6 +14,7 @@ build/
dist/ dist/
mitmproxy/contrib/kaitaistruct/*.ksy mitmproxy/contrib/kaitaistruct/*.ksy
.pytest_cache .pytest_cache
__pycache__
# UI # UI

View File

@ -3,6 +3,7 @@ import typing
import traceback import traceback
import contextlib import contextlib
import sys import sys
import asyncio
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy import eventsequence from mitmproxy import eventsequence
@ -220,7 +221,7 @@ class AddonManager:
name = _get_name(item) name = _get_name(item)
return name in self.lookup return name in self.lookup
def handle_lifecycle(self, name, message): async def handle_lifecycle(self, name, message):
""" """
Handle a lifecycle event. Handle a lifecycle event.
""" """

View File

@ -8,10 +8,10 @@ class Channel:
The only way for the proxy server to communicate with the master The only way for the proxy server to communicate with the master
is to use the channel it has been given. is to use the channel it has been given.
""" """
def __init__(self, loop, q, should_exit): def __init__(self, master, loop, should_exit):
self.master = master
self.loop = loop self.loop = loop
self.should_exit = should_exit self.should_exit = should_exit
self._q = q
def ask(self, mtype, m): def ask(self, mtype, m):
""" """
@ -22,7 +22,10 @@ class Channel:
exceptions.Kill: All connections should be closed immediately. exceptions.Kill: All connections should be closed immediately.
""" """
m.reply = Reply(m) m.reply = Reply(m)
asyncio.run_coroutine_threadsafe(self._q.put((mtype, m)), self.loop) asyncio.run_coroutine_threadsafe(
self.master.addons.handle_lifecycle(mtype, m),
self.loop,
)
g = m.reply.q.get() g = m.reply.q.get()
if g == exceptions.Kill: if g == exceptions.Kill:
raise exceptions.Kill() raise exceptions.Kill()
@ -34,7 +37,10 @@ class Channel:
then return immediately. then return immediately.
""" """
m.reply = DummyReply() m.reply = DummyReply()
asyncio.run_coroutine_threadsafe(self._q.put((mtype, m)), self.loop) asyncio.run_coroutine_threadsafe(
self.master.addons.handle_lifecycle(mtype, m),
self.loop,
)
NO_REPLY = object() # special object we can distinguish from a valid "None" reply. NO_REPLY = object() # special object we can distinguish from a valid "None" reply.

View File

@ -43,14 +43,12 @@ class Master:
The master handles mitmproxy's main event loop. The master handles mitmproxy's main event loop.
""" """
def __init__(self, opts): def __init__(self, opts):
self.event_queue = asyncio.Queue()
self.should_exit = threading.Event() self.should_exit = threading.Event()
self.channel = controller.Channel( self.channel = controller.Channel(
self,
asyncio.get_event_loop(), asyncio.get_event_loop(),
self.event_queue,
self.should_exit, self.should_exit,
) )
asyncio.ensure_future(self.main())
asyncio.ensure_future(self.tick()) asyncio.ensure_future(self.tick())
self.options = opts or options.Options() # type: options.Options self.options = opts or options.Options() # type: options.Options
@ -96,17 +94,6 @@ class Master:
if self.server: if self.server:
ServerThread(self.server).start() ServerThread(self.server).start()
async def main(self):
while True:
try:
mtype, obj = await self.event_queue.get()
except RuntimeError:
return
if mtype not in eventsequence.Events: # pragma: no cover
raise exceptions.ControlException("Unknown event %s" % repr(mtype))
self.addons.handle_lifecycle(mtype, obj)
self.event_queue.task_done()
async def tick(self): async def tick(self):
if self.first_tick: if self.first_tick:
self.first_tick = False self.first_tick = False
@ -145,7 +132,7 @@ class Master:
f.request.host, f.request.port = upstream_spec.address f.request.host, f.request.port = upstream_spec.address
f.request.scheme = upstream_spec.scheme f.request.scheme = upstream_spec.scheme
def load_flow(self, f): async def load_flow(self, f):
""" """
Loads a flow and links websocket & handshake flows Loads a flow and links websocket & handshake flows
""" """
@ -163,7 +150,7 @@ class Master:
f.reply = controller.DummyReply() f.reply = controller.DummyReply()
for e, o in eventsequence.iterate(f): for e, o in eventsequence.iterate(f):
self.addons.handle_lifecycle(e, o) await self.addons.handle_lifecycle(e, o)
def replay_request( def replay_request(
self, self,

View File

@ -4,6 +4,7 @@ import logging
import os.path import os.path
import re import re
from io import BytesIO from io import BytesIO
import asyncio
import mitmproxy.flow import mitmproxy.flow
import tornado.escape import tornado.escape
@ -235,7 +236,7 @@ class DumpFlows(RequestHandler):
self.view.clear() self.view.clear()
bio = BytesIO(self.filecontents) bio = BytesIO(self.filecontents)
for i in io.FlowReader(bio).stream(): for i in io.FlowReader(bio).stream():
self.master.load_flow(i) asyncio.call_soon(self.master.load_flow, i)
bio.close() bio.close()

View File

@ -123,8 +123,6 @@ class TcpMixin:
i2 = self.pathod("306") i2 = self.pathod("306")
self._ignore_off() self._ignore_off()
self.master.event_queue.join()
assert n.status_code == 304 assert n.status_code == 304
assert i.status_code == 305 assert i.status_code == 305
assert i2.status_code == 306 assert i2.status_code == 306
@ -168,8 +166,6 @@ class TcpMixin:
i2 = self.pathod("306") i2 = self.pathod("306")
self._tcpproxy_off() self._tcpproxy_off()
self.master.event_queue.join()
assert n.status_code == 304 assert n.status_code == 304
assert i.status_code == 305 assert i.status_code == 305
assert i2.status_code == 306 assert i2.status_code == 306

View File

@ -65,7 +65,8 @@ def test_halt():
assert end.custom_called assert end.custom_called
def test_lifecycle(): @pytest.mark.asyncio
async def test_lifecycle():
o = options.Options() o = options.Options()
m = master.Master(o) m = master.Master(o)
a = addonmanager.AddonManager(m) a = addonmanager.AddonManager(m)
@ -77,7 +78,7 @@ def test_lifecycle():
a.remove(TAddon("nonexistent")) a.remove(TAddon("nonexistent"))
f = tflow.tflow() f = tflow.tflow()
a.handle_lifecycle("request", f) await a.handle_lifecycle("request", f)
a._configure_all(o, o.keys()) a._configure_all(o, o.keys())

View File

@ -2,7 +2,7 @@ import io
from unittest import mock from unittest import mock
import pytest import pytest
from mitmproxy.test import tflow, tutils from mitmproxy.test import tflow, tutils, taddons
import mitmproxy.io import mitmproxy.io
from mitmproxy import flowfilter from mitmproxy import flowfilter
from mitmproxy import options from mitmproxy import options
@ -97,27 +97,27 @@ class TestSerialize:
class TestFlowMaster: class TestFlowMaster:
def test_load_http_flow_reverse(self): @pytest.mark.asyncio
s = tservers.TestState() async def test_load_http_flow_reverse(self):
opts = options.Options( opts = options.Options(
mode="reverse:https://use-this-domain" mode="reverse:https://use-this-domain"
) )
fm = master.Master(opts) s = tservers.TestState()
fm.addons.add(s) with taddons.context(s, options=opts) as ctx:
f = tflow.tflow(resp=True) f = tflow.tflow(resp=True)
fm.load_flow(f) await ctx.master.load_flow(f)
assert s.flows[0].request.host == "use-this-domain" assert s.flows[0].request.host == "use-this-domain"
def test_load_websocket_flow(self): @pytest.mark.asyncio
s = tservers.TestState() async def test_load_websocket_flow(self):
opts = options.Options( opts = options.Options(
mode="reverse:https://use-this-domain" mode="reverse:https://use-this-domain"
) )
fm = master.Master(opts) s = tservers.TestState()
fm.addons.add(s) with taddons.context(s, options=opts) as ctx:
f = tflow.twebsocketflow() f = tflow.twebsocketflow()
fm.load_flow(f.handshake_flow) await ctx.master.load_flow(f.handshake_flow)
fm.load_flow(f) await ctx.master.load_flow(f)
assert s.flows[0].request.host == "use-this-domain" assert s.flows[0].request.host == "use-this-domain"
assert s.flows[1].handshake_flow == f.handshake_flow assert s.flows[1].handshake_flow == f.handshake_flow
assert len(s.flows[1].messages) == len(f.messages) assert len(s.flows[1].messages) == len(f.messages)
@ -150,31 +150,27 @@ class TestFlowMaster:
assert rt.f.request.http_version == "HTTP/1.1" assert rt.f.request.http_version == "HTTP/1.1"
assert ":authority" not in rt.f.request.headers assert ":authority" not in rt.f.request.headers
def test_all(self): @pytest.mark.asyncio
async def test_all(self):
opts = options.Options(
mode="reverse:https://use-this-domain"
)
s = tservers.TestState() s = tservers.TestState()
fm = master.Master(None) with taddons.context(s, options=opts) as ctx:
fm.addons.add(s)
f = tflow.tflow(req=None) f = tflow.tflow(req=None)
fm.addons.handle_lifecycle("clientconnect", f.client_conn) await ctx.master.addons.handle_lifecycle("clientconnect", f.client_conn)
f.request = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq()) f.request = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq())
fm.addons.handle_lifecycle("request", f) await ctx.master.addons.handle_lifecycle("request", f)
assert len(s.flows) == 1 assert len(s.flows) == 1
f.response = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp()) f.response = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp())
fm.addons.handle_lifecycle("response", f) await ctx.master.addons.handle_lifecycle("response", f)
assert len(s.flows) == 1 assert len(s.flows) == 1
fm.addons.handle_lifecycle("clientdisconnect", f.client_conn) await ctx.master.addons.handle_lifecycle("clientdisconnect", f.client_conn)
f.error = flow.Error("msg") f.error = flow.Error("msg")
fm.addons.handle_lifecycle("error", f) await ctx.master.addons.handle_lifecycle("error", f)
# FIXME: This no longer works, because we consume on the main loop.
# fm.tell("foo", f)
# with pytest.raises(ControlException):
# fm.addons.trigger("unknown")
fm.shutdown()
class TestError: class TestError:

View File

@ -4,13 +4,15 @@ from mitmproxy.tools.console import keymap
from mitmproxy.tools.console import master from mitmproxy.tools.console import master
from mitmproxy import command from mitmproxy import command
import pytest
def test_commands_exist(): @pytest.mark.asyncio
async def test_commands_exist():
km = keymap.Keymap(None) km = keymap.Keymap(None)
defaultkeys.map(km) defaultkeys.map(km)
assert km.bindings assert km.bindings
m = master.ConsoleMaster(None) m = master.ConsoleMaster(None)
m.load_flow(tflow()) await m.load_flow(tflow())
for binding in km.bindings: for binding in km.bindings:
cmd, *args = command.lexer(binding.command) cmd, *args = command.lexer(binding.command)

View File

@ -4,6 +4,10 @@ from mitmproxy import options
from mitmproxy.tools import console from mitmproxy.tools import console
from ... import tservers from ... import tservers
import pytest
@pytest.mark.asyncio
class TestMaster(tservers.MasterTest): class TestMaster(tservers.MasterTest):
def mkmaster(self, **opts): def mkmaster(self, **opts):
@ -12,11 +16,11 @@ class TestMaster(tservers.MasterTest):
m.addons.trigger("configure", o.keys()) m.addons.trigger("configure", o.keys())
return m return m
def test_basic(self): async def test_basic(self):
m = self.mkmaster() m = self.mkmaster()
for i in (1, 2, 3): for i in (1, 2, 3):
try: try:
self.dummy_cycle(m, 1, b"") await self.dummy_cycle(m, 1, b"")
except urwid.ExitMainLoop: except urwid.ExitMainLoop:
pass pass
assert len(m.view) == i assert len(m.view) == i

View File

@ -2,6 +2,7 @@ import json as _json
import logging import logging
from unittest import mock from unittest import mock
import os import os
import asyncio
import pytest import pytest
import tornado.testing import tornado.testing
@ -32,6 +33,11 @@ def json(resp: httpclient.HTTPResponse):
@pytest.mark.usefixtures("no_tornado_logging") @pytest.mark.usefixtures("no_tornado_logging")
class TestApp(tornado.testing.AsyncHTTPTestCase): class TestApp(tornado.testing.AsyncHTTPTestCase):
def get_new_ioloop(self):
io_loop = tornado.platform.asyncio.AsyncIOLoop()
asyncio.set_event_loop(io_loop.asyncio_loop)
return io_loop
def get_app(self): def get_app(self):
o = options.Options(http2=False) o = options.Options(http2=False)
m = webmaster.WebMaster(o, with_termlog=False) m = webmaster.WebMaster(o, with_termlog=False)
@ -75,12 +81,6 @@ class TestApp(tornado.testing.AsyncHTTPTestCase):
resp = self.fetch("/flows/dump") resp = self.fetch("/flows/dump")
assert b"address" in resp.body assert b"address" in resp.body
self.view.clear()
assert not len(self.view)
assert self.fetch("/flows/dump", method="POST", body=resp.body).code == 200
assert len(self.view)
def test_clear(self): def test_clear(self):
events = self.events.data.copy() events = self.events.data.copy()
flows = list(self.view) flows = list(self.view)

View File

@ -1,6 +1,8 @@
from mitmproxy.tools.web import master from mitmproxy.tools.web import master
from mitmproxy import options from mitmproxy import options
import pytest
from ... import tservers from ... import tservers
@ -9,8 +11,9 @@ class TestWebMaster(tservers.MasterTest):
o = options.Options(**opts) o = options.Options(**opts)
return master.WebMaster(o) return master.WebMaster(o)
def test_basic(self): @pytest.mark.asyncio
async def test_basic(self):
m = self.mkmaster() m = self.mkmaster()
for i in (1, 2, 3): for i in (1, 2, 3):
self.dummy_cycle(m, 1, b"") await self.dummy_cycle(m, 1, b"")
assert len(m.view) == i assert len(m.view) == i

View File

@ -26,20 +26,20 @@ from mitmproxy.test import taddons
class MasterTest: class MasterTest:
def cycle(self, master, content): async def cycle(self, master, content):
f = tflow.tflow(req=tutils.treq(content=content)) f = tflow.tflow(req=tutils.treq(content=content))
layer = mock.Mock("mitmproxy.proxy.protocol.base.Layer") layer = mock.Mock("mitmproxy.proxy.protocol.base.Layer")
layer.client_conn = f.client_conn layer.client_conn = f.client_conn
layer.reply = controller.DummyReply() layer.reply = controller.DummyReply()
master.addons.handle_lifecycle("clientconnect", layer) await master.addons.handle_lifecycle("clientconnect", layer)
for i in eventsequence.iterate(f): for i in eventsequence.iterate(f):
master.addons.handle_lifecycle(*i) await master.addons.handle_lifecycle(*i)
master.addons.handle_lifecycle("clientdisconnect", layer) await master.addons.handle_lifecycle("clientdisconnect", layer)
return f return f
def dummy_cycle(self, master, n, content): async def dummy_cycle(self, master, n, content):
for i in range(n): for i in range(n):
self.cycle(master, content) await self.cycle(master, content)
master.shutdown() master.shutdown()
def flowfile(self, path): def flowfile(self, path):