clean up initialization mess

We now manage the eventloop ourselves no matter which tool.
This commit is contained in:
Maximilian Hils 2022-03-15 11:15:05 +01:00
parent 46ccf6049c
commit bbc65e5f37
7 changed files with 344 additions and 212 deletions

View File

View File

@ -0,0 +1,282 @@
"""
SPDX-License-Identifier: Apache-2.0
Vendored partial copy of https://github.com/tornadoweb/tornado/blob/master/tornado/platform/asyncio.py @ e18ea03
to fix https://github.com/tornadoweb/tornado/issues/3092. Can be removed once tornado >6.1 is out.
"""
import asyncio
import atexit
import errno
import functools
import socket
import threading
import typing
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
import select
if typing.TYPE_CHECKING:
from typing import Set # noqa: F401
from typing_extensions import Protocol
class _HasFileno(Protocol):
def fileno(self) -> int:
pass
_FileDescriptorLike = Union[int, _HasFileno]
_T = TypeVar("_T")
# Collection of selector thread event loops to shut down on exit.
_selector_loops = set() # type: Set[AddThreadSelectorEventLoop]
def _atexit_callback() -> None:
for loop in _selector_loops:
with loop._select_cond:
loop._closing_selector = True
loop._select_cond.notify()
try:
loop._waker_w.send(b"a")
except BlockingIOError:
pass
# If we don't join our (daemon) thread here, we may get a deadlock
# during interpreter shutdown. I don't really understand why. This
# deadlock happens every time in CI (both travis and appveyor) but
# I've never been able to reproduce locally.
loop._thread.join()
_selector_loops.clear()
atexit.register(_atexit_callback)
class AddThreadSelectorEventLoop(asyncio.AbstractEventLoop):
"""Wrap an event loop to add implementations of the ``add_reader`` method family.
Instances of this class start a second thread to run a selector.
This thread is completely hidden from the user; all callbacks are
run on the wrapped event loop's thread.
This class is used automatically by Tornado; applications should not need
to refer to it directly.
It is safe to wrap any event loop with this class, although it only makes sense
for event loops that do not implement the ``add_reader`` family of methods
themselves (i.e. ``WindowsProactorEventLoop``)
Closing the ``AddThreadSelectorEventLoop`` also closes the wrapped event loop.
"""
# This class is a __getattribute__-based proxy. All attributes other than those
# in this set are proxied through to the underlying loop.
MY_ATTRIBUTES = {
"_consume_waker",
"_select_cond",
"_select_args",
"_closing_selector",
"_thread",
"_handle_event",
"_readers",
"_real_loop",
"_start_select",
"_run_select",
"_handle_select",
"_wake_selector",
"_waker_r",
"_waker_w",
"_writers",
"add_reader",
"add_writer",
"close",
"remove_reader",
"remove_writer",
}
def __getattribute__(self, name: str) -> Any:
if name in AddThreadSelectorEventLoop.MY_ATTRIBUTES:
return super().__getattribute__(name)
return getattr(self._real_loop, name)
def __init__(self, real_loop: asyncio.AbstractEventLoop) -> None:
self._real_loop = real_loop
# Create a thread to run the select system call. We manage this thread
# manually so we can trigger a clean shutdown from an atexit hook. Note
# that due to the order of operations at shutdown, only daemon threads
# can be shut down in this way (non-daemon threads would require the
# introduction of a new hook: https://bugs.python.org/issue41962)
self._select_cond = threading.Condition()
self._select_args = (
None
) # type: Optional[Tuple[List[_FileDescriptorLike], List[_FileDescriptorLike]]]
self._closing_selector = False
self._thread = threading.Thread(
name="Tornado selector",
daemon=True,
target=self._run_select,
)
self._thread.start()
# Start the select loop once the loop is started.
self._real_loop.call_soon(self._start_select)
self._readers = {} # type: Dict[_FileDescriptorLike, Callable]
self._writers = {} # type: Dict[_FileDescriptorLike, Callable]
# Writing to _waker_w will wake up the selector thread, which
# watches for _waker_r to be readable.
self._waker_r, self._waker_w = socket.socketpair()
self._waker_r.setblocking(False)
self._waker_w.setblocking(False)
_selector_loops.add(self)
self.add_reader(self._waker_r, self._consume_waker)
def __del__(self) -> None:
# If the top-level application code uses asyncio interfaces to
# start and stop the event loop, no objects created in Tornado
# can get a clean shutdown notification. If we're just left to
# be GC'd, we must explicitly close our sockets to avoid
# logging warnings.
_selector_loops.discard(self)
self._waker_r.close()
self._waker_w.close()
def close(self) -> None:
with self._select_cond:
self._closing_selector = True
self._select_cond.notify()
self._wake_selector()
self._thread.join()
_selector_loops.discard(self)
self._waker_r.close()
self._waker_w.close()
self._real_loop.close()
def _wake_selector(self) -> None:
try:
self._waker_w.send(b"a")
except BlockingIOError:
pass
def _consume_waker(self) -> None:
try:
self._waker_r.recv(1024)
except BlockingIOError:
pass
def _start_select(self) -> None:
# Capture reader and writer sets here in the event loop
# thread to avoid any problems with concurrent
# modification while the select loop uses them.
with self._select_cond:
assert self._select_args is None
self._select_args = (list(self._readers.keys()), list(self._writers.keys()))
self._select_cond.notify()
def _run_select(self) -> None:
while True:
with self._select_cond:
while self._select_args is None and not self._closing_selector:
self._select_cond.wait()
if self._closing_selector:
return
assert self._select_args is not None
to_read, to_write = self._select_args
self._select_args = None
# We use the simpler interface of the select module instead of
# the more stateful interface in the selectors module because
# this class is only intended for use on windows, where
# select.select is the only option. The selector interface
# does not have well-documented thread-safety semantics that
# we can rely on so ensuring proper synchronization would be
# tricky.
try:
# On windows, selecting on a socket for write will not
# return the socket when there is an error (but selecting
# for reads works). Also select for errors when selecting
# for writes, and merge the results.
#
# This pattern is also used in
# https://github.com/python/cpython/blob/v3.8.0/Lib/selectors.py#L312-L317
rs, ws, xs = select.select(to_read, to_write, to_write)
ws = ws + xs
except OSError as e:
# After remove_reader or remove_writer is called, the file
# descriptor may subsequently be closed on the event loop
# thread. It's possible that this select thread hasn't
# gotten into the select system call by the time that
# happens in which case (at least on macOS), select may
# raise a "bad file descriptor" error. If we get that
# error, check and see if we're also being woken up by
# polling the waker alone. If we are, just return to the
# event loop and we'll get the updated set of file
# descriptors on the next iteration. Otherwise, raise the
# original error.
if e.errno == getattr(errno, "WSAENOTSOCK", errno.EBADF):
rs, _, _ = select.select([self._waker_r.fileno()], [], [], 0)
if rs:
ws = []
else:
raise
else:
raise
try:
self._real_loop.call_soon_threadsafe(self._handle_select, rs, ws)
except RuntimeError:
# "Event loop is closed". Swallow the exception for
# consistency with PollIOLoop (and logical consistency
# with the fact that we can't guarantee that an
# add_callback that completes without error will
# eventually execute).
pass
except AttributeError:
# ProactorEventLoop may raise this instead of RuntimeError
# if call_soon_threadsafe races with a call to close().
# Swallow it too for consistency.
pass
def _handle_select(
self, rs: List["_FileDescriptorLike"], ws: List["_FileDescriptorLike"]
) -> None:
for r in rs:
self._handle_event(r, self._readers)
for w in ws:
self._handle_event(w, self._writers)
self._start_select()
def _handle_event(
self,
fd: "_FileDescriptorLike",
cb_map: Dict["_FileDescriptorLike", Callable],
) -> None:
try:
callback = cb_map[fd]
except KeyError:
return
callback()
def add_reader(
self, fd: "_FileDescriptorLike", callback: Callable[..., None], *args: Any
) -> None:
self._readers[fd] = functools.partial(callback, *args)
self._wake_selector()
def add_writer(
self, fd: "_FileDescriptorLike", callback: Callable[..., None], *args: Any
) -> None:
self._writers[fd] = functools.partial(callback, *args)
self._wake_selector()
def remove_reader(self, fd: "_FileDescriptorLike") -> None:
del self._readers[fd]
self._wake_selector()
def remove_writer(self, fd: "_FileDescriptorLike") -> None:
del self._writers[fd]
self._wake_selector()

View File

@ -1,9 +1,5 @@
import asyncio import asyncio
import logging
import sys
import threading
import traceback import traceback
from typing import Callable
from mitmproxy import addonmanager, hooks from mitmproxy import addonmanager, hooks
from mitmproxy import command from mitmproxy import command
@ -14,39 +10,47 @@ from mitmproxy import options
from mitmproxy.net import server_spec from mitmproxy.net import server_spec
from . import ctx as mitmproxy_ctx from . import ctx as mitmproxy_ctx
# Conclusively preventing cross-thread races on proxy shutdown turns out to be
# very hard. We could build a thread sync infrastructure for this, or we could
# wait until we ditch threads and move all the protocols into the async loop.
# Until then, silence non-critical errors.
logging.getLogger('asyncio').setLevel(logging.CRITICAL)
class Master: class Master:
""" """
The master handles mitmproxy's main event loop. The master handles mitmproxy's main event loop.
""" """
event_loop: asyncio.AbstractEventLoop
def __init__(self, opts): def __init__(self, opts):
self.should_exit = threading.Event() self.should_exit = asyncio.Event()
self.event_loop = asyncio.get_event_loop()
self.options: options.Options = opts or options.Options() self.options: options.Options = opts or options.Options()
self.commands = command.CommandManager(self) self.commands = command.CommandManager(self)
self.addons = addonmanager.AddonManager(self) self.addons = addonmanager.AddonManager(self)
self._server = None
self.log = log.Log(self) self.log = log.Log(self)
mitmproxy_ctx.master = self mitmproxy_ctx.master = self
mitmproxy_ctx.log = self.log mitmproxy_ctx.log = self.log
mitmproxy_ctx.options = self.options mitmproxy_ctx.options = self.options
def start(self): async def run(self) -> None:
self.event_loop = asyncio.get_running_loop()
self.event_loop.set_exception_handler(self._asyncio_exception_handler)
self.should_exit.clear() self.should_exit.clear()
async def running(self): await self.running()
self.addons.trigger(hooks.RunningHook()) await self.should_exit.wait()
# We set the exception handler here because urwid's run() method overwrites it. await self.done()
asyncio.get_running_loop().set_exception_handler(self._asyncio_exception_handler)
def shutdown(self):
"""
Shut down the proxy. This method is thread-safe.
"""
# We may add an exception argument here.
self.event_loop.call_soon_threadsafe(self.should_exit.set)
async def running(self) -> None:
await self.addons.trigger_event(hooks.RunningHook())
async def done(self) -> None:
await self.addons.trigger_event(hooks.DoneHook())
def _asyncio_exception_handler(self, loop, context): def _asyncio_exception_handler(self, loop, context):
exc: Exception = context["exception"] exc: Exception = context["exception"]
@ -58,57 +62,6 @@ class Master:
"\n\thttps://github.com/mitmproxy/mitmproxy/issues" "\n\thttps://github.com/mitmproxy/mitmproxy/issues"
) )
def run_loop(self, run_forever: Callable) -> None:
self.start()
asyncio.ensure_future(self.running())
exc = None
try:
run_forever()
except Exception: # pragma: no cover
exc = traceback.format_exc()
finally:
if not self.should_exit.is_set(): # pragma: no cover
self.shutdown()
loop = asyncio.get_event_loop()
tasks = asyncio.all_tasks(loop)
for p in tasks:
p.cancel()
loop.close()
if exc: # pragma: no cover
print(exc, file=sys.stderr)
print("mitmproxy has crashed!", file=sys.stderr)
print("Please lodge a bug report at:", file=sys.stderr)
print("\thttps://github.com/mitmproxy/mitmproxy/issues", file=sys.stderr)
self.addons.trigger(hooks.DoneHook())
def run(self):
loop = asyncio.get_event_loop()
self.run_loop(loop.run_forever)
async def _shutdown(self):
self.should_exit.set()
loop = asyncio.get_event_loop()
loop.stop()
def shutdown(self):
"""
Shut down the proxy. This method is thread-safe.
"""
if not self.should_exit.is_set():
self.should_exit.set()
ret = asyncio.run_coroutine_threadsafe(self._shutdown(), loop=self.event_loop)
# Weird band-aid to make sure that self._shutdown() is actually executed,
# which otherwise hangs the process as the proxy server is threaded.
# This all needs to be simplified when the proxy server runs on asyncio as well.
if not self.event_loop.is_running(): # pragma: no cover
try:
self.event_loop.run_until_complete(asyncio.wrap_future(ret, loop=self.event_loop))
except RuntimeError:
pass # Event loop stopped before Future completed.
async def load_flow(self, f): async def load_flow(self, f):
""" """
Loads a flow Loads a flow

View File

@ -5,16 +5,14 @@ import os
import os.path import os.path
import shlex import shlex
import shutil import shutil
import signal
import stat import stat
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
import typing # noqa
import contextlib import contextlib
import threading import threading
from tornado.platform.asyncio import AddThreadSelectorEventLoop from mitmproxy.contrib.tornado.asyncio import AddThreadSelectorEventLoop
import urwid import urwid
@ -38,8 +36,6 @@ class ConsoleMaster(master.Master):
def __init__(self, opts): def __init__(self, opts):
super().__init__(opts) super().__init__(opts)
self.start_err: typing.Optional[log.LogEntry] = None
self.view: view.View = view.View() self.view: view.View = view.View()
self.events = eventstore.EventStore() self.events = eventstore.EventStore()
self.events.sig_add.connect(self.sig_add_log) self.events.sig_add.connect(self.sig_add_log)
@ -61,11 +57,6 @@ class ConsoleMaster(master.Master):
keymap.KeymapConfig(), keymap.KeymapConfig(),
) )
def sigint_handler(*args, **kwargs):
self.prompt_for_exit()
signal.signal(signal.SIGINT, sigint_handler)
self.window = None self.window = None
def __setattr__(self, name, value): def __setattr__(self, name, value):
@ -201,7 +192,7 @@ class ConsoleMaster(master.Master):
def inject_key(self, key): def inject_key(self, key):
self.loop.process_input([key]) self.loop.process_input([key])
def run(self): async def running(self) -> None:
if not sys.stdout.isatty(): if not sys.stdout.isatty():
print("Error: mitmproxy's console interface requires a tty. " print("Error: mitmproxy's console interface requires a tty. "
"Please run mitmproxy in an interactive shell environment.", file=sys.stderr) "Please run mitmproxy in an interactive shell environment.", file=sys.stderr)
@ -215,7 +206,8 @@ class ConsoleMaster(master.Master):
self.set_palette, self.set_palette,
["console_palette", "console_palette_transparent"] ["console_palette", "console_palette_transparent"]
) )
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
if isinstance(loop, getattr(asyncio, "ProactorEventLoop", tuple())): if isinstance(loop, getattr(asyncio, "ProactorEventLoop", tuple())):
# fix for https://bugs.python.org/issue37373 # fix for https://bugs.python.org/issue37373
loop = AddThreadSelectorEventLoop(loop) loop = AddThreadSelectorEventLoop(loop)
@ -229,13 +221,13 @@ class ConsoleMaster(master.Master):
self.loop.widget = self.window self.loop.widget = self.window
self.window.refresh() self.window.refresh()
if self.start_err: self.loop.start()
def display_err(*_):
self.sig_add_log(None, self.start_err)
self.start_err = None
self.loop.set_alarm_in(0.01, display_err)
super().run_loop(self.loop.run) await super().running()
async def done(self):
self.loop.stop()
await super().done()
def overlay(self, widget, **kwargs): def overlay(self, widget, **kwargs):
self.window.set_overlay(widget, **kwargs) self.window.set_overlay(widget, **kwargs)

View File

@ -9,7 +9,7 @@ from mitmproxy import exceptions, master
from mitmproxy import options from mitmproxy import options
from mitmproxy import optmanager from mitmproxy import optmanager
from mitmproxy.tools import cmdline from mitmproxy.tools import cmdline
from mitmproxy.utils import debug, arg_check from mitmproxy.utils import asyncio_utils, debug, arg_check
def assert_utf8_env(): def assert_utf8_env():
@ -95,20 +95,27 @@ def run(
master.log.info(f"Only processing flows that match \"{' & '.join(args.filter_args)}\"") master.log.info(f"Only processing flows that match \"{' & '.join(args.filter_args)}\"")
opts.update(**extra(args)) opts.update(**extra(args))
loop = asyncio.get_event_loop()
try:
loop.add_signal_handler(signal.SIGINT, getattr(master, "prompt_for_exit", master.shutdown))
loop.add_signal_handler(signal.SIGTERM, master.shutdown)
except NotImplementedError:
# Not supported on Windows
pass
master.run()
except exceptions.OptionsError as e: except exceptions.OptionsError as e:
print("{}: {}".format(sys.argv[0], e), file=sys.stderr) print("{}: {}".format(sys.argv[0], e), file=sys.stderr)
sys.exit(1) sys.exit(1)
except (KeyboardInterrupt, RuntimeError):
pass async def main():
loop = asyncio.get_running_loop()
def _sigint(*_):
loop.call_soon_threadsafe(getattr(master, "prompt_for_exit", master.shutdown))
def _sigterm(*_):
loop.call_soon_threadsafe(master.shutdown)
# We can't use loop.add_signal_handler because that's not available on Windows' Proactorloop,
# but signal.signal just works fine for our purposes.
signal.signal(signal.SIGINT, _sigint)
signal.signal(signal.SIGTERM, _sigterm)
return await master.run()
asyncio.run(main())
return master return master

View File

@ -92,13 +92,16 @@ class WebMaster(master.Master):
data=options_dict data=options_dict
) )
def run(self): # pragma: no cover async def running(self):
AsyncIOMainLoop().install() # Register tornado with the current event loop
iol = tornado.ioloop.IOLoop.instance() tornado.ioloop.IOLoop.current()
# Add our web app.
http_server = tornado.httpserver.HTTPServer(self.app) http_server = tornado.httpserver.HTTPServer(self.app)
http_server.listen(self.options.web_port, self.options.web_host) http_server.listen(self.options.web_port, self.options.web_host)
web_url = f"http://{self.options.web_host}:{self.options.web_port}/"
self.log.info( self.log.info(
f"Web server listening at {web_url}", f"Web server listening at http://{self.options.web_host}:{self.options.web_port}/",
) )
self.run_loop(iol.start)
return await super().running()

View File

@ -66,108 +66,3 @@ def task_repr(task: asyncio.Task) -> str:
if client: if client:
client = f"{human.format_address(client)}: " client = f"{human.format_address(client)}: "
return f"{client}{name}{age}" return f"{client}{name}{age}"
T = TypeVar("T")
def run(
main_func: Awaitable[T],
ctrl_c_handler: Callable,
sigterm_handler: Callable,
) -> T:
"""
Like `asyncio.run`, but with cross-platform Ctrl+C support.
The main problem with Ctrl+C is that it raises a KeyboardInterrupt on Windows,
which terminates the current event loop. This method here moves the event loop to a second thread,
gracefully catches KeyboardInterrupt in the main thread, and then calls the sigint handler.
"""
loop = asyncio.new_event_loop()
try:
loop.add_signal_handler(signal.SIGINT, ctrl_c_handler)
loop.add_signal_handler(signal.SIGTERM, sigterm_handler)
return _run_loop(loop, main_func)
except NotImplementedError:
pass
# Windows code path. We don't make this path of the except clause above
# because that creates "during the handling of another exception" messages
with concurrent.futures.ThreadPoolExecutor(thread_name_prefix="eventloop") as executor:
future = executor.submit(_run_loop, loop, main_func)
while True:
try:
# A larger timeout doesn't work, KeyboardInterrupt is not detected then.
return future.result(.1)
except concurrent.futures.TimeoutError:
pass
except KeyboardInterrupt:
loop.call_soon_threadsafe(ctrl_c_handler)
def _run_loop(loop: asyncio.AbstractEventLoop, main_func: Awaitable[T]) -> T:
# this method mimics what `asyncio.run` is doing.
try:
asyncio.set_event_loop(loop)
return loop.run_until_complete(main_func)
finally:
try:
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
loop.run_until_complete(loop.shutdown_default_executor())
finally:
asyncio.set_event_loop(None)
loop.close()
# copied from https://github.com/python/cpython/blob/3.10/Lib/asyncio/runners.py
def _cancel_all_tasks(loop):
to_cancel = tasks.all_tasks(loop)
if not to_cancel:
return
for task in to_cancel:
task.cancel()
loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True))
for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler({
'message': 'unhandled exception during asyncio.run() shutdown',
'exception': task.exception(),
'task': task,
})
if __name__ == "__main__":
done = asyncio.Event()
async def main():
while True:
print("...")
try:
await asyncio.wait_for(done.wait(), 1)
break
except asyncio.TimeoutError:
pass
return 42
print(f"{run(main(), done.set, done.set)=}")
async def main_err():
raise RuntimeError
try:
run(main_err(), lambda: 0, lambda: 0)
except RuntimeError:
print("error propagation ok.")