mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-22 07:08:10 +00:00
asyncio simplify: we don't need a queue for proxy->main loop comms
Instead, we just schedule coroutines directly onto the core loop.
This commit is contained in:
parent
cdbe6f97af
commit
0fa1280daa
1
.gitignore
vendored
1
.gitignore
vendored
@ -14,6 +14,7 @@ build/
|
||||
dist/
|
||||
mitmproxy/contrib/kaitaistruct/*.ksy
|
||||
.pytest_cache
|
||||
__pycache__
|
||||
|
||||
# UI
|
||||
|
||||
|
@ -3,6 +3,7 @@ import typing
|
||||
import traceback
|
||||
import contextlib
|
||||
import sys
|
||||
import asyncio
|
||||
|
||||
from mitmproxy import exceptions
|
||||
from mitmproxy import eventsequence
|
||||
@ -220,7 +221,7 @@ class AddonManager:
|
||||
name = _get_name(item)
|
||||
return name in self.lookup
|
||||
|
||||
def handle_lifecycle(self, name, message):
|
||||
async def handle_lifecycle(self, name, message):
|
||||
"""
|
||||
Handle a lifecycle event.
|
||||
"""
|
||||
|
@ -8,10 +8,10 @@ class Channel:
|
||||
The only way for the proxy server to communicate with the master
|
||||
is to use the channel it has been given.
|
||||
"""
|
||||
def __init__(self, loop, q, should_exit):
|
||||
def __init__(self, master, loop, should_exit):
|
||||
self.master = master
|
||||
self.loop = loop
|
||||
self.should_exit = should_exit
|
||||
self._q = q
|
||||
|
||||
def ask(self, mtype, m):
|
||||
"""
|
||||
@ -22,7 +22,10 @@ class Channel:
|
||||
exceptions.Kill: All connections should be closed immediately.
|
||||
"""
|
||||
m.reply = Reply(m)
|
||||
asyncio.run_coroutine_threadsafe(self._q.put((mtype, m)), self.loop)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.master.addons.handle_lifecycle(mtype, m),
|
||||
self.loop,
|
||||
)
|
||||
g = m.reply.q.get()
|
||||
if g == exceptions.Kill:
|
||||
raise exceptions.Kill()
|
||||
@ -34,7 +37,10 @@ class Channel:
|
||||
then return immediately.
|
||||
"""
|
||||
m.reply = DummyReply()
|
||||
asyncio.run_coroutine_threadsafe(self._q.put((mtype, m)), self.loop)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.master.addons.handle_lifecycle(mtype, m),
|
||||
self.loop,
|
||||
)
|
||||
|
||||
|
||||
NO_REPLY = object() # special object we can distinguish from a valid "None" reply.
|
||||
|
@ -43,14 +43,12 @@ class Master:
|
||||
The master handles mitmproxy's main event loop.
|
||||
"""
|
||||
def __init__(self, opts):
|
||||
self.event_queue = asyncio.Queue()
|
||||
self.should_exit = threading.Event()
|
||||
self.channel = controller.Channel(
|
||||
self,
|
||||
asyncio.get_event_loop(),
|
||||
self.event_queue,
|
||||
self.should_exit,
|
||||
)
|
||||
asyncio.ensure_future(self.main())
|
||||
asyncio.ensure_future(self.tick())
|
||||
|
||||
self.options = opts or options.Options() # type: options.Options
|
||||
@ -96,17 +94,6 @@ class Master:
|
||||
if self.server:
|
||||
ServerThread(self.server).start()
|
||||
|
||||
async def main(self):
|
||||
while True:
|
||||
try:
|
||||
mtype, obj = await self.event_queue.get()
|
||||
except RuntimeError:
|
||||
return
|
||||
if mtype not in eventsequence.Events: # pragma: no cover
|
||||
raise exceptions.ControlException("Unknown event %s" % repr(mtype))
|
||||
self.addons.handle_lifecycle(mtype, obj)
|
||||
self.event_queue.task_done()
|
||||
|
||||
async def tick(self):
|
||||
if self.first_tick:
|
||||
self.first_tick = False
|
||||
@ -145,7 +132,7 @@ class Master:
|
||||
f.request.host, f.request.port = upstream_spec.address
|
||||
f.request.scheme = upstream_spec.scheme
|
||||
|
||||
def load_flow(self, f):
|
||||
async def load_flow(self, f):
|
||||
"""
|
||||
Loads a flow and links websocket & handshake flows
|
||||
"""
|
||||
@ -163,7 +150,7 @@ class Master:
|
||||
|
||||
f.reply = controller.DummyReply()
|
||||
for e, o in eventsequence.iterate(f):
|
||||
self.addons.handle_lifecycle(e, o)
|
||||
await self.addons.handle_lifecycle(e, o)
|
||||
|
||||
def replay_request(
|
||||
self,
|
||||
|
@ -4,6 +4,7 @@ import logging
|
||||
import os.path
|
||||
import re
|
||||
from io import BytesIO
|
||||
import asyncio
|
||||
|
||||
import mitmproxy.flow
|
||||
import tornado.escape
|
||||
@ -235,7 +236,7 @@ class DumpFlows(RequestHandler):
|
||||
self.view.clear()
|
||||
bio = BytesIO(self.filecontents)
|
||||
for i in io.FlowReader(bio).stream():
|
||||
self.master.load_flow(i)
|
||||
asyncio.call_soon(self.master.load_flow, i)
|
||||
bio.close()
|
||||
|
||||
|
||||
|
@ -123,8 +123,6 @@ class TcpMixin:
|
||||
i2 = self.pathod("306")
|
||||
self._ignore_off()
|
||||
|
||||
self.master.event_queue.join()
|
||||
|
||||
assert n.status_code == 304
|
||||
assert i.status_code == 305
|
||||
assert i2.status_code == 306
|
||||
@ -168,8 +166,6 @@ class TcpMixin:
|
||||
i2 = self.pathod("306")
|
||||
self._tcpproxy_off()
|
||||
|
||||
self.master.event_queue.join()
|
||||
|
||||
assert n.status_code == 304
|
||||
assert i.status_code == 305
|
||||
assert i2.status_code == 306
|
||||
|
@ -65,7 +65,8 @@ def test_halt():
|
||||
assert end.custom_called
|
||||
|
||||
|
||||
def test_lifecycle():
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle():
|
||||
o = options.Options()
|
||||
m = master.Master(o)
|
||||
a = addonmanager.AddonManager(m)
|
||||
@ -77,7 +78,7 @@ def test_lifecycle():
|
||||
a.remove(TAddon("nonexistent"))
|
||||
|
||||
f = tflow.tflow()
|
||||
a.handle_lifecycle("request", f)
|
||||
await a.handle_lifecycle("request", f)
|
||||
|
||||
a._configure_all(o, o.keys())
|
||||
|
||||
|
@ -2,7 +2,7 @@ import io
|
||||
from unittest import mock
|
||||
import pytest
|
||||
|
||||
from mitmproxy.test import tflow, tutils
|
||||
from mitmproxy.test import tflow, tutils, taddons
|
||||
import mitmproxy.io
|
||||
from mitmproxy import flowfilter
|
||||
from mitmproxy import options
|
||||
@ -97,27 +97,27 @@ class TestSerialize:
|
||||
|
||||
|
||||
class TestFlowMaster:
|
||||
def test_load_http_flow_reverse(self):
|
||||
s = tservers.TestState()
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_http_flow_reverse(self):
|
||||
opts = options.Options(
|
||||
mode="reverse:https://use-this-domain"
|
||||
)
|
||||
fm = master.Master(opts)
|
||||
fm.addons.add(s)
|
||||
s = tservers.TestState()
|
||||
with taddons.context(s, options=opts) as ctx:
|
||||
f = tflow.tflow(resp=True)
|
||||
fm.load_flow(f)
|
||||
await ctx.master.load_flow(f)
|
||||
assert s.flows[0].request.host == "use-this-domain"
|
||||
|
||||
def test_load_websocket_flow(self):
|
||||
s = tservers.TestState()
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_websocket_flow(self):
|
||||
opts = options.Options(
|
||||
mode="reverse:https://use-this-domain"
|
||||
)
|
||||
fm = master.Master(opts)
|
||||
fm.addons.add(s)
|
||||
s = tservers.TestState()
|
||||
with taddons.context(s, options=opts) as ctx:
|
||||
f = tflow.twebsocketflow()
|
||||
fm.load_flow(f.handshake_flow)
|
||||
fm.load_flow(f)
|
||||
await ctx.master.load_flow(f.handshake_flow)
|
||||
await ctx.master.load_flow(f)
|
||||
assert s.flows[0].request.host == "use-this-domain"
|
||||
assert s.flows[1].handshake_flow == f.handshake_flow
|
||||
assert len(s.flows[1].messages) == len(f.messages)
|
||||
@ -150,31 +150,27 @@ class TestFlowMaster:
|
||||
assert rt.f.request.http_version == "HTTP/1.1"
|
||||
assert ":authority" not in rt.f.request.headers
|
||||
|
||||
def test_all(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_all(self):
|
||||
opts = options.Options(
|
||||
mode="reverse:https://use-this-domain"
|
||||
)
|
||||
s = tservers.TestState()
|
||||
fm = master.Master(None)
|
||||
fm.addons.add(s)
|
||||
with taddons.context(s, options=opts) as ctx:
|
||||
f = tflow.tflow(req=None)
|
||||
fm.addons.handle_lifecycle("clientconnect", f.client_conn)
|
||||
await ctx.master.addons.handle_lifecycle("clientconnect", f.client_conn)
|
||||
f.request = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq())
|
||||
fm.addons.handle_lifecycle("request", f)
|
||||
await ctx.master.addons.handle_lifecycle("request", f)
|
||||
assert len(s.flows) == 1
|
||||
|
||||
f.response = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp())
|
||||
fm.addons.handle_lifecycle("response", f)
|
||||
await ctx.master.addons.handle_lifecycle("response", f)
|
||||
assert len(s.flows) == 1
|
||||
|
||||
fm.addons.handle_lifecycle("clientdisconnect", f.client_conn)
|
||||
await ctx.master.addons.handle_lifecycle("clientdisconnect", f.client_conn)
|
||||
|
||||
f.error = flow.Error("msg")
|
||||
fm.addons.handle_lifecycle("error", f)
|
||||
|
||||
# FIXME: This no longer works, because we consume on the main loop.
|
||||
# fm.tell("foo", f)
|
||||
# with pytest.raises(ControlException):
|
||||
# fm.addons.trigger("unknown")
|
||||
|
||||
fm.shutdown()
|
||||
await ctx.master.addons.handle_lifecycle("error", f)
|
||||
|
||||
|
||||
class TestError:
|
||||
|
@ -4,13 +4,15 @@ from mitmproxy.tools.console import keymap
|
||||
from mitmproxy.tools.console import master
|
||||
from mitmproxy import command
|
||||
|
||||
import pytest
|
||||
|
||||
def test_commands_exist():
|
||||
@pytest.mark.asyncio
|
||||
async def test_commands_exist():
|
||||
km = keymap.Keymap(None)
|
||||
defaultkeys.map(km)
|
||||
assert km.bindings
|
||||
m = master.ConsoleMaster(None)
|
||||
m.load_flow(tflow())
|
||||
await m.load_flow(tflow())
|
||||
|
||||
for binding in km.bindings:
|
||||
cmd, *args = command.lexer(binding.command)
|
||||
|
@ -4,6 +4,10 @@ from mitmproxy import options
|
||||
from mitmproxy.tools import console
|
||||
from ... import tservers
|
||||
|
||||
import pytest
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
|
||||
class TestMaster(tservers.MasterTest):
|
||||
def mkmaster(self, **opts):
|
||||
@ -12,11 +16,11 @@ class TestMaster(tservers.MasterTest):
|
||||
m.addons.trigger("configure", o.keys())
|
||||
return m
|
||||
|
||||
def test_basic(self):
|
||||
async def test_basic(self):
|
||||
m = self.mkmaster()
|
||||
for i in (1, 2, 3):
|
||||
try:
|
||||
self.dummy_cycle(m, 1, b"")
|
||||
await self.dummy_cycle(m, 1, b"")
|
||||
except urwid.ExitMainLoop:
|
||||
pass
|
||||
assert len(m.view) == i
|
||||
|
@ -2,6 +2,7 @@ import json as _json
|
||||
import logging
|
||||
from unittest import mock
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
import tornado.testing
|
||||
@ -32,6 +33,11 @@ def json(resp: httpclient.HTTPResponse):
|
||||
|
||||
@pytest.mark.usefixtures("no_tornado_logging")
|
||||
class TestApp(tornado.testing.AsyncHTTPTestCase):
|
||||
def get_new_ioloop(self):
|
||||
io_loop = tornado.platform.asyncio.AsyncIOLoop()
|
||||
asyncio.set_event_loop(io_loop.asyncio_loop)
|
||||
return io_loop
|
||||
|
||||
def get_app(self):
|
||||
o = options.Options(http2=False)
|
||||
m = webmaster.WebMaster(o, with_termlog=False)
|
||||
@ -75,12 +81,6 @@ class TestApp(tornado.testing.AsyncHTTPTestCase):
|
||||
resp = self.fetch("/flows/dump")
|
||||
assert b"address" in resp.body
|
||||
|
||||
self.view.clear()
|
||||
assert not len(self.view)
|
||||
|
||||
assert self.fetch("/flows/dump", method="POST", body=resp.body).code == 200
|
||||
assert len(self.view)
|
||||
|
||||
def test_clear(self):
|
||||
events = self.events.data.copy()
|
||||
flows = list(self.view)
|
||||
|
@ -1,6 +1,8 @@
|
||||
from mitmproxy.tools.web import master
|
||||
from mitmproxy import options
|
||||
|
||||
import pytest
|
||||
|
||||
from ... import tservers
|
||||
|
||||
|
||||
@ -9,8 +11,9 @@ class TestWebMaster(tservers.MasterTest):
|
||||
o = options.Options(**opts)
|
||||
return master.WebMaster(o)
|
||||
|
||||
def test_basic(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic(self):
|
||||
m = self.mkmaster()
|
||||
for i in (1, 2, 3):
|
||||
self.dummy_cycle(m, 1, b"")
|
||||
await self.dummy_cycle(m, 1, b"")
|
||||
assert len(m.view) == i
|
||||
|
@ -26,20 +26,20 @@ from mitmproxy.test import taddons
|
||||
|
||||
class MasterTest:
|
||||
|
||||
def cycle(self, master, content):
|
||||
async def cycle(self, master, content):
|
||||
f = tflow.tflow(req=tutils.treq(content=content))
|
||||
layer = mock.Mock("mitmproxy.proxy.protocol.base.Layer")
|
||||
layer.client_conn = f.client_conn
|
||||
layer.reply = controller.DummyReply()
|
||||
master.addons.handle_lifecycle("clientconnect", layer)
|
||||
await master.addons.handle_lifecycle("clientconnect", layer)
|
||||
for i in eventsequence.iterate(f):
|
||||
master.addons.handle_lifecycle(*i)
|
||||
master.addons.handle_lifecycle("clientdisconnect", layer)
|
||||
await master.addons.handle_lifecycle(*i)
|
||||
await master.addons.handle_lifecycle("clientdisconnect", layer)
|
||||
return f
|
||||
|
||||
def dummy_cycle(self, master, n, content):
|
||||
async def dummy_cycle(self, master, n, content):
|
||||
for i in range(n):
|
||||
self.cycle(master, content)
|
||||
await self.cycle(master, content)
|
||||
master.shutdown()
|
||||
|
||||
def flowfile(self, path):
|
||||
|
Loading…
Reference in New Issue
Block a user