mirror of
https://github.com/PaiGramTeam/PaiGram.git
synced 2024-11-16 12:51:35 +00:00
233e7ab58d
Co-authored-by: luoshuijs <luoshuijs@outlook.com> Co-authored-by: Karako <karakohear@gmail.com> Co-authored-by: xtaodada <xtao@xtaolink.cn>
186 lines
5.6 KiB
Python
186 lines
5.6 KiB
Python
import inspect
|
|
import multiprocessing
|
|
import os
|
|
import signal
|
|
import threading
|
|
from pathlib import Path
|
|
from typing import Callable, Iterator, List, Optional, TYPE_CHECKING
|
|
|
|
from watchfiles import watch
|
|
|
|
from utils.const import HANDLED_SIGNALS, PROJECT_ROOT
|
|
from utils.log import logger
|
|
from utils.typedefs import StrOrPath
|
|
|
|
if TYPE_CHECKING:
|
|
from multiprocessing.process import BaseProcess
|
|
|
|
__all__ = ("Reloader",)
|
|
|
|
multiprocessing.allow_connection_pickling()
|
|
spawn = multiprocessing.get_context("spawn")
|
|
|
|
|
|
class FileFilter:
|
|
"""监控文件过滤"""
|
|
|
|
def __init__(self, includes: List[str], excludes: List[str]) -> None:
|
|
default_includes = ["*.py"]
|
|
self.includes = [default for default in default_includes if default not in excludes]
|
|
self.includes.extend(includes)
|
|
self.includes = list(set(self.includes))
|
|
|
|
default_excludes = [".*", ".py[cod]", ".sw.*", "~*", __file__]
|
|
self.excludes = [default for default in default_excludes if default not in includes]
|
|
self.exclude_dirs = []
|
|
for e in excludes:
|
|
p = Path(e)
|
|
try:
|
|
is_dir = p.is_dir()
|
|
except OSError:
|
|
is_dir = False
|
|
|
|
if is_dir:
|
|
self.exclude_dirs.append(p)
|
|
else:
|
|
self.excludes.append(e)
|
|
self.excludes = list(set(self.excludes))
|
|
|
|
def __call__(self, path: Path) -> bool:
|
|
for include_pattern in self.includes:
|
|
if path.match(include_pattern):
|
|
for exclude_dir in self.exclude_dirs:
|
|
if exclude_dir in path.parents:
|
|
return False
|
|
|
|
for exclude_pattern in self.excludes:
|
|
if path.match(exclude_pattern):
|
|
return False
|
|
|
|
return True
|
|
return False
|
|
|
|
|
|
class Reloader:
|
|
_target: Callable[..., None]
|
|
_process: "BaseProcess"
|
|
|
|
@property
|
|
def process(self) -> "BaseProcess":
|
|
return self._process
|
|
|
|
@property
|
|
def target(self) -> Callable[..., None]:
|
|
return self._target
|
|
|
|
def __init__(
|
|
self,
|
|
target: Callable[..., None],
|
|
*,
|
|
reload_delay: float = 0.25,
|
|
reload_dirs: List[StrOrPath] = None,
|
|
reload_includes: List[str] = None,
|
|
reload_excludes: List[str] = None,
|
|
):
|
|
if inspect.iscoroutinefunction(target):
|
|
raise ValueError("不支持异步函数")
|
|
self._target = target
|
|
|
|
self.reload_delay = reload_delay
|
|
|
|
_reload_dirs = []
|
|
for reload_dir in reload_dirs or []:
|
|
_reload_dirs.append(PROJECT_ROOT.joinpath(Path(reload_dir)))
|
|
|
|
self.reload_dirs = []
|
|
for reload_dir in _reload_dirs:
|
|
append = True
|
|
for parent in reload_dir.parents:
|
|
if parent in _reload_dirs:
|
|
append = False
|
|
break
|
|
if append:
|
|
self.reload_dirs.append(reload_dir)
|
|
|
|
if not self.reload_dirs:
|
|
logger.warning("需要检测的目标文件夹列表为空", extra={"tag": "Reloader"})
|
|
|
|
self._should_exit = threading.Event()
|
|
|
|
frame = inspect.currentframe().f_back
|
|
|
|
self.watch_filter = FileFilter(reload_includes or [], (reload_excludes or []) + [frame.f_globals["__file__"]])
|
|
self.watcher = watch(
|
|
*self.reload_dirs,
|
|
watch_filter=None,
|
|
stop_event=self._should_exit,
|
|
yield_on_timeout=True,
|
|
)
|
|
|
|
def get_changes(self) -> Optional[List[Path]]:
|
|
if not self._process.is_alive():
|
|
logger.info("目标进程已经关闭", extra={"tag": "Reloader"})
|
|
self._should_exit.set()
|
|
try:
|
|
changes = next(self.watcher)
|
|
except StopIteration:
|
|
return None
|
|
if changes:
|
|
unique_paths = {Path(c[1]) for c in changes}
|
|
return [p for p in unique_paths if self.watch_filter(p)]
|
|
return None
|
|
|
|
def __iter__(self) -> Iterator[Optional[List[Path]]]:
|
|
return self
|
|
|
|
def __next__(self) -> Optional[List[Path]]:
|
|
return self.get_changes()
|
|
|
|
def run(self) -> None:
|
|
self.startup()
|
|
for changes in self:
|
|
if changes:
|
|
logger.warning(
|
|
"检测到文件 %s 发生改变, 正在重载...",
|
|
[str(c.relative_to(PROJECT_ROOT)).replace(os.sep, "/") for c in changes],
|
|
extra={"tag": "Reloader"},
|
|
)
|
|
self.restart()
|
|
|
|
self.shutdown()
|
|
|
|
def signal_handler(self, *_) -> None:
|
|
"""当接收到结束信号量时"""
|
|
self._process.join(3)
|
|
if self._process.is_alive():
|
|
self._process.terminate()
|
|
self._process.join()
|
|
self._should_exit.set()
|
|
|
|
def startup(self) -> None:
|
|
"""启动进程"""
|
|
logger.info("目标进程正在启动", extra={"tag": "Reloader"})
|
|
|
|
for sig in HANDLED_SIGNALS:
|
|
signal.signal(sig, self.signal_handler)
|
|
|
|
self._process = spawn.Process(target=self._target)
|
|
self._process.start()
|
|
logger.success("目标进程启动成功", extra={"tag": "Reloader"})
|
|
|
|
def restart(self) -> None:
|
|
"""重启进程"""
|
|
self._process.terminate()
|
|
self._process.join(10)
|
|
|
|
self._process = spawn.Process(target=self._target)
|
|
self._process.start()
|
|
logger.info("目标进程已经重载", extra={"tag": "Reloader"})
|
|
|
|
def shutdown(self) -> None:
|
|
"""关闭进程"""
|
|
self._process.terminate()
|
|
self._process.join(10)
|
|
|
|
logger.info("重载器已经关闭", extra={"tag": "Reloader"})
|