GramCore/application.py
2024-06-23 00:26:36 +08:00

359 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""BOT"""
import asyncio
import signal
from functools import wraps
from signal import SIGABRT, SIGINT, SIGTERM, signal as signal_func
from ssl import SSLZeroReturnError
from typing import Callable, List, Optional, TYPE_CHECKING, TypeVar, Union
import pytz
import uvicorn
from fastapi import FastAPI
from starlette.requests import Request
from starlette.responses import Response
from telegram import Bot, Update
from telegram.error import NetworkError, TelegramError, TimedOut
from telegram.ext import (
Application as TelegramApplication,
ApplicationBuilder as TelegramApplicationBuilder,
Defaults,
JobQueue,
)
from telegram.request import HTTPXRequest
from typing_extensions import ParamSpec
from uvicorn import Server
from gram_core.config import config as application_config
from gram_core.handler.limiterhandler import LimiterHandler
from gram_core.manager import Managers
from gram_core.ratelimiter import RateLimiter
from utils.const import WRAPPER_ASSIGNMENTS
from utils.log import logger
from utils.models.signal import Singleton
if TYPE_CHECKING:
from asyncio import Task
from telegram import Bot
from types import FrameType
from gram_core.ratelimiter import T_CalledAPIFunc
from gram_core.handler.hookhandler import T_PreprocessorsFunc
__all__ = ("Application",)
R = TypeVar("R")
T = TypeVar("T")
P = ParamSpec("P")
class Application(Singleton):
"""Application"""
_web_server_task: Optional["Task"] = None
_startup_funcs: List[Callable] = []
_shutdown_funcs: List[Callable] = []
_called_api_funcs: List["T_CalledAPIFunc"] = []
_preprocessors_funcs: List["T_PreprocessorsFunc"] = []
def __init__(self, managers: "Managers", telegram: "TelegramApplication", web_server: "Server") -> None:
self._running = False
self.managers = managers
self.telegram = telegram
self.web_server = web_server
self.managers.set_application(application=self) # 给 managers 设置 application
self.managers.build_executor("Application")
@classmethod
def build(cls):
managers = Managers()
rate_limiter = RateLimiter()
telegram = (
TelegramApplicationBuilder()
.get_updates_read_timeout(application_config.update_read_timeout)
.get_updates_write_timeout(application_config.update_write_timeout)
.get_updates_connect_timeout(application_config.update_connect_timeout)
.get_updates_pool_timeout(application_config.update_pool_timeout)
.defaults(Defaults(tzinfo=pytz.timezone("Asia/Shanghai"), allow_sending_without_reply=True))
.token(application_config.bot.token)
.base_url(application_config.bot.base_url)
.base_file_url(application_config.bot.base_file_url)
.request(
HTTPXRequest(
connection_pool_size=application_config.connection_pool_size,
proxy_url=application_config.proxy_url,
read_timeout=application_config.read_timeout,
write_timeout=application_config.write_timeout,
connect_timeout=application_config.connect_timeout,
pool_timeout=application_config.pool_timeout,
media_write_timeout=application_config.write_timeout,
)
)
.rate_limiter(rate_limiter)
.build()
)
web_server = Server(
uvicorn.Config(
app=FastAPI(debug=application_config.debug),
port=application_config.webserver.port,
host=application_config.webserver.host,
log_config=None,
)
)
instance = cls(managers, telegram, web_server)
rate_limiter.set_application(instance)
return instance
@property
def running(self) -> bool:
"""bot 是否正在运行"""
with self._lock:
return self._running
@property
def web_app(self) -> Union["FastAPI", Callable, str]:
"""fastapi app"""
return self.web_server.config.app
@property
def bot(self) -> Optional[Bot]:
return self.telegram.bot
@property
def job_queue(self) -> Optional[JobQueue]:
return self.telegram.job_queue
async def _on_startup(self) -> None:
for func in self._startup_funcs:
await self.managers.executor(func, block=getattr(func, "block", False))
async def _on_shutdown(self) -> None:
for func in self._shutdown_funcs:
await self.managers.executor(func, block=getattr(func, "block", False))
async def initialize(self):
"""BOT 初始化"""
self.telegram.add_handler(LimiterHandler(limit_time=10), group=-1) # 启用入口洪水限制
await self.managers.start_dependency() # 启动基础服务
await self.managers.init_components() # 实例化组件
await self.managers.start_services() # 启动其他服务
await self.managers.install_plugins() # 安装插件
async def shutdown(self):
"""BOT 关闭"""
await self.managers.uninstall_plugins() # 卸载插件
await self.managers.stop_services() # 终止其他服务
await self.managers.stop_dependency() # 终止基础服务
async def start(self) -> None:
"""启动 BOT"""
logger.info("正在启动 BOT 中...")
await self.telegram.initialize()
logger.info("[blue]Telegram[/] 初始化成功", extra={"markup": True})
if application_config.webserver.enable: # 如果使用 web app
server_config = self.web_server.config
server_config.setup_event_loop()
if not server_config.loaded:
server_config.load()
self.web_server.lifespan = server_config.lifespan_class(server_config)
try:
await self.web_server.startup()
except OSError as e:
if e.errno == 10048:
logger.error("Web Server 端口被占用:%s", e)
logger.error("Web Server 启动失败,正在退出")
raise SystemExit from None
if self.web_server.should_exit:
logger.error("Web Server 启动失败,正在退出")
raise SystemExit from None
logger.success("Web Server 启动成功")
self._web_server_task = asyncio.create_task(self.web_server.main_loop())
await self.start_bot()
await self.initialize()
logger.success("BOT 初始化成功")
logger.debug("BOT 开始启动")
await self._on_startup()
await self.telegram.start()
self._running = True
logger.success("BOT 启动成功")
async def start_bot(self):
"""启动 BOT"""
def error_callback(exc: TelegramError) -> None:
"""错误信息回调"""
self.telegram.create_task(self.telegram.process_error(error=exc, update=None))
if application_config.bot.is_webhook and application_config.webserver.enable:
self.register_bot_route()
await self.bot.set_webhook(application_config.bot.webhook_url)
for _ in range(5): # 连接至 telegram 服务器
try:
if application_config.bot.is_webhook and application_config.webserver.enable:
await self.bot.set_webhook(application_config.bot.webhook_url)
else:
await self.bot.delete_webhook()
await self.telegram.updater.start_polling(
error_callback=error_callback, allowed_updates=Update.ALL_TYPES
)
break
except TimedOut:
logger.warning("连接至 [blue]telegram[/] 服务器失败,正在重试", extra={"markup": True})
continue
except NetworkError as e:
logger.exception()
if isinstance(e, SSLZeroReturnError):
logger.error("代理服务出现异常, 请检查您的代理服务是否配置成功.")
else:
logger.error("网络连接出现问题, 请检查您的网络状况.")
raise SystemExit from e
def register_bot_route(self):
"""注册 webhook 路由"""
@self.web_app.post("/telegram")
async def telegram(request: Request) -> Response:
"""Handle incoming Telegram updates by putting them into the `update_queue`"""
await self.telegram.updater.update_queue.put(Update.de_json(data=await request.json(), bot=self.bot))
return Response()
def stop_signal_handler(self, signum: int):
"""终止信号处理"""
signals = {k: v for v, k in signal.__dict__.items() if v.startswith("SIG") and not v.startswith("SIG_")}
logger.debug("接收到了终止信号 %s 正在退出...", signals[signum])
if self._web_server_task:
self._web_server_task.cancel()
async def idle(self) -> None:
"""在接收到中止信号之前堵塞loop"""
task = None
def stop_handler(signum: int, _: "FrameType") -> None:
self.stop_signal_handler(signum)
task.cancel()
for s in (SIGINT, SIGTERM, SIGABRT):
signal_func(s, stop_handler)
while True:
task = asyncio.create_task(asyncio.sleep(600))
try:
await task
except asyncio.CancelledError:
break
async def stop(self) -> None:
"""关闭"""
logger.info("BOT 正在关闭")
self._running = False
await self._on_shutdown()
if self.telegram.updater.running:
await self.telegram.updater.stop()
await self.shutdown()
if self.telegram.running:
await self.telegram.stop()
await self.telegram.shutdown()
if self.web_server is not None:
try:
await self.web_server.shutdown()
logger.info("Web Server 已经关闭")
except AttributeError:
pass
logger.success("BOT 关闭成功")
def launch(self) -> None:
"""启动"""
loop = asyncio.get_event_loop()
try:
loop.run_until_complete(self.start())
loop.run_until_complete(self.idle())
except (SystemExit, KeyboardInterrupt) as exc:
logger.debug("接收到了终止信号BOT 即将关闭", exc_info=exc) # 接收到了终止信号
except NetworkError as e:
if isinstance(e, SSLZeroReturnError):
logger.critical("代理服务出现异常, 请检查您的代理服务是否配置成功.")
else:
logger.critical("网络连接出现问题, 请检查您的网络状况.")
except Exception as e:
logger.critical("遇到了未知错误: %s", {type(e)}, exc_info=e)
finally:
loop.run_until_complete(self.stop())
if application_config.reload:
raise SystemExit from None
def on_startup(self, func: Callable[P, R]) -> Callable[P, R]:
"""注册一个在 BOT 启动时执行的函数"""
if func not in self._startup_funcs:
self._startup_funcs.append(func)
# noinspection PyTypeChecker
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return func(*args, **kwargs)
return wrapper
def on_shutdown(self, func: Callable[P, R]) -> Callable[P, R]:
"""注册一个在 BOT 停止时执行的函数"""
if func not in self._shutdown_funcs:
self._shutdown_funcs.append(func)
# noinspection PyTypeChecker
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return func(*args, **kwargs)
return wrapper
def on_called_api(self, func: "T_CalledAPIFunc") -> Callable[P, R]:
"""注册一个在 BOT 调用 Telegram API 后执行的函数"""
if func not in self._called_api_funcs:
self._called_api_funcs.append(func)
# noinspection PyTypeChecker
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return func(*args, **kwargs)
return wrapper
def get_called_api_funcs(self) -> List["T_CalledAPIFunc"]:
"""获取所有在 BOT 调用 Telegram API 后执行的函数"""
return self._called_api_funcs
def run_preprocessor(self, func: "T_PreprocessorsFunc") -> Callable[P, R]:
"""注册一个在 BOT 调用插件函数前执行的函数"""
if func not in self._preprocessors_funcs:
self._preprocessors_funcs.append(func)
# noinspection PyTypeChecker
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return func(*args, **kwargs)
return wrapper
def get_preprocessors_funcs(self) -> List["T_PreprocessorsFunc"]:
"""获取所有在 BOT 调用插件函数前执行的函数"""
return self._preprocessors_funcs