GramCore/builtins/reloader.py
2023-08-05 10:41:12 +08:00

186 lines
5.6 KiB
Python

import inspect
import multiprocessing
import os
import signal
import threading
from pathlib import Path
from typing import Callable, Iterator, List, Optional, TYPE_CHECKING
from watchfiles import watch
from utils.const import HANDLED_SIGNALS, PROJECT_ROOT
from utils.log import logger
from utils.typedefs import StrOrPath
if TYPE_CHECKING:
from multiprocessing.process import BaseProcess
__all__ = ("Reloader",)
multiprocessing.allow_connection_pickling()
spawn = multiprocessing.get_context("spawn")
class FileFilter:
"""监控文件过滤"""
def __init__(self, includes: List[str], excludes: List[str]) -> None:
default_includes = ["*.py"]
self.includes = [default for default in default_includes if default not in excludes]
self.includes.extend(includes)
self.includes = list(set(self.includes))
default_excludes = [".*", ".py[cod]", ".sw.*", "~*", __file__]
self.excludes = [default for default in default_excludes if default not in includes]
self.exclude_dirs = []
for e in excludes:
p = Path(e)
try:
is_dir = p.is_dir()
except OSError:
is_dir = False
if is_dir:
self.exclude_dirs.append(p)
else:
self.excludes.append(e)
self.excludes = list(set(self.excludes))
def __call__(self, path: Path) -> bool:
for include_pattern in self.includes:
if path.match(include_pattern):
for exclude_dir in self.exclude_dirs:
if exclude_dir in path.parents:
return False
for exclude_pattern in self.excludes:
if path.match(exclude_pattern):
return False
return True
return False
class Reloader:
_target: Callable[..., None]
_process: "BaseProcess"
@property
def process(self) -> "BaseProcess":
return self._process
@property
def target(self) -> Callable[..., None]:
return self._target
def __init__(
self,
target: Callable[..., None],
*,
reload_delay: float = 0.25,
reload_dirs: List[StrOrPath] = None,
reload_includes: List[str] = None,
reload_excludes: List[str] = None,
):
if inspect.iscoroutinefunction(target):
raise ValueError("不支持异步函数")
self._target = target
self.reload_delay = reload_delay
_reload_dirs = []
for reload_dir in reload_dirs or []:
_reload_dirs.append(PROJECT_ROOT.joinpath(Path(reload_dir)))
self.reload_dirs = []
for reload_dir in _reload_dirs:
append = True
for parent in reload_dir.parents:
if parent in _reload_dirs:
append = False
break
if append:
self.reload_dirs.append(reload_dir)
if not self.reload_dirs:
logger.warning("需要检测的目标文件夹列表为空", extra={"tag": "Reloader"})
self._should_exit = threading.Event()
frame = inspect.currentframe().f_back
self.watch_filter = FileFilter(reload_includes or [], (reload_excludes or []) + [frame.f_globals["__file__"]])
self.watcher = watch(
*self.reload_dirs,
watch_filter=None,
stop_event=self._should_exit,
yield_on_timeout=True,
)
def get_changes(self) -> Optional[List[Path]]:
if not self._process.is_alive():
logger.info("目标进程已经关闭", extra={"tag": "Reloader"})
self._should_exit.set()
try:
changes = next(self.watcher)
except StopIteration:
return None
if changes:
unique_paths = {Path(c[1]) for c in changes}
return [p for p in unique_paths if self.watch_filter(p)]
return None
def __iter__(self) -> Iterator[Optional[List[Path]]]:
return self
def __next__(self) -> Optional[List[Path]]:
return self.get_changes()
def run(self) -> None:
self.startup()
for changes in self:
if changes:
logger.warning(
"检测到文件 %s 发生改变, 正在重载...",
[str(c.relative_to(PROJECT_ROOT)).replace(os.sep, "/") for c in changes],
extra={"tag": "Reloader"},
)
self.restart()
self.shutdown()
def signal_handler(self, *_) -> None:
"""当接收到结束信号量时"""
self._process.join(3)
if self._process.is_alive():
self._process.terminate()
self._process.join()
self._should_exit.set()
def startup(self) -> None:
"""启动进程"""
logger.info("目标进程正在启动", extra={"tag": "Reloader"})
for sig in HANDLED_SIGNALS:
signal.signal(sig, self.signal_handler)
self._process = spawn.Process(target=self._target)
self._process.start()
logger.success("目标进程启动成功", extra={"tag": "Reloader"})
def restart(self) -> None:
"""重启进程"""
self._process.terminate()
self._process.join(10)
self._process = spawn.Process(target=self._target)
self._process.start()
logger.info("目标进程已经重载", extra={"tag": "Reloader"})
def shutdown(self) -> None:
"""关闭进程"""
self._process.terminate()
self._process.join(10)
logger.info("重载器已经关闭", extra={"tag": "Reloader"})