mirror of
https://github.com/PaiGramTeam/GramCore.git
synced 2024-11-25 07:36:27 +00:00
315 lines
11 KiB
Python
315 lines
11 KiB
Python
"""BOT"""
|
||
import asyncio
|
||
import signal
|
||
from functools import wraps
|
||
from signal import SIGABRT, SIGINT, SIGTERM, signal as signal_func
|
||
from ssl import SSLZeroReturnError
|
||
from typing import Callable, List, Optional, TYPE_CHECKING, TypeVar, Union
|
||
|
||
import pytz
|
||
import uvicorn
|
||
from fastapi import FastAPI
|
||
from telegram import Bot, Update
|
||
from telegram.error import NetworkError, TelegramError, TimedOut
|
||
from telegram.ext import (
|
||
Application as TelegramApplication,
|
||
ApplicationBuilder as TelegramApplicationBuilder,
|
||
Defaults,
|
||
JobQueue,
|
||
)
|
||
from typing_extensions import ParamSpec
|
||
from uvicorn import Server
|
||
|
||
from gram_core.config import config as application_config
|
||
from gram_core.handler.limiterhandler import LimiterHandler
|
||
from gram_core.manager import Managers
|
||
from gram_core.override.telegram import HTTPXRequest
|
||
from gram_core.ratelimiter import RateLimiter
|
||
from utils.const import WRAPPER_ASSIGNMENTS
|
||
from utils.log import logger
|
||
from utils.models.signal import Singleton
|
||
|
||
if TYPE_CHECKING:
|
||
from asyncio import Task
|
||
from types import FrameType
|
||
from uvicorn._types import ASGIApplication
|
||
from gram_core.ratelimiter import T_CalledAPIFunc
|
||
|
||
__all__ = ("Application",)
|
||
|
||
R = TypeVar("R")
|
||
T = TypeVar("T")
|
||
P = ParamSpec("P")
|
||
|
||
|
||
class Application(Singleton):
|
||
"""Application"""
|
||
|
||
_web_server_task: Optional["Task"] = None
|
||
|
||
_startup_funcs: List[Callable] = []
|
||
_shutdown_funcs: List[Callable] = []
|
||
_called_api_funcs: List["T_CalledAPIFunc"] = []
|
||
|
||
def __init__(self, managers: "Managers", telegram: "TelegramApplication", web_server: "Server") -> None:
|
||
self._running = False
|
||
self.managers = managers
|
||
self.telegram = telegram
|
||
self.web_server = web_server
|
||
self.managers.set_application(application=self) # 给 managers 设置 application
|
||
self.managers.build_executor("Application")
|
||
|
||
@classmethod
|
||
def build(cls):
|
||
managers = Managers()
|
||
rate_limiter = RateLimiter()
|
||
telegram = (
|
||
TelegramApplicationBuilder()
|
||
.get_updates_read_timeout(application_config.update_read_timeout)
|
||
.get_updates_write_timeout(application_config.update_write_timeout)
|
||
.get_updates_connect_timeout(application_config.update_connect_timeout)
|
||
.get_updates_pool_timeout(application_config.update_pool_timeout)
|
||
.defaults(Defaults(tzinfo=pytz.timezone("Asia/Shanghai")))
|
||
.token(application_config.bot_token)
|
||
.base_url(application_config.bot_base_url)
|
||
.base_file_url(application_config.bot_base_file_url)
|
||
.request(
|
||
HTTPXRequest(
|
||
connection_pool_size=application_config.connection_pool_size,
|
||
proxy_url=application_config.proxy_url,
|
||
read_timeout=application_config.read_timeout,
|
||
write_timeout=application_config.write_timeout,
|
||
connect_timeout=application_config.connect_timeout,
|
||
pool_timeout=application_config.pool_timeout,
|
||
)
|
||
)
|
||
.rate_limiter(rate_limiter)
|
||
.build()
|
||
)
|
||
web_server = Server(
|
||
uvicorn.Config(
|
||
app=FastAPI(debug=application_config.debug),
|
||
port=application_config.webserver.port,
|
||
host=application_config.webserver.host,
|
||
log_config=None,
|
||
)
|
||
)
|
||
instance = cls(managers, telegram, web_server)
|
||
rate_limiter.set_application(instance)
|
||
return instance
|
||
|
||
@property
|
||
def running(self) -> bool:
|
||
"""bot 是否正在运行"""
|
||
with self._lock:
|
||
return self._running
|
||
|
||
@property
|
||
def web_app(self) -> Union["ASGIApplication", Callable, str]:
|
||
"""fastapi app"""
|
||
return self.web_server.config.app
|
||
|
||
@property
|
||
def bot(self) -> Optional[Bot]:
|
||
return self.telegram.bot
|
||
|
||
@property
|
||
def job_queue(self) -> Optional[JobQueue]:
|
||
return self.telegram.job_queue
|
||
|
||
async def _on_startup(self) -> None:
|
||
for func in self._startup_funcs:
|
||
await self.managers.executor(func, block=getattr(func, "block", False))
|
||
|
||
async def _on_shutdown(self) -> None:
|
||
for func in self._shutdown_funcs:
|
||
await self.managers.executor(func, block=getattr(func, "block", False))
|
||
|
||
async def initialize(self):
|
||
"""BOT 初始化"""
|
||
self.telegram.add_handler(LimiterHandler(limit_time=10), group=-1) # 启用入口洪水限制
|
||
await self.managers.start_dependency() # 启动基础服务
|
||
await self.managers.init_components() # 实例化组件
|
||
await self.managers.start_services() # 启动其他服务
|
||
await self.managers.install_plugins() # 安装插件
|
||
|
||
async def shutdown(self):
|
||
"""BOT 关闭"""
|
||
await self.managers.uninstall_plugins() # 卸载插件
|
||
await self.managers.stop_services() # 终止其他服务
|
||
await self.managers.stop_dependency() # 终止基础服务
|
||
|
||
async def start(self) -> None:
|
||
"""启动 BOT"""
|
||
logger.info("正在启动 BOT 中...")
|
||
|
||
def error_callback(exc: TelegramError) -> None:
|
||
"""错误信息回调"""
|
||
self.telegram.create_task(self.telegram.process_error(error=exc, update=None))
|
||
|
||
await self.telegram.initialize()
|
||
logger.info("[blue]Telegram[/] 初始化成功", extra={"markup": True})
|
||
|
||
if application_config.webserver.enable: # 如果使用 web app
|
||
server_config = self.web_server.config
|
||
server_config.setup_event_loop()
|
||
if not server_config.loaded:
|
||
server_config.load()
|
||
self.web_server.lifespan = server_config.lifespan_class(server_config)
|
||
try:
|
||
await self.web_server.startup()
|
||
except OSError as e:
|
||
if e.errno == 10048:
|
||
logger.error("Web Server 端口被占用:%s", e)
|
||
logger.error("Web Server 启动失败,正在退出")
|
||
raise SystemExit from None
|
||
|
||
if self.web_server.should_exit:
|
||
logger.error("Web Server 启动失败,正在退出")
|
||
raise SystemExit from None
|
||
logger.success("Web Server 启动成功")
|
||
|
||
self._web_server_task = asyncio.create_task(self.web_server.main_loop())
|
||
|
||
for _ in range(5): # 连接至 telegram 服务器
|
||
try:
|
||
await self.telegram.updater.start_polling(
|
||
error_callback=error_callback, allowed_updates=Update.ALL_TYPES
|
||
)
|
||
break
|
||
except TimedOut:
|
||
logger.warning("连接至 [blue]telegram[/] 服务器失败,正在重试", extra={"markup": True})
|
||
continue
|
||
except NetworkError as e:
|
||
logger.exception()
|
||
if isinstance(e, SSLZeroReturnError):
|
||
logger.error("代理服务出现异常, 请检查您的代理服务是否配置成功.")
|
||
else:
|
||
logger.error("网络连接出现问题, 请检查您的网络状况.")
|
||
raise SystemExit from e
|
||
|
||
await self.initialize()
|
||
logger.success("BOT 初始化成功")
|
||
logger.debug("BOT 开始启动")
|
||
|
||
await self._on_startup()
|
||
await self.telegram.start()
|
||
self._running = True
|
||
logger.success("BOT 启动成功")
|
||
|
||
def stop_signal_handler(self, signum: int):
|
||
"""终止信号处理"""
|
||
signals = {k: v for v, k in signal.__dict__.items() if v.startswith("SIG") and not v.startswith("SIG_")}
|
||
logger.debug("接收到了终止信号 %s 正在退出...", signals[signum])
|
||
if self._web_server_task:
|
||
self._web_server_task.cancel()
|
||
|
||
async def idle(self) -> None:
|
||
"""在接收到中止信号之前,堵塞loop"""
|
||
|
||
task = None
|
||
|
||
def stop_handler(signum: int, _: "FrameType") -> None:
|
||
self.stop_signal_handler(signum)
|
||
task.cancel()
|
||
|
||
for s in (SIGINT, SIGTERM, SIGABRT):
|
||
signal_func(s, stop_handler)
|
||
|
||
while True:
|
||
task = asyncio.create_task(asyncio.sleep(600))
|
||
|
||
try:
|
||
await task
|
||
except asyncio.CancelledError:
|
||
break
|
||
|
||
async def stop(self) -> None:
|
||
"""关闭"""
|
||
logger.info("BOT 正在关闭")
|
||
self._running = False
|
||
|
||
await self._on_shutdown()
|
||
|
||
if self.telegram.updater.running:
|
||
await self.telegram.updater.stop()
|
||
|
||
await self.shutdown()
|
||
|
||
if self.telegram.running:
|
||
await self.telegram.stop()
|
||
|
||
await self.telegram.shutdown()
|
||
if self.web_server is not None:
|
||
try:
|
||
await self.web_server.shutdown()
|
||
logger.info("Web Server 已经关闭")
|
||
except AttributeError:
|
||
pass
|
||
|
||
logger.success("BOT 关闭成功")
|
||
|
||
def launch(self) -> None:
|
||
"""启动"""
|
||
loop = asyncio.get_event_loop()
|
||
try:
|
||
loop.run_until_complete(self.start())
|
||
loop.run_until_complete(self.idle())
|
||
except (SystemExit, KeyboardInterrupt) as exc:
|
||
logger.debug("接收到了终止信号,BOT 即将关闭", exc_info=exc) # 接收到了终止信号
|
||
except NetworkError as e:
|
||
if isinstance(e, SSLZeroReturnError):
|
||
logger.critical("代理服务出现异常, 请检查您的代理服务是否配置成功.")
|
||
else:
|
||
logger.critical("网络连接出现问题, 请检查您的网络状况.")
|
||
except Exception as e:
|
||
logger.critical("遇到了未知错误: %s", {type(e)}, exc_info=e)
|
||
finally:
|
||
loop.run_until_complete(self.stop())
|
||
|
||
if application_config.reload:
|
||
raise SystemExit from None
|
||
|
||
def on_startup(self, func: Callable[P, R]) -> Callable[P, R]:
|
||
"""注册一个在 BOT 启动时执行的函数"""
|
||
|
||
if func not in self._startup_funcs:
|
||
self._startup_funcs.append(func)
|
||
|
||
# noinspection PyTypeChecker
|
||
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||
return func(*args, **kwargs)
|
||
|
||
return wrapper
|
||
|
||
def on_shutdown(self, func: Callable[P, R]) -> Callable[P, R]:
|
||
"""注册一个在 BOT 停止时执行的函数"""
|
||
|
||
if func not in self._shutdown_funcs:
|
||
self._shutdown_funcs.append(func)
|
||
|
||
# noinspection PyTypeChecker
|
||
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||
return func(*args, **kwargs)
|
||
|
||
return wrapper
|
||
|
||
def on_called_api(self, func: "T_CalledAPIFunc") -> Callable[P, R]:
|
||
"""注册一个在 BOT 调用 Telegram API 后执行的函数"""
|
||
|
||
if func not in self._called_api_funcs:
|
||
self._called_api_funcs.append(func)
|
||
|
||
# noinspection PyTypeChecker
|
||
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||
return func(*args, **kwargs)
|
||
|
||
return wrapper
|
||
|
||
def get_called_api_funcs(self) -> List["T_CalledAPIFunc"]:
|
||
"""获取所有在 BOT 调用 Telegram API 后执行的函数"""
|
||
return self._called_api_funcs
|