mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 02:10:59 +00:00
asyncio simplify: we don't need a queue for proxy->main loop comms
Instead, we just schedule coroutines directly onto the core loop.
This commit is contained in:
parent
cdbe6f97af
commit
0fa1280daa
1
.gitignore
vendored
1
.gitignore
vendored
@ -14,6 +14,7 @@ build/
|
|||||||
dist/
|
dist/
|
||||||
mitmproxy/contrib/kaitaistruct/*.ksy
|
mitmproxy/contrib/kaitaistruct/*.ksy
|
||||||
.pytest_cache
|
.pytest_cache
|
||||||
|
__pycache__
|
||||||
|
|
||||||
# UI
|
# UI
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ import typing
|
|||||||
import traceback
|
import traceback
|
||||||
import contextlib
|
import contextlib
|
||||||
import sys
|
import sys
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from mitmproxy import exceptions
|
from mitmproxy import exceptions
|
||||||
from mitmproxy import eventsequence
|
from mitmproxy import eventsequence
|
||||||
@ -220,7 +221,7 @@ class AddonManager:
|
|||||||
name = _get_name(item)
|
name = _get_name(item)
|
||||||
return name in self.lookup
|
return name in self.lookup
|
||||||
|
|
||||||
def handle_lifecycle(self, name, message):
|
async def handle_lifecycle(self, name, message):
|
||||||
"""
|
"""
|
||||||
Handle a lifecycle event.
|
Handle a lifecycle event.
|
||||||
"""
|
"""
|
||||||
|
@ -8,10 +8,10 @@ class Channel:
|
|||||||
The only way for the proxy server to communicate with the master
|
The only way for the proxy server to communicate with the master
|
||||||
is to use the channel it has been given.
|
is to use the channel it has been given.
|
||||||
"""
|
"""
|
||||||
def __init__(self, loop, q, should_exit):
|
def __init__(self, master, loop, should_exit):
|
||||||
|
self.master = master
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
self.should_exit = should_exit
|
self.should_exit = should_exit
|
||||||
self._q = q
|
|
||||||
|
|
||||||
def ask(self, mtype, m):
|
def ask(self, mtype, m):
|
||||||
"""
|
"""
|
||||||
@ -22,7 +22,10 @@ class Channel:
|
|||||||
exceptions.Kill: All connections should be closed immediately.
|
exceptions.Kill: All connections should be closed immediately.
|
||||||
"""
|
"""
|
||||||
m.reply = Reply(m)
|
m.reply = Reply(m)
|
||||||
asyncio.run_coroutine_threadsafe(self._q.put((mtype, m)), self.loop)
|
asyncio.run_coroutine_threadsafe(
|
||||||
|
self.master.addons.handle_lifecycle(mtype, m),
|
||||||
|
self.loop,
|
||||||
|
)
|
||||||
g = m.reply.q.get()
|
g = m.reply.q.get()
|
||||||
if g == exceptions.Kill:
|
if g == exceptions.Kill:
|
||||||
raise exceptions.Kill()
|
raise exceptions.Kill()
|
||||||
@ -34,7 +37,10 @@ class Channel:
|
|||||||
then return immediately.
|
then return immediately.
|
||||||
"""
|
"""
|
||||||
m.reply = DummyReply()
|
m.reply = DummyReply()
|
||||||
asyncio.run_coroutine_threadsafe(self._q.put((mtype, m)), self.loop)
|
asyncio.run_coroutine_threadsafe(
|
||||||
|
self.master.addons.handle_lifecycle(mtype, m),
|
||||||
|
self.loop,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
NO_REPLY = object() # special object we can distinguish from a valid "None" reply.
|
NO_REPLY = object() # special object we can distinguish from a valid "None" reply.
|
||||||
|
@ -43,14 +43,12 @@ class Master:
|
|||||||
The master handles mitmproxy's main event loop.
|
The master handles mitmproxy's main event loop.
|
||||||
"""
|
"""
|
||||||
def __init__(self, opts):
|
def __init__(self, opts):
|
||||||
self.event_queue = asyncio.Queue()
|
|
||||||
self.should_exit = threading.Event()
|
self.should_exit = threading.Event()
|
||||||
self.channel = controller.Channel(
|
self.channel = controller.Channel(
|
||||||
|
self,
|
||||||
asyncio.get_event_loop(),
|
asyncio.get_event_loop(),
|
||||||
self.event_queue,
|
|
||||||
self.should_exit,
|
self.should_exit,
|
||||||
)
|
)
|
||||||
asyncio.ensure_future(self.main())
|
|
||||||
asyncio.ensure_future(self.tick())
|
asyncio.ensure_future(self.tick())
|
||||||
|
|
||||||
self.options = opts or options.Options() # type: options.Options
|
self.options = opts or options.Options() # type: options.Options
|
||||||
@ -96,17 +94,6 @@ class Master:
|
|||||||
if self.server:
|
if self.server:
|
||||||
ServerThread(self.server).start()
|
ServerThread(self.server).start()
|
||||||
|
|
||||||
async def main(self):
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
mtype, obj = await self.event_queue.get()
|
|
||||||
except RuntimeError:
|
|
||||||
return
|
|
||||||
if mtype not in eventsequence.Events: # pragma: no cover
|
|
||||||
raise exceptions.ControlException("Unknown event %s" % repr(mtype))
|
|
||||||
self.addons.handle_lifecycle(mtype, obj)
|
|
||||||
self.event_queue.task_done()
|
|
||||||
|
|
||||||
async def tick(self):
|
async def tick(self):
|
||||||
if self.first_tick:
|
if self.first_tick:
|
||||||
self.first_tick = False
|
self.first_tick = False
|
||||||
@ -145,7 +132,7 @@ class Master:
|
|||||||
f.request.host, f.request.port = upstream_spec.address
|
f.request.host, f.request.port = upstream_spec.address
|
||||||
f.request.scheme = upstream_spec.scheme
|
f.request.scheme = upstream_spec.scheme
|
||||||
|
|
||||||
def load_flow(self, f):
|
async def load_flow(self, f):
|
||||||
"""
|
"""
|
||||||
Loads a flow and links websocket & handshake flows
|
Loads a flow and links websocket & handshake flows
|
||||||
"""
|
"""
|
||||||
@ -163,7 +150,7 @@ class Master:
|
|||||||
|
|
||||||
f.reply = controller.DummyReply()
|
f.reply = controller.DummyReply()
|
||||||
for e, o in eventsequence.iterate(f):
|
for e, o in eventsequence.iterate(f):
|
||||||
self.addons.handle_lifecycle(e, o)
|
await self.addons.handle_lifecycle(e, o)
|
||||||
|
|
||||||
def replay_request(
|
def replay_request(
|
||||||
self,
|
self,
|
||||||
|
@ -4,6 +4,7 @@ import logging
|
|||||||
import os.path
|
import os.path
|
||||||
import re
|
import re
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
import asyncio
|
||||||
|
|
||||||
import mitmproxy.flow
|
import mitmproxy.flow
|
||||||
import tornado.escape
|
import tornado.escape
|
||||||
@ -235,7 +236,7 @@ class DumpFlows(RequestHandler):
|
|||||||
self.view.clear()
|
self.view.clear()
|
||||||
bio = BytesIO(self.filecontents)
|
bio = BytesIO(self.filecontents)
|
||||||
for i in io.FlowReader(bio).stream():
|
for i in io.FlowReader(bio).stream():
|
||||||
self.master.load_flow(i)
|
asyncio.call_soon(self.master.load_flow, i)
|
||||||
bio.close()
|
bio.close()
|
||||||
|
|
||||||
|
|
||||||
|
@ -123,8 +123,6 @@ class TcpMixin:
|
|||||||
i2 = self.pathod("306")
|
i2 = self.pathod("306")
|
||||||
self._ignore_off()
|
self._ignore_off()
|
||||||
|
|
||||||
self.master.event_queue.join()
|
|
||||||
|
|
||||||
assert n.status_code == 304
|
assert n.status_code == 304
|
||||||
assert i.status_code == 305
|
assert i.status_code == 305
|
||||||
assert i2.status_code == 306
|
assert i2.status_code == 306
|
||||||
@ -168,8 +166,6 @@ class TcpMixin:
|
|||||||
i2 = self.pathod("306")
|
i2 = self.pathod("306")
|
||||||
self._tcpproxy_off()
|
self._tcpproxy_off()
|
||||||
|
|
||||||
self.master.event_queue.join()
|
|
||||||
|
|
||||||
assert n.status_code == 304
|
assert n.status_code == 304
|
||||||
assert i.status_code == 305
|
assert i.status_code == 305
|
||||||
assert i2.status_code == 306
|
assert i2.status_code == 306
|
||||||
|
@ -65,7 +65,8 @@ def test_halt():
|
|||||||
assert end.custom_called
|
assert end.custom_called
|
||||||
|
|
||||||
|
|
||||||
def test_lifecycle():
|
@pytest.mark.asyncio
|
||||||
|
async def test_lifecycle():
|
||||||
o = options.Options()
|
o = options.Options()
|
||||||
m = master.Master(o)
|
m = master.Master(o)
|
||||||
a = addonmanager.AddonManager(m)
|
a = addonmanager.AddonManager(m)
|
||||||
@ -77,7 +78,7 @@ def test_lifecycle():
|
|||||||
a.remove(TAddon("nonexistent"))
|
a.remove(TAddon("nonexistent"))
|
||||||
|
|
||||||
f = tflow.tflow()
|
f = tflow.tflow()
|
||||||
a.handle_lifecycle("request", f)
|
await a.handle_lifecycle("request", f)
|
||||||
|
|
||||||
a._configure_all(o, o.keys())
|
a._configure_all(o, o.keys())
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import io
|
|||||||
from unittest import mock
|
from unittest import mock
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from mitmproxy.test import tflow, tutils
|
from mitmproxy.test import tflow, tutils, taddons
|
||||||
import mitmproxy.io
|
import mitmproxy.io
|
||||||
from mitmproxy import flowfilter
|
from mitmproxy import flowfilter
|
||||||
from mitmproxy import options
|
from mitmproxy import options
|
||||||
@ -97,27 +97,27 @@ class TestSerialize:
|
|||||||
|
|
||||||
|
|
||||||
class TestFlowMaster:
|
class TestFlowMaster:
|
||||||
def test_load_http_flow_reverse(self):
|
@pytest.mark.asyncio
|
||||||
s = tservers.TestState()
|
async def test_load_http_flow_reverse(self):
|
||||||
opts = options.Options(
|
opts = options.Options(
|
||||||
mode="reverse:https://use-this-domain"
|
mode="reverse:https://use-this-domain"
|
||||||
)
|
)
|
||||||
fm = master.Master(opts)
|
s = tservers.TestState()
|
||||||
fm.addons.add(s)
|
with taddons.context(s, options=opts) as ctx:
|
||||||
f = tflow.tflow(resp=True)
|
f = tflow.tflow(resp=True)
|
||||||
fm.load_flow(f)
|
await ctx.master.load_flow(f)
|
||||||
assert s.flows[0].request.host == "use-this-domain"
|
assert s.flows[0].request.host == "use-this-domain"
|
||||||
|
|
||||||
def test_load_websocket_flow(self):
|
@pytest.mark.asyncio
|
||||||
s = tservers.TestState()
|
async def test_load_websocket_flow(self):
|
||||||
opts = options.Options(
|
opts = options.Options(
|
||||||
mode="reverse:https://use-this-domain"
|
mode="reverse:https://use-this-domain"
|
||||||
)
|
)
|
||||||
fm = master.Master(opts)
|
s = tservers.TestState()
|
||||||
fm.addons.add(s)
|
with taddons.context(s, options=opts) as ctx:
|
||||||
f = tflow.twebsocketflow()
|
f = tflow.twebsocketflow()
|
||||||
fm.load_flow(f.handshake_flow)
|
await ctx.master.load_flow(f.handshake_flow)
|
||||||
fm.load_flow(f)
|
await ctx.master.load_flow(f)
|
||||||
assert s.flows[0].request.host == "use-this-domain"
|
assert s.flows[0].request.host == "use-this-domain"
|
||||||
assert s.flows[1].handshake_flow == f.handshake_flow
|
assert s.flows[1].handshake_flow == f.handshake_flow
|
||||||
assert len(s.flows[1].messages) == len(f.messages)
|
assert len(s.flows[1].messages) == len(f.messages)
|
||||||
@ -150,31 +150,27 @@ class TestFlowMaster:
|
|||||||
assert rt.f.request.http_version == "HTTP/1.1"
|
assert rt.f.request.http_version == "HTTP/1.1"
|
||||||
assert ":authority" not in rt.f.request.headers
|
assert ":authority" not in rt.f.request.headers
|
||||||
|
|
||||||
def test_all(self):
|
@pytest.mark.asyncio
|
||||||
|
async def test_all(self):
|
||||||
|
opts = options.Options(
|
||||||
|
mode="reverse:https://use-this-domain"
|
||||||
|
)
|
||||||
s = tservers.TestState()
|
s = tservers.TestState()
|
||||||
fm = master.Master(None)
|
with taddons.context(s, options=opts) as ctx:
|
||||||
fm.addons.add(s)
|
|
||||||
f = tflow.tflow(req=None)
|
f = tflow.tflow(req=None)
|
||||||
fm.addons.handle_lifecycle("clientconnect", f.client_conn)
|
await ctx.master.addons.handle_lifecycle("clientconnect", f.client_conn)
|
||||||
f.request = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq())
|
f.request = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq())
|
||||||
fm.addons.handle_lifecycle("request", f)
|
await ctx.master.addons.handle_lifecycle("request", f)
|
||||||
assert len(s.flows) == 1
|
assert len(s.flows) == 1
|
||||||
|
|
||||||
f.response = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp())
|
f.response = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp())
|
||||||
fm.addons.handle_lifecycle("response", f)
|
await ctx.master.addons.handle_lifecycle("response", f)
|
||||||
assert len(s.flows) == 1
|
assert len(s.flows) == 1
|
||||||
|
|
||||||
fm.addons.handle_lifecycle("clientdisconnect", f.client_conn)
|
await ctx.master.addons.handle_lifecycle("clientdisconnect", f.client_conn)
|
||||||
|
|
||||||
f.error = flow.Error("msg")
|
f.error = flow.Error("msg")
|
||||||
fm.addons.handle_lifecycle("error", f)
|
await ctx.master.addons.handle_lifecycle("error", f)
|
||||||
|
|
||||||
# FIXME: This no longer works, because we consume on the main loop.
|
|
||||||
# fm.tell("foo", f)
|
|
||||||
# with pytest.raises(ControlException):
|
|
||||||
# fm.addons.trigger("unknown")
|
|
||||||
|
|
||||||
fm.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
class TestError:
|
class TestError:
|
||||||
|
@ -4,13 +4,15 @@ from mitmproxy.tools.console import keymap
|
|||||||
from mitmproxy.tools.console import master
|
from mitmproxy.tools.console import master
|
||||||
from mitmproxy import command
|
from mitmproxy import command
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
def test_commands_exist():
|
@pytest.mark.asyncio
|
||||||
|
async def test_commands_exist():
|
||||||
km = keymap.Keymap(None)
|
km = keymap.Keymap(None)
|
||||||
defaultkeys.map(km)
|
defaultkeys.map(km)
|
||||||
assert km.bindings
|
assert km.bindings
|
||||||
m = master.ConsoleMaster(None)
|
m = master.ConsoleMaster(None)
|
||||||
m.load_flow(tflow())
|
await m.load_flow(tflow())
|
||||||
|
|
||||||
for binding in km.bindings:
|
for binding in km.bindings:
|
||||||
cmd, *args = command.lexer(binding.command)
|
cmd, *args = command.lexer(binding.command)
|
||||||
|
@ -4,6 +4,10 @@ from mitmproxy import options
|
|||||||
from mitmproxy.tools import console
|
from mitmproxy.tools import console
|
||||||
from ... import tservers
|
from ... import tservers
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
class TestMaster(tservers.MasterTest):
|
class TestMaster(tservers.MasterTest):
|
||||||
def mkmaster(self, **opts):
|
def mkmaster(self, **opts):
|
||||||
@ -12,11 +16,11 @@ class TestMaster(tservers.MasterTest):
|
|||||||
m.addons.trigger("configure", o.keys())
|
m.addons.trigger("configure", o.keys())
|
||||||
return m
|
return m
|
||||||
|
|
||||||
def test_basic(self):
|
async def test_basic(self):
|
||||||
m = self.mkmaster()
|
m = self.mkmaster()
|
||||||
for i in (1, 2, 3):
|
for i in (1, 2, 3):
|
||||||
try:
|
try:
|
||||||
self.dummy_cycle(m, 1, b"")
|
await self.dummy_cycle(m, 1, b"")
|
||||||
except urwid.ExitMainLoop:
|
except urwid.ExitMainLoop:
|
||||||
pass
|
pass
|
||||||
assert len(m.view) == i
|
assert len(m.view) == i
|
||||||
|
@ -2,6 +2,7 @@ import json as _json
|
|||||||
import logging
|
import logging
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
import os
|
import os
|
||||||
|
import asyncio
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import tornado.testing
|
import tornado.testing
|
||||||
@ -32,6 +33,11 @@ def json(resp: httpclient.HTTPResponse):
|
|||||||
|
|
||||||
@pytest.mark.usefixtures("no_tornado_logging")
|
@pytest.mark.usefixtures("no_tornado_logging")
|
||||||
class TestApp(tornado.testing.AsyncHTTPTestCase):
|
class TestApp(tornado.testing.AsyncHTTPTestCase):
|
||||||
|
def get_new_ioloop(self):
|
||||||
|
io_loop = tornado.platform.asyncio.AsyncIOLoop()
|
||||||
|
asyncio.set_event_loop(io_loop.asyncio_loop)
|
||||||
|
return io_loop
|
||||||
|
|
||||||
def get_app(self):
|
def get_app(self):
|
||||||
o = options.Options(http2=False)
|
o = options.Options(http2=False)
|
||||||
m = webmaster.WebMaster(o, with_termlog=False)
|
m = webmaster.WebMaster(o, with_termlog=False)
|
||||||
@ -75,12 +81,6 @@ class TestApp(tornado.testing.AsyncHTTPTestCase):
|
|||||||
resp = self.fetch("/flows/dump")
|
resp = self.fetch("/flows/dump")
|
||||||
assert b"address" in resp.body
|
assert b"address" in resp.body
|
||||||
|
|
||||||
self.view.clear()
|
|
||||||
assert not len(self.view)
|
|
||||||
|
|
||||||
assert self.fetch("/flows/dump", method="POST", body=resp.body).code == 200
|
|
||||||
assert len(self.view)
|
|
||||||
|
|
||||||
def test_clear(self):
|
def test_clear(self):
|
||||||
events = self.events.data.copy()
|
events = self.events.data.copy()
|
||||||
flows = list(self.view)
|
flows = list(self.view)
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
from mitmproxy.tools.web import master
|
from mitmproxy.tools.web import master
|
||||||
from mitmproxy import options
|
from mitmproxy import options
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from ... import tservers
|
from ... import tservers
|
||||||
|
|
||||||
|
|
||||||
@ -9,8 +11,9 @@ class TestWebMaster(tservers.MasterTest):
|
|||||||
o = options.Options(**opts)
|
o = options.Options(**opts)
|
||||||
return master.WebMaster(o)
|
return master.WebMaster(o)
|
||||||
|
|
||||||
def test_basic(self):
|
@pytest.mark.asyncio
|
||||||
|
async def test_basic(self):
|
||||||
m = self.mkmaster()
|
m = self.mkmaster()
|
||||||
for i in (1, 2, 3):
|
for i in (1, 2, 3):
|
||||||
self.dummy_cycle(m, 1, b"")
|
await self.dummy_cycle(m, 1, b"")
|
||||||
assert len(m.view) == i
|
assert len(m.view) == i
|
||||||
|
@ -26,20 +26,20 @@ from mitmproxy.test import taddons
|
|||||||
|
|
||||||
class MasterTest:
|
class MasterTest:
|
||||||
|
|
||||||
def cycle(self, master, content):
|
async def cycle(self, master, content):
|
||||||
f = tflow.tflow(req=tutils.treq(content=content))
|
f = tflow.tflow(req=tutils.treq(content=content))
|
||||||
layer = mock.Mock("mitmproxy.proxy.protocol.base.Layer")
|
layer = mock.Mock("mitmproxy.proxy.protocol.base.Layer")
|
||||||
layer.client_conn = f.client_conn
|
layer.client_conn = f.client_conn
|
||||||
layer.reply = controller.DummyReply()
|
layer.reply = controller.DummyReply()
|
||||||
master.addons.handle_lifecycle("clientconnect", layer)
|
await master.addons.handle_lifecycle("clientconnect", layer)
|
||||||
for i in eventsequence.iterate(f):
|
for i in eventsequence.iterate(f):
|
||||||
master.addons.handle_lifecycle(*i)
|
await master.addons.handle_lifecycle(*i)
|
||||||
master.addons.handle_lifecycle("clientdisconnect", layer)
|
await master.addons.handle_lifecycle("clientdisconnect", layer)
|
||||||
return f
|
return f
|
||||||
|
|
||||||
def dummy_cycle(self, master, n, content):
|
async def dummy_cycle(self, master, n, content):
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
self.cycle(master, content)
|
await self.cycle(master, content)
|
||||||
master.shutdown()
|
master.shutdown()
|
||||||
|
|
||||||
def flowfile(self, path):
|
def flowfile(self, path):
|
||||||
|
Loading…
Reference in New Issue
Block a user