PaiGram/core/builtins/reloader.py

186 lines
5.6 KiB
Python
Raw Normal View History

import inspect
import multiprocessing
import os
import signal
import threading
from pathlib import Path
from typing import Callable, Iterator, List, Optional, TYPE_CHECKING
from watchfiles import watch
from utils.const import HANDLED_SIGNALS, PROJECT_ROOT
from utils.log import logger
from utils.typedefs import StrOrPath
if TYPE_CHECKING:
from multiprocessing.process import BaseProcess
__all__ = ("Reloader",)
multiprocessing.allow_connection_pickling()
spawn = multiprocessing.get_context("spawn")
class FileFilter:
"""监控文件过滤"""
def __init__(self, includes: List[str], excludes: List[str]) -> None:
default_includes = ["*.py"]
self.includes = [default for default in default_includes if default not in excludes]
self.includes.extend(includes)
self.includes = list(set(self.includes))
default_excludes = [".*", ".py[cod]", ".sw.*", "~*", __file__]
self.excludes = [default for default in default_excludes if default not in includes]
self.exclude_dirs = []
for e in excludes:
p = Path(e)
try:
is_dir = p.is_dir()
except OSError:
is_dir = False
if is_dir:
self.exclude_dirs.append(p)
else:
self.excludes.append(e)
self.excludes = list(set(self.excludes))
def __call__(self, path: Path) -> bool:
for include_pattern in self.includes:
if path.match(include_pattern):
for exclude_dir in self.exclude_dirs:
if exclude_dir in path.parents:
return False
for exclude_pattern in self.excludes:
if path.match(exclude_pattern):
return False
return True
return False
class Reloader:
_target: Callable[..., None]
_process: "BaseProcess"
@property
def process(self) -> "BaseProcess":
return self._process
@property
def target(self) -> Callable[..., None]:
return self._target
def __init__(
self,
target: Callable[..., None],
*,
reload_delay: float = 0.25,
reload_dirs: List[StrOrPath] = None,
reload_includes: List[str] = None,
reload_excludes: List[str] = None,
):
if inspect.iscoroutinefunction(target):
raise ValueError("不支持异步函数")
self._target = target
self.reload_delay = reload_delay
_reload_dirs = []
for reload_dir in reload_dirs or []:
_reload_dirs.append(PROJECT_ROOT.joinpath(Path(reload_dir)))
self.reload_dirs = []
for reload_dir in _reload_dirs:
append = True
for parent in reload_dir.parents:
if parent in _reload_dirs:
append = False
break
if append:
self.reload_dirs.append(reload_dir)
if not self.reload_dirs:
logger.warning("需要检测的目标文件夹列表为空", extra={"tag": "Reloader"})
self._should_exit = threading.Event()
frame = inspect.currentframe().f_back
self.watch_filter = FileFilter(reload_includes or [], (reload_excludes or []) + [frame.f_globals["__file__"]])
self.watcher = watch(
*self.reload_dirs,
watch_filter=None,
stop_event=self._should_exit,
yield_on_timeout=True,
)
def get_changes(self) -> Optional[List[Path]]:
if not self._process.is_alive():
logger.info("目标进程已经关闭", extra={"tag": "Reloader"})
self._should_exit.set()
try:
changes = next(self.watcher)
except StopIteration:
return None
if changes:
unique_paths = {Path(c[1]) for c in changes}
return [p for p in unique_paths if self.watch_filter(p)]
return None
def __iter__(self) -> Iterator[Optional[List[Path]]]:
return self
def __next__(self) -> Optional[List[Path]]:
return self.get_changes()
def run(self) -> None:
self.startup()
for changes in self:
if changes:
logger.warning(
"检测到文件 %s 发生改变, 正在重载...",
[str(c.relative_to(PROJECT_ROOT)).replace(os.sep, "/") for c in changes],
extra={"tag": "Reloader"},
)
self.restart()
self.shutdown()
def signal_handler(self, *_) -> None:
"""当接收到结束信号量时"""
self._process.join(3)
if self._process.is_alive():
self._process.terminate()
self._process.join()
self._should_exit.set()
def startup(self) -> None:
"""启动进程"""
logger.info("目标进程正在启动", extra={"tag": "Reloader"})
for sig in HANDLED_SIGNALS:
signal.signal(sig, self.signal_handler)
self._process = spawn.Process(target=self._target)
self._process.start()
logger.success("目标进程启动成功", extra={"tag": "Reloader"})
def restart(self) -> None:
"""重启进程"""
self._process.terminate()
self._process.join(10)
self._process = spawn.Process(target=self._target)
self._process.start()
logger.info("目标进程已经重载", extra={"tag": "Reloader"})
def shutdown(self) -> None:
"""关闭进程"""
self._process.terminate()
self._process.join(10)
logger.info("重载器已经关闭", extra={"tag": "Reloader"})