mirror of
https://github.com/PaiGramTeam/PaiGram.git
synced 2024-11-24 09:15:45 +00:00
♻️ PaiGram V4
Co-authored-by: luoshuijs <luoshuijs@outlook.com> Co-authored-by: Karako <karakohear@gmail.com> Co-authored-by: xtaodada <xtao@xtaolink.cn>
This commit is contained in:
parent
baceace292
commit
233e7ab58d
32
.env.example
32
.env.example
@ -1,6 +1,12 @@
|
||||
# debug 开关
|
||||
DEBUG=false
|
||||
|
||||
AUTO_RELOAD=false
|
||||
RELOAD_DELAY=0.25
|
||||
RELOAD_DIRS=[]
|
||||
RELOAD_INCLUDE=[]
|
||||
RELOAD_EXCLUDE=[]
|
||||
|
||||
# MySQL
|
||||
DB_HOST=127.0.0.1
|
||||
DB_PORT=3306
|
||||
@ -17,14 +23,14 @@ REDIS_PASSWORD=""
|
||||
# 联系 https://t.me/BotFather 使用 /newbot 命令创建机器人并获取 token
|
||||
BOT_TOKEN="xxxxxxx"
|
||||
|
||||
# bot 管理员
|
||||
ADMINS=[{ "username": "", "user_id": -1 }]
|
||||
# bot 所有者
|
||||
OWNER=0
|
||||
|
||||
# 记录错误并发送消息通知开发人员 可选配置项
|
||||
# ERROR_NOTIFICATION_CHAT_ID=chat_id
|
||||
|
||||
# 文章推送群组 可选配置项
|
||||
# CHANNELS=[{ "name": "", "chat_id": 1}]
|
||||
# CHANNELS=[]
|
||||
|
||||
# 是否允许机器人邀请到其他群 默认不允许 如果允许 可以允许全部人或有认证选项 可选配置项
|
||||
# JOIN_GROUPS = "NO_ALLOW"
|
||||
@ -33,20 +39,20 @@ ADMINS=[{ "username": "", "user_id": -1 }]
|
||||
# VERIFY_GROUPS=[]
|
||||
|
||||
# logger 配置 可选配置项
|
||||
LOGGER_NAME="TGPaimon"
|
||||
# LOGGER_NAME="TGPaimon"
|
||||
# 打印时的宽度
|
||||
LOGGER_WIDTH=180
|
||||
# LOGGER_WIDTH=180
|
||||
# log 文件存放目录
|
||||
LOGGER_LOG_PATH="logs"
|
||||
# LOGGER_LOG_PATH="logs"
|
||||
# log 时间格式,参考 datetime.strftime
|
||||
LOGGER_TIME_FORMAT="[%Y-%m-%d %X]"
|
||||
# LOGGER_TIME_FORMAT="[%Y-%m-%d %X]"
|
||||
# log 高亮关键词
|
||||
LOGGER_RENDER_KEYWORDS=["BOT"]
|
||||
# LOGGER_RENDER_KEYWORDS=["BOT"]
|
||||
# traceback 相关配置
|
||||
LOGGER_TRACEBACK_MAX_FRAMES=20
|
||||
LOGGER_LOCALS_MAX_DEPTH=0
|
||||
LOGGER_LOCALS_MAX_LENGTH=10
|
||||
LOGGER_LOCALS_MAX_STRING=80
|
||||
# LOGGER_TRACEBACK_MAX_FRAMES=20
|
||||
# LOGGER_LOCALS_MAX_DEPTH=0
|
||||
# LOGGER_LOCALS_MAX_LENGTH=10
|
||||
# LOGGER_LOCALS_MAX_STRING=80
|
||||
# 可被 logger 打印的 record 的名称(默认包含了 LOGGER_NAME )
|
||||
LOGGER_FILTERED_NAMES=["uvicorn","ErrorPush","ApiHelper"]
|
||||
|
||||
@ -77,7 +83,7 @@ LOGGER_FILTERED_NAMES=["uvicorn","ErrorPush","ApiHelper"]
|
||||
# ENKA_NETWORK_API_AGENT=""
|
||||
|
||||
# Web Server
|
||||
# 目前只用于预览模板,仅开发环境启动
|
||||
# WEB_SWITCH=False # 是否开启
|
||||
# WEB_URL=http://localhost:8080/
|
||||
# WEB_HOST=localhost
|
||||
# WEB_PORT=8080
|
||||
|
54
.github/workflows/integration-test.yml
vendored
Normal file
54
.github/workflows/integration-test.yml
vendored
Normal file
@ -0,0 +1,54 @@
|
||||
name: Integration Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- 'tests/integration/**'
|
||||
pull_request:
|
||||
types: [ opened, synchronize ]
|
||||
paths:
|
||||
- 'core/services/**'
|
||||
- 'core/dependence/**'
|
||||
- 'tests/integration/**'
|
||||
|
||||
jobs:
|
||||
pytest:
|
||||
name: pytest
|
||||
runs-on: ubuntu-latest
|
||||
services:
|
||||
mysql:
|
||||
image: mysql:5.7
|
||||
env:
|
||||
MYSQL_DATABASE: integration_test
|
||||
MYSQL_ROOT_PASSWORD: 123456test
|
||||
ports:
|
||||
- 3306:3306
|
||||
options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3
|
||||
redis:
|
||||
image: redis
|
||||
ports:
|
||||
- 6379:6379
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
- name: Set up Python 3.11
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.11
|
||||
- name: Setup integration test environment
|
||||
run: cp tests/integration/.env.example .env && cp tests/integration/.env.example tests/integration/.env
|
||||
- name: Create venv
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
python3 -m venv venv
|
||||
- name: Install requirements
|
||||
run: |
|
||||
source venv/bin/activate
|
||||
python3 -m pip install --upgrade poetry
|
||||
python3 -m poetry install --extras all
|
||||
- name: Run test
|
||||
run: |
|
||||
source venv/bin/activate
|
||||
python3 -m pytest tests/integration
|
21
.github/workflows/test.yml
vendored
21
.github/workflows/test.yml
vendored
@ -1,19 +1,17 @@
|
||||
name: test
|
||||
name: Test modules
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- 'tests/**'
|
||||
- 'tests/unit/**'
|
||||
pull_request:
|
||||
types: [ opened, synchronize ]
|
||||
paths:
|
||||
- 'modules/apihelper/**'
|
||||
- 'modules/wiki/**'
|
||||
- 'tests/**'
|
||||
schedule:
|
||||
- cron: '0 4 * * 3'
|
||||
- 'tests/unit/**'
|
||||
|
||||
jobs:
|
||||
pytest:
|
||||
@ -22,16 +20,15 @@ jobs:
|
||||
continue-on-error: ${{ matrix.experimental }}
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [ '3.10' ]
|
||||
os: [ ubuntu-latest, windows-latest ]
|
||||
experimental: [ false ]
|
||||
fail-fast: False
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
- name: Set up Python 3.11
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
python-version: 3.11
|
||||
- name: restore or create a python virtualenv
|
||||
id: cache
|
||||
uses: syphar/restore-virtualenv@v1.2
|
||||
@ -45,4 +42,4 @@ jobs:
|
||||
poetry install --extras test
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
python -m pytest
|
||||
python -m pytest tests/unit
|
5
.gitignore
vendored
5
.gitignore
vendored
@ -58,6 +58,5 @@ plugins/private
|
||||
.pytest_cache
|
||||
|
||||
### mtp ###
|
||||
paimon.session
|
||||
PaimonBot.session
|
||||
PaimonBot.session-journal
|
||||
paigram.session
|
||||
paigram.session-journal
|
||||
|
@ -1,7 +1,6 @@
|
||||
<h1 align="center">PaiGram</h1>
|
||||
|
||||
<div align="center">
|
||||
<img src="https://img.shields.io/badge/python-3.8%2B-blue" alt="">
|
||||
<div align="center">·<img src="https://img.shields.io/badge/python-3.11%2B-blue" alt="">
|
||||
<img src="https://img.shields.io/badge/works%20on-my%20machine-brightgreen" alt="">
|
||||
<img src="https://img.shields.io/badge/status-%E5%92%95%E5%92%95%E5%92%95-blue" alt="">
|
||||
<a href="https://black.readthedocs.io/en/stable/index.html"><img src="https://img.shields.io/badge/code%20style-black-000000.svg" alt="code_style" /></a>
|
||||
@ -19,7 +18,7 @@
|
||||
|
||||
## 环境需求
|
||||
|
||||
- Python 3.8+
|
||||
- Python 3.11+
|
||||
- MySQL
|
||||
- Redis
|
||||
|
||||
|
@ -6,19 +6,13 @@ from logging.config import fileConfig
|
||||
from typing import Iterator
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import (
|
||||
engine_from_config,
|
||||
pool,
|
||||
)
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from utils.const import (
|
||||
CORE_DIR,
|
||||
PLUGIN_DIR,
|
||||
PROJECT_ROOT,
|
||||
)
|
||||
from core.config import config as BotConfig
|
||||
from utils.const import CORE_DIR, PLUGIN_DIR, PROJECT_ROOT
|
||||
from utils.log import logger
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
@ -28,7 +22,7 @@ config = context.config
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
fileConfig(config.config_file_name) # skipcq: PY-A6006
|
||||
|
||||
|
||||
def scan_models() -> Iterator[str]:
|
||||
@ -46,7 +40,7 @@ def import_models():
|
||||
try:
|
||||
import_module(pkg) # 导入 models
|
||||
except Exception as e: # pylint: disable=W0703
|
||||
logger.error(f'在导入文件 "{pkg}" 的过程中遇到了错误: \n[red bold]{type(e).__name__}: {e}[/]')
|
||||
logger.error("在导入文件 %s 的过程中遇到了错误: \n[red bold]%s: %s[/]", pkg, type(e).__name__, e, extra={"markup": True})
|
||||
|
||||
|
||||
# register our models for alembic to auto-generate migrations
|
||||
@ -61,14 +55,13 @@ target_metadata = SQLModel.metadata
|
||||
|
||||
# here we allow ourselves to pass interpolation vars to alembic.ini
|
||||
# from the application config module
|
||||
from core.config import config as botConfig
|
||||
|
||||
section = config.config_ini_section
|
||||
config.set_section_option(section, "DB_HOST", botConfig.mysql.host)
|
||||
config.set_section_option(section, "DB_PORT", str(botConfig.mysql.port))
|
||||
config.set_section_option(section, "DB_USERNAME", botConfig.mysql.username)
|
||||
config.set_section_option(section, "DB_PASSWORD", botConfig.mysql.password)
|
||||
config.set_section_option(section, "DB_DATABASE", botConfig.mysql.database)
|
||||
config.set_section_option(section, "DB_HOST", BotConfig.mysql.host)
|
||||
config.set_section_option(section, "DB_PORT", str(BotConfig.mysql.port))
|
||||
config.set_section_option(section, "DB_USERNAME", BotConfig.mysql.username)
|
||||
config.set_section_option(section, "DB_PASSWORD", BotConfig.mysql.password)
|
||||
config.set_section_option(section, "DB_DATABASE", BotConfig.mysql.database)
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
|
@ -5,16 +5,19 @@ Revises:
|
||||
Create Date: 2022-09-01 16:55:20.372560
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from base64 import b64decode
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9e9a36470cd5"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
old_cookies_database_name1 = b64decode("bWlob3lvX2Nvb2tpZXM=").decode()
|
||||
old_cookies_database_name2 = b64decode("aG95b3ZlcnNlX2Nvb2tpZXM=").decode()
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
@ -22,7 +25,7 @@ def upgrade() -> None:
|
||||
op.create_table(
|
||||
"question",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("text", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column("text", sqlmodel.AutoString(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
mysql_charset="utf8mb4",
|
||||
mysql_collate="utf8mb4_general_ci",
|
||||
@ -35,7 +38,7 @@ def upgrade() -> None:
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("yuanshen_uid", sa.Integer(), nullable=True),
|
||||
sa.Column("genshin_uid", sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
@ -46,7 +49,7 @@ def upgrade() -> None:
|
||||
op.create_table(
|
||||
"admin",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.user_id"],
|
||||
@ -60,7 +63,7 @@ def upgrade() -> None:
|
||||
sa.Column("question_id", sa.Integer(), nullable=True),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("is_correct", sa.Boolean(), nullable=True),
|
||||
sa.Column("text", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column("text", sqlmodel.AutoString(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["question_id"],
|
||||
["question.id"],
|
||||
@ -72,7 +75,7 @@ def upgrade() -> None:
|
||||
mysql_collate="utf8mb4_general_ci",
|
||||
)
|
||||
op.create_table(
|
||||
"hoyoverse_cookies",
|
||||
old_cookies_database_name2,
|
||||
sa.Column("cookies", sa.JSON(), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
@ -85,7 +88,7 @@ def upgrade() -> None:
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=True),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.user_id"],
|
||||
@ -95,7 +98,7 @@ def upgrade() -> None:
|
||||
mysql_collate="utf8mb4_general_ci",
|
||||
)
|
||||
op.create_table(
|
||||
"mihoyo_cookies",
|
||||
old_cookies_database_name1,
|
||||
sa.Column("cookies", sa.JSON(), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
@ -108,7 +111,7 @@ def upgrade() -> None:
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=True),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.user_id"],
|
||||
@ -119,6 +122,9 @@ def upgrade() -> None:
|
||||
)
|
||||
op.create_table(
|
||||
"sign",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("chat_id", sa.BigInteger(), nullable=True),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
@ -140,14 +146,11 @@ def upgrade() -> None:
|
||||
),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||
sa.Column("chat_id", sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.user_id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.PrimaryKeyConstraint("id", "user_id"),
|
||||
mysql_charset="utf8mb4",
|
||||
mysql_collate="utf8mb4_general_ci",
|
||||
)
|
||||
@ -157,8 +160,8 @@ def upgrade() -> None:
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("sign")
|
||||
op.drop_table("mihoyo_cookies")
|
||||
op.drop_table("hoyoverse_cookies")
|
||||
op.drop_table(old_cookies_database_name1)
|
||||
op.drop_table(old_cookies_database_name2)
|
||||
op.drop_table("answer")
|
||||
op.drop_table("admin")
|
||||
op.drop_table("user")
|
||||
|
301
alembic/versions/ddcfba3c7d5c_v4.py
Normal file
301
alembic/versions/ddcfba3c7d5c_v4.py
Normal file
@ -0,0 +1,301 @@
|
||||
"""v4
|
||||
|
||||
Revision ID: ddcfba3c7d5c
|
||||
Revises: 9e9a36470cd5
|
||||
Create Date: 2023-02-11 17:07:18.170175
|
||||
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from base64 import b64decode
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import NoSuchTableError
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ddcfba3c7d5c"
|
||||
down_revision = "9e9a36470cd5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
old_cookies_database_name1 = b64decode("bWlob3lvX2Nvb2tpZXM=").decode()
|
||||
old_cookies_database_name2 = b64decode("aG95b3ZlcnNlX2Nvb2tpZXM=").decode()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
connection = op.get_bind()
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
cookies_table = op.create_table(
|
||||
"cookies",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("account_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("data", sa.JSON(), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Enum(
|
||||
"STATUS_SUCCESS",
|
||||
"INVALID_COOKIES",
|
||||
"TOO_MANY_REQUESTS",
|
||||
name="cookiesstatusenum",
|
||||
),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"region",
|
||||
sa.Enum("NULL", "HYPERION", "HOYOLAB", name="regionenum"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("is_share", sa.Boolean(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.Index("index_user_account", "user_id", "account_id", unique=True),
|
||||
mysql_charset="utf8mb4",
|
||||
mysql_collate="utf8mb4_general_ci",
|
||||
)
|
||||
for old_cookies_database_name in (old_cookies_database_name1, old_cookies_database_name2):
|
||||
try:
|
||||
statement = f"SELECT * FROM {old_cookies_database_name};" # skipcq: BAN-B608
|
||||
old_cookies_table_data = connection.execute(text(statement))
|
||||
except NoSuchTableError:
|
||||
logger.warning("Table '%s' doesn't exist", old_cookies_database_name)
|
||||
continue
|
||||
if old_cookies_table_data is None:
|
||||
logger.warning("Old Cookies Database is None")
|
||||
continue
|
||||
for row in old_cookies_table_data:
|
||||
try:
|
||||
user_id = row["user_id"]
|
||||
status = row["status"]
|
||||
cookies_row = row["cookies"]
|
||||
cookies_data = json.loads(cookies_row)
|
||||
account_id = cookies_data.get("account_id")
|
||||
if account_id is None: # Cleaning Data 清洗数据
|
||||
account_id = cookies_data.get("ltuid")
|
||||
else:
|
||||
account_mid_v2 = cookies_data.get("account_mid_v2")
|
||||
if account_mid_v2 is not None:
|
||||
cookies_data.pop("account_id")
|
||||
cookies_data.setdefault("account_uid_v2", account_id)
|
||||
if old_cookies_database_name == old_cookies_database_name1:
|
||||
region = "HYPERION"
|
||||
else:
|
||||
region = "HOYOLAB"
|
||||
if account_id is None:
|
||||
logger.warning("Can not get user account_id, user_id :%s", user_id)
|
||||
continue
|
||||
insert = cookies_table.insert().values(
|
||||
user_id=int(user_id),
|
||||
account_id=int(account_id),
|
||||
status=status,
|
||||
data=cookies_data,
|
||||
region=region,
|
||||
is_share=True,
|
||||
)
|
||||
with op.get_context().autocommit_block():
|
||||
connection.execute(insert)
|
||||
except Exception as exc: # pylint: disable=W0703
|
||||
logger.error(
|
||||
"Process %s->cookies Exception", old_cookies_database_name, exc_info=exc
|
||||
) # pylint: disable=W0703
|
||||
players_table = op.create_table(
|
||||
"players",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("account_id", sa.BigInteger(), nullable=True),
|
||||
sa.Column("player_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column(
|
||||
"region",
|
||||
sa.Enum("NULL", "HYPERION", "HOYOLAB", name="regionenum"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("is_chosen", sa.Boolean(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.Index("index_user_account_player", "user_id", "account_id", "player_id", unique=True),
|
||||
mysql_charset="utf8mb4",
|
||||
mysql_collate="utf8mb4_general_ci",
|
||||
)
|
||||
|
||||
try:
|
||||
statement = "SELECT * FROM user;"
|
||||
old_user_table_data = connection.execute(text(statement))
|
||||
except NoSuchTableError:
|
||||
logger.warning("Table 'user' doesn't exist")
|
||||
return # should not happen
|
||||
if old_user_table_data is not None:
|
||||
for row in old_user_table_data:
|
||||
try:
|
||||
user_id = row["user_id"]
|
||||
y_uid = row["yuanshen_uid"]
|
||||
g_uid = row["genshin_uid"]
|
||||
region = row["region"]
|
||||
account_id = None
|
||||
cookies_row = connection.execute(
|
||||
cookies_table.select().where(cookies_table.c.user_id == user_id)
|
||||
).first()
|
||||
if cookies_row is not None:
|
||||
account_id = cookies_row["account_id"]
|
||||
if y_uid:
|
||||
insert = players_table.insert().values(
|
||||
user_id=int(user_id),
|
||||
player_id=int(y_uid),
|
||||
is_chosen=(region == "HYPERION"),
|
||||
region="HYPERION",
|
||||
account_id=account_id,
|
||||
)
|
||||
with op.get_context().autocommit_block():
|
||||
connection.execute(insert)
|
||||
if g_uid:
|
||||
insert = players_table.insert().values(
|
||||
user_id=int(user_id),
|
||||
player_id=int(g_uid),
|
||||
is_chosen=(region == "HOYOLAB"),
|
||||
region="HOYOLAB",
|
||||
account_id=account_id,
|
||||
)
|
||||
with op.get_context().autocommit_block():
|
||||
connection.execute(insert)
|
||||
except Exception as exc: # pylint: disable=W0703
|
||||
logger.error("Process user->player Exception", exc_info=exc)
|
||||
else:
|
||||
logger.warning("Old User Database is None")
|
||||
|
||||
users_table = op.create_table(
|
||||
"users",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=False, primary_key=True),
|
||||
sa.Column(
|
||||
"permissions",
|
||||
sa.Enum("OWNER", "ADMIN", "PUBLIC", name="permissionsenum"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("locale", sqlmodel.AutoString(), nullable=True),
|
||||
sa.Column("is_banned", sa.BigInteger(), nullable=True),
|
||||
sa.Column("ban_end_time", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("ban_start_time", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("user_id"),
|
||||
mysql_charset="utf8mb4",
|
||||
mysql_collate="utf8mb4_general_ci",
|
||||
)
|
||||
|
||||
try:
|
||||
statement = "SELECT * FROM admin;"
|
||||
old_user_table_data = connection.execute(text(statement))
|
||||
except NoSuchTableError:
|
||||
logger.warning("Table 'admin' doesn't exist")
|
||||
return # should not happen
|
||||
if old_user_table_data is not None:
|
||||
for row in old_user_table_data:
|
||||
try:
|
||||
user_id = row["user_id"]
|
||||
insert = users_table.insert().values(
|
||||
user_id=int(user_id),
|
||||
permissions="ADMIN",
|
||||
)
|
||||
with op.get_context().autocommit_block():
|
||||
connection.execute(insert)
|
||||
except Exception as exc: # pylint: disable=W0703
|
||||
logger.error("Process admin->users Exception", exc_info=exc)
|
||||
else:
|
||||
logger.warning("Old User Database is None")
|
||||
|
||||
op.create_table(
|
||||
"players_info",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("player_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("nickname", sqlmodel.AutoString(length=128), nullable=True),
|
||||
sa.Column("signature", sqlmodel.AutoString(length=255), nullable=True),
|
||||
sa.Column("hand_image", sa.Integer(), nullable=True),
|
||||
sa.Column("name_card", sa.Integer(), nullable=True),
|
||||
sa.Column("extra_data", sa.VARCHAR(length=512), nullable=True),
|
||||
sa.Column(
|
||||
"create_time",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("last_save_time", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("is_update", sa.Boolean(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.Index("index_user_player", "user_id", "player_id", unique=True),
|
||||
mysql_charset="utf8mb4",
|
||||
mysql_collate="utf8mb4_general_ci",
|
||||
)
|
||||
|
||||
op.drop_table(old_cookies_database_name1)
|
||||
op.drop_table(old_cookies_database_name2)
|
||||
op.drop_table("admin")
|
||||
op.drop_constraint("sign_ibfk_1", "sign", type_="foreignkey")
|
||||
op.drop_index("user_id", table_name="sign")
|
||||
op.drop_table("user")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_table(
|
||||
"user",
|
||||
sa.Column("region", sa.Enum("NULL", "HYPERION", "HOYOLAB", name="regionenum"), nullable=True),
|
||||
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), autoincrement=False, nullable=False),
|
||||
sa.Column("yuanshen_uid", sa.INTEGER(), autoincrement=False, nullable=True),
|
||||
sa.Column("genshin_uid", sa.INTEGER(), autoincrement=False, nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
mysql_collate="utf8mb4_general_ci",
|
||||
mysql_default_charset="utf8mb4",
|
||||
mysql_engine="InnoDB",
|
||||
)
|
||||
op.create_index("user_id", "user", ["user_id"], unique=False)
|
||||
op.create_table(
|
||||
"admin",
|
||||
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), autoincrement=False, nullable=False),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.user_id"], name="admin_ibfk_1"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
mysql_collate="utf8mb4_general_ci",
|
||||
mysql_default_charset="utf8mb4",
|
||||
mysql_engine="InnoDB",
|
||||
)
|
||||
op.create_table(
|
||||
old_cookies_database_name1,
|
||||
sa.Column("cookies", sa.JSON(), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Enum("STATUS_SUCCESS", "INVALID_COOKIES", "TOO_MANY_REQUESTS", name="cookiesstatusenum"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), autoincrement=False, nullable=True),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.user_id"], name="mihoyo_cookies_ibfk_1"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
mysql_collate="utf8mb4_general_ci",
|
||||
mysql_default_charset="utf8mb4",
|
||||
mysql_engine="InnoDB",
|
||||
)
|
||||
op.create_table(
|
||||
old_cookies_database_name2,
|
||||
sa.Column("cookies", sa.JSON(), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Enum("STATUS_SUCCESS", "INVALID_COOKIES", "TOO_MANY_REQUESTS", name="cookiesstatusenum"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), autoincrement=False, nullable=True),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.user_id"], name="hoyoverse_cookies_ibfk_1"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
mysql_collate="utf8mb4_general_ci",
|
||||
mysql_default_charset="utf8mb4",
|
||||
mysql_engine="InnoDB",
|
||||
)
|
||||
op.create_foreign_key("sign_ibfk_1", "sign", "user", ["user_id"], ["user_id"])
|
||||
op.create_index("user_id", "sign", ["user_id"], unique=False)
|
||||
op.drop_table("users")
|
||||
op.drop_table("players")
|
||||
op.drop_table("cookies")
|
||||
op.drop_table("players_info")
|
||||
# ### end Alembic commands ###
|
@ -1,14 +0,0 @@
|
||||
from core.service import init_service
|
||||
from core.base.mysql import MySQL
|
||||
from core.base.redisdb import RedisDB
|
||||
from core.admin.cache import BotAdminCache
|
||||
from core.admin.repositories import BotAdminRepository
|
||||
from core.admin.services import BotAdminService
|
||||
|
||||
|
||||
@init_service
|
||||
def create_bot_admin_service(mysql: MySQL, redis: RedisDB):
|
||||
_cache = BotAdminCache(redis)
|
||||
_repository = BotAdminRepository(mysql)
|
||||
_service = BotAdminService(_repository, _cache)
|
||||
return _service
|
@ -1,38 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from core.base.redisdb import RedisDB
|
||||
|
||||
|
||||
class BotAdminCache:
|
||||
def __init__(self, redis: RedisDB):
|
||||
self.client = redis.client
|
||||
self.qname = "bot:admin"
|
||||
|
||||
async def get_list(self):
|
||||
return [int(str_data) for str_data in await self.client.lrange(self.qname, 0, -1)]
|
||||
|
||||
async def set_list(self, str_list: List[int], ttl: int = -1):
|
||||
await self.client.ltrim(self.qname, 1, 0)
|
||||
await self.client.lpush(self.qname, *str_list)
|
||||
if ttl != -1:
|
||||
await self.client.expire(self.qname, ttl)
|
||||
count = await self.client.llen(self.qname)
|
||||
return count
|
||||
|
||||
|
||||
class GroupAdminCache:
|
||||
def __init__(self, redis: RedisDB):
|
||||
self.client = redis.client
|
||||
self.qname = "group:admin_list"
|
||||
|
||||
async def get_chat_admin(self, chat_id: int):
|
||||
qname = f"{self.qname}:{chat_id}"
|
||||
return [int(str_id) for str_id in await self.client.lrange(qname, 0, -1)]
|
||||
|
||||
async def set_chat_admin(self, chat_id: int, admin_list: List[int]):
|
||||
qname = f"{self.qname}:{chat_id}"
|
||||
await self.client.ltrim(qname, 1, 0)
|
||||
await self.client.lpush(qname, *admin_list)
|
||||
await self.client.expire(qname, 60)
|
||||
count = await self.client.llen(qname)
|
||||
return count
|
@ -1,8 +0,0 @@
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
|
||||
class Admin(SQLModel, table=True):
|
||||
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
||||
|
||||
id: int = Field(primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.user_id")
|
@ -1,33 +0,0 @@
|
||||
from typing import List, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from core.admin.models import Admin
|
||||
from core.base.mysql import MySQL
|
||||
|
||||
|
||||
class BotAdminRepository:
|
||||
def __init__(self, mysql: MySQL):
|
||||
self.mysql = mysql
|
||||
|
||||
async def delete_by_user_id(self, user_id: int):
|
||||
async with self.mysql.Session() as session:
|
||||
session = cast(AsyncSession, session)
|
||||
statement = select(Admin).where(Admin.user_id == user_id)
|
||||
results = await session.exec(statement)
|
||||
admin = results.one()
|
||||
await session.delete(admin)
|
||||
|
||||
async def add_by_user_id(self, user_id: int):
|
||||
async with self.mysql.Session() as session:
|
||||
admin = Admin(user_id=user_id)
|
||||
session.add(admin)
|
||||
await session.commit()
|
||||
|
||||
async def get_all_user_id(self) -> List[int]:
|
||||
async with self.mysql.Session() as session:
|
||||
query = select(Admin)
|
||||
results = await session.exec(query)
|
||||
admins = results.all()
|
||||
return [admin[0].user_id for admin in admins]
|
@ -1,60 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from asyncmy.errors import IntegrityError
|
||||
from telegram import Bot
|
||||
|
||||
from core.admin.cache import BotAdminCache, GroupAdminCache
|
||||
from core.admin.repositories import BotAdminRepository
|
||||
from core.config import config
|
||||
from utils.log import logger
|
||||
|
||||
|
||||
class BotAdminService:
|
||||
def __init__(self, repository: BotAdminRepository, cache: BotAdminCache):
|
||||
self._repository = repository
|
||||
self._cache = cache
|
||||
|
||||
async def get_admin_list(self) -> List[int]:
|
||||
admin_list = await self._cache.get_list()
|
||||
if len(admin_list) == 0:
|
||||
admin_list = await self._repository.get_all_user_id()
|
||||
for config_admin in config.admins:
|
||||
admin_list.append(config_admin.user_id)
|
||||
await self._cache.set_list(admin_list)
|
||||
return admin_list
|
||||
|
||||
async def add_admin(self, user_id: int) -> bool:
|
||||
try:
|
||||
await self._repository.add_by_user_id(user_id)
|
||||
except IntegrityError:
|
||||
logger.warning("用户 %s 已经存在 Admin 数据库", user_id)
|
||||
admin_list = await self._repository.get_all_user_id()
|
||||
for config_admin in config.admins:
|
||||
admin_list.append(config_admin.user_id)
|
||||
await self._cache.set_list(admin_list)
|
||||
return True
|
||||
|
||||
async def delete_admin(self, user_id: int) -> bool:
|
||||
try:
|
||||
await self._repository.delete_by_user_id(user_id)
|
||||
except ValueError:
|
||||
return False
|
||||
admin_list = await self._repository.get_all_user_id()
|
||||
for config_admin in config.admins:
|
||||
admin_list.append(config_admin.user_id)
|
||||
await self._cache.set_list(admin_list)
|
||||
return True
|
||||
|
||||
|
||||
class GroupAdminService:
|
||||
def __init__(self, cache: GroupAdminCache):
|
||||
self._cache = cache
|
||||
|
||||
async def get_admins(self, bot: Bot, chat_id: int, extra_user: List[int]) -> List[int]:
|
||||
admin_id_list = await self._cache.get_chat_admin(chat_id)
|
||||
if len(admin_id_list) == 0:
|
||||
admin_list = await bot.get_chat_administrators(chat_id)
|
||||
admin_id_list = [admin.user.id for admin in admin_list]
|
||||
await self._cache.set_chat_admin(chat_id, admin_id_list)
|
||||
admin_id_list += extra_user
|
||||
return admin_id_list
|
287
core/application.py
Normal file
287
core/application.py
Normal file
@ -0,0 +1,287 @@
|
||||
"""BOT"""
|
||||
import asyncio
|
||||
import signal
|
||||
from functools import wraps
|
||||
from signal import SIGABRT, SIGINT, SIGTERM, signal as signal_func
|
||||
from ssl import SSLZeroReturnError
|
||||
from typing import Callable, List, Optional, TYPE_CHECKING, TypeVar
|
||||
|
||||
import pytz
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from telegram import Bot, Update
|
||||
from telegram.error import NetworkError, TelegramError, TimedOut
|
||||
from telegram.ext import (
|
||||
Application as TelegramApplication,
|
||||
ApplicationBuilder as TelegramApplicationBuilder,
|
||||
Defaults,
|
||||
JobQueue,
|
||||
)
|
||||
from typing_extensions import ParamSpec
|
||||
from uvicorn import Server
|
||||
|
||||
from core.config import config as application_config
|
||||
from core.handler.limiterhandler import LimiterHandler
|
||||
from core.manager import Managers
|
||||
from core.override.telegram import HTTPXRequest
|
||||
from utils.const import WRAPPER_ASSIGNMENTS
|
||||
from utils.log import logger
|
||||
from utils.models.signal import Singleton
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from asyncio import Task
|
||||
from types import FrameType
|
||||
|
||||
__all__ = ("Application",)
|
||||
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class Application(Singleton):
|
||||
"""Application"""
|
||||
|
||||
_web_server_task: Optional["Task"] = None
|
||||
|
||||
_startup_funcs: List[Callable] = []
|
||||
_shutdown_funcs: List[Callable] = []
|
||||
|
||||
def __init__(self, managers: "Managers", telegram: "TelegramApplication", web_server: "Server") -> None:
|
||||
self._running = False
|
||||
self.managers = managers
|
||||
self.telegram = telegram
|
||||
self.web_server = web_server
|
||||
self.managers.set_application(application=self) # 给 managers 设置 application
|
||||
self.managers.build_executor("Application")
|
||||
|
||||
@classmethod
|
||||
def build(cls):
|
||||
managers = Managers()
|
||||
telegram = (
|
||||
TelegramApplicationBuilder()
|
||||
.get_updates_read_timeout(application_config.update_read_timeout)
|
||||
.get_updates_write_timeout(application_config.update_write_timeout)
|
||||
.get_updates_connect_timeout(application_config.update_connect_timeout)
|
||||
.get_updates_pool_timeout(application_config.update_pool_timeout)
|
||||
.defaults(Defaults(tzinfo=pytz.timezone("Asia/Shanghai")))
|
||||
.token(application_config.bot_token)
|
||||
.request(
|
||||
HTTPXRequest(
|
||||
connection_pool_size=application_config.connection_pool_size,
|
||||
proxy_url=application_config.proxy_url,
|
||||
read_timeout=application_config.read_timeout,
|
||||
write_timeout=application_config.write_timeout,
|
||||
connect_timeout=application_config.connect_timeout,
|
||||
pool_timeout=application_config.pool_timeout,
|
||||
)
|
||||
)
|
||||
.build()
|
||||
)
|
||||
web_server = Server(
|
||||
uvicorn.Config(
|
||||
app=FastAPI(debug=application_config.debug),
|
||||
port=application_config.webserver.port,
|
||||
host=application_config.webserver.host,
|
||||
log_config=None,
|
||||
)
|
||||
)
|
||||
return cls(managers, telegram, web_server)
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
"""bot 是否正在运行"""
|
||||
with self._lock:
|
||||
return self._running
|
||||
|
||||
@property
|
||||
def web_app(self) -> FastAPI:
|
||||
"""fastapi app"""
|
||||
return self.web_server.config.app
|
||||
|
||||
@property
|
||||
def bot(self) -> Optional[Bot]:
|
||||
return self.telegram.bot
|
||||
|
||||
@property
|
||||
def job_queue(self) -> Optional[JobQueue]:
|
||||
return self.telegram.job_queue
|
||||
|
||||
async def _on_startup(self) -> None:
|
||||
for func in self._startup_funcs:
|
||||
await self.managers.executor(func, block=getattr(func, "block", False))
|
||||
|
||||
async def _on_shutdown(self) -> None:
|
||||
for func in self._shutdown_funcs:
|
||||
await self.managers.executor(func, block=getattr(func, "block", False))
|
||||
|
||||
async def initialize(self):
|
||||
"""BOT 初始化"""
|
||||
self.telegram.add_handler(LimiterHandler(limit_time=10), group=-1) # 启用入口洪水限制
|
||||
await self.managers.start_dependency() # 启动基础服务
|
||||
await self.managers.init_components() # 实例化组件
|
||||
await self.managers.start_services() # 启动其他服务
|
||||
await self.managers.install_plugins() # 安装插件
|
||||
|
||||
async def shutdown(self):
|
||||
"""BOT 关闭"""
|
||||
await self.managers.uninstall_plugins() # 卸载插件
|
||||
await self.managers.stop_services() # 终止其他服务
|
||||
await self.managers.stop_dependency() # 终止基础服务
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动 BOT"""
|
||||
logger.info("正在启动 BOT 中...")
|
||||
|
||||
def error_callback(exc: TelegramError) -> None:
|
||||
"""错误信息回调"""
|
||||
self.telegram.create_task(self.telegram.process_error(error=exc, update=None))
|
||||
|
||||
await self.telegram.initialize()
|
||||
logger.info("[blue]Telegram[/] 初始化成功", extra={"markup": True})
|
||||
|
||||
if application_config.webserver.enable: # 如果使用 web app
|
||||
server_config = self.web_server.config
|
||||
server_config.setup_event_loop()
|
||||
if not server_config.loaded:
|
||||
server_config.load()
|
||||
self.web_server.lifespan = server_config.lifespan_class(server_config)
|
||||
try:
|
||||
await self.web_server.startup()
|
||||
except OSError as e:
|
||||
if e.errno == 10048:
|
||||
logger.error("Web Server 端口被占用:%s", e)
|
||||
logger.error("Web Server 启动失败,正在退出")
|
||||
raise SystemExit from None
|
||||
|
||||
if self.web_server.should_exit:
|
||||
logger.error("Web Server 启动失败,正在退出")
|
||||
raise SystemExit from None
|
||||
logger.success("Web Server 启动成功")
|
||||
|
||||
self._web_server_task = asyncio.create_task(self.web_server.main_loop())
|
||||
|
||||
for _ in range(5): # 连接至 telegram 服务器
|
||||
try:
|
||||
await self.telegram.updater.start_polling(
|
||||
error_callback=error_callback, allowed_updates=Update.ALL_TYPES
|
||||
)
|
||||
break
|
||||
except TimedOut:
|
||||
logger.warning("连接至 [blue]telegram[/] 服务器失败,正在重试", extra={"markup": True})
|
||||
continue
|
||||
except NetworkError as e:
|
||||
logger.exception()
|
||||
if isinstance(e, SSLZeroReturnError):
|
||||
logger.error("代理服务出现异常, 请检查您的代理服务是否配置成功.")
|
||||
else:
|
||||
logger.error("网络连接出现问题, 请检查您的网络状况.")
|
||||
raise SystemExit from e
|
||||
|
||||
await self.initialize()
|
||||
logger.success("BOT 初始化成功")
|
||||
logger.debug("BOT 开始启动")
|
||||
|
||||
await self._on_startup()
|
||||
await self.telegram.start()
|
||||
self._running = True
|
||||
logger.success("BOT 启动成功")
|
||||
|
||||
def stop_signal_handler(self, signum: int):
|
||||
"""终止信号处理"""
|
||||
signals = {k: v for v, k in signal.__dict__.items() if v.startswith("SIG") and not v.startswith("SIG_")}
|
||||
logger.debug("接收到了终止信号 %s 正在退出...", signals[signum])
|
||||
if self._web_server_task:
|
||||
self._web_server_task.cancel()
|
||||
|
||||
async def idle(self) -> None:
|
||||
"""在接收到中止信号之前,堵塞loop"""
|
||||
|
||||
task = None
|
||||
|
||||
def stop_handler(signum: int, _: "FrameType") -> None:
|
||||
self.stop_signal_handler(signum)
|
||||
task.cancel()
|
||||
|
||||
for s in (SIGINT, SIGTERM, SIGABRT):
|
||||
signal_func(s, stop_handler)
|
||||
|
||||
while True:
|
||||
task = asyncio.create_task(asyncio.sleep(600))
|
||||
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""关闭"""
|
||||
logger.info("BOT 正在关闭")
|
||||
self._running = False
|
||||
|
||||
await self._on_shutdown()
|
||||
|
||||
if self.telegram.updater.running:
|
||||
await self.telegram.updater.stop()
|
||||
|
||||
await self.shutdown()
|
||||
|
||||
if self.telegram.running:
|
||||
await self.telegram.stop()
|
||||
|
||||
await self.telegram.shutdown()
|
||||
if self.web_server is not None:
|
||||
try:
|
||||
await self.web_server.shutdown()
|
||||
logger.info("Web Server 已经关闭")
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
logger.success("BOT 关闭成功")
|
||||
|
||||
def launch(self) -> None:
|
||||
"""启动"""
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
loop.run_until_complete(self.start())
|
||||
loop.run_until_complete(self.idle())
|
||||
except (SystemExit, KeyboardInterrupt) as exc:
|
||||
logger.debug("接收到了终止信号,BOT 即将关闭", exc_info=exc) # 接收到了终止信号
|
||||
except NetworkError as e:
|
||||
if isinstance(e, SSLZeroReturnError):
|
||||
logger.critical("代理服务出现异常, 请检查您的代理服务是否配置成功.")
|
||||
else:
|
||||
logger.critical("网络连接出现问题, 请检查您的网络状况.")
|
||||
except Exception as e:
|
||||
logger.critical("遇到了未知错误: %s", {type(e)}, exc_info=e)
|
||||
finally:
|
||||
loop.run_until_complete(self.stop())
|
||||
|
||||
if application_config.reload:
|
||||
raise SystemExit from None
|
||||
|
||||
def on_startup(self, func: Callable[P, R]) -> Callable[P, R]:
|
||||
"""注册一个在 BOT 启动时执行的函数"""
|
||||
|
||||
if func not in self._startup_funcs:
|
||||
self._startup_funcs.append(func)
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
def on_shutdown(self, func: Callable[P, R]) -> Callable[P, R]:
|
||||
"""注册一个在 BOT 停止时执行的函数"""
|
||||
|
||||
if func not in self._shutdown_funcs:
|
||||
self._shutdown_funcs.append(func)
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
@ -1,31 +0,0 @@
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from typing_extensions import Self
|
||||
|
||||
from core.config import BotConfig
|
||||
from core.service import Service
|
||||
|
||||
|
||||
class MySQL(Service):
|
||||
@classmethod
|
||||
def from_config(cls, config: BotConfig) -> Self:
|
||||
return cls(**config.mysql.dict())
|
||||
|
||||
def __init__(self, host: str, port: int, username: str, password: str, database: str):
|
||||
self.database = database
|
||||
self.password = password
|
||||
self.user = username
|
||||
self.port = port
|
||||
self.host = host
|
||||
self.url = f"mysql+asyncmy://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
|
||||
self.engine = create_async_engine(self.url)
|
||||
self.Session = sessionmaker(bind=self.engine, class_=AsyncSession)
|
||||
|
||||
async def get_session(self):
|
||||
"""获取会话"""
|
||||
async with self.Session() as session:
|
||||
yield session
|
||||
|
||||
async def stop(self):
|
||||
self.Session.close_all()
|
@ -1,64 +0,0 @@
|
||||
import asyncio
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
|
||||
from core.config import (
|
||||
BotConfig,
|
||||
config as botConfig,
|
||||
)
|
||||
from core.service import Service
|
||||
|
||||
__all__ = ["webapp", "WebServer"]
|
||||
|
||||
webapp = FastAPI(debug=botConfig.debug)
|
||||
|
||||
|
||||
@webapp.get("/")
|
||||
def index():
|
||||
return {"Hello": "Paimon"}
|
||||
|
||||
|
||||
class WebServer(Service):
|
||||
debug: bool
|
||||
|
||||
host: str
|
||||
port: int
|
||||
|
||||
server: uvicorn.Server
|
||||
|
||||
_server_task: asyncio.Task
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: BotConfig) -> Service:
|
||||
return cls(debug=config.debug, host=config.webserver.host, port=config.webserver.port)
|
||||
|
||||
def __init__(self, debug: bool, host: str, port: int):
|
||||
self.debug = debug
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
self.server = uvicorn.Server(
|
||||
uvicorn.Config(app=webapp, port=port, use_colors=False, host=host, log_config=None)
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
"""启动 service"""
|
||||
|
||||
# 暂时只在开发环境启动 webserver 用于开发调试
|
||||
if not self.debug:
|
||||
return
|
||||
|
||||
# 防止 uvicorn server 拦截 signals
|
||||
self.server.install_signal_handlers = lambda: None
|
||||
self._server_task = asyncio.create_task(self.server.serve())
|
||||
|
||||
async def stop(self):
|
||||
"""关闭 service"""
|
||||
if not self.debug:
|
||||
return
|
||||
|
||||
self.server.should_exit = True
|
||||
|
||||
# 等待 task 结束
|
||||
await self._server_task
|
60
core/base_service.py
Normal file
60
core/base_service.py
Normal file
@ -0,0 +1,60 @@
|
||||
from abc import ABC
|
||||
from itertools import chain
|
||||
from typing import ClassVar, Iterable, Type, TypeVar
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from utils.helpers import isabstract
|
||||
|
||||
__all__ = ("BaseService", "BaseServiceType", "DependenceType", "ComponentType", "get_all_services")
|
||||
|
||||
|
||||
class _BaseService:
|
||||
"""服务基类"""
|
||||
|
||||
_is_component: ClassVar[bool] = False
|
||||
_is_dependence: ClassVar[bool] = False
|
||||
|
||||
def __init_subclass__(cls, load: bool = True, **kwargs):
|
||||
cls.is_dependence = cls._is_dependence
|
||||
cls.is_component = cls._is_component
|
||||
cls.load = load
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
await self.initialize()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
await self.shutdown()
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize resources used by this service"""
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Stop & clear resources used by this service"""
|
||||
|
||||
|
||||
class _Dependence(_BaseService, ABC):
|
||||
_is_dependence: ClassVar[bool] = True
|
||||
|
||||
|
||||
class _Component(_BaseService, ABC):
|
||||
_is_component: ClassVar[bool] = True
|
||||
|
||||
|
||||
class BaseService(_BaseService, ABC):
|
||||
Dependence: Type[_BaseService] = _Dependence
|
||||
Component: Type[_BaseService] = _Component
|
||||
|
||||
|
||||
BaseServiceType = TypeVar("BaseServiceType", bound=_BaseService)
|
||||
DependenceType = TypeVar("DependenceType", bound=_Dependence)
|
||||
ComponentType = TypeVar("ComponentType", bound=_Component)
|
||||
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
def get_all_services() -> Iterable[Type[_BaseService]]:
|
||||
return filter(
|
||||
lambda x: x.__name__[0] != "_" and x.load and not isabstract(x),
|
||||
chain(BaseService.__subclasses__(), _Dependence.__subclasses__(), _Component.__subclasses__()),
|
||||
)
|
29
core/basemodel.py
Normal file
29
core/basemodel.py
Normal file
@ -0,0 +1,29 @@
|
||||
import enum
|
||||
|
||||
try:
|
||||
import ujson as jsonlib
|
||||
except ImportError:
|
||||
import json as jsonlib
|
||||
|
||||
from pydantic import BaseSettings
|
||||
|
||||
__all__ = ("RegionEnum", "Settings")
|
||||
|
||||
|
||||
class RegionEnum(int, enum.Enum):
|
||||
"""账号数据所在服务器"""
|
||||
|
||||
NULL = 0
|
||||
HYPERION = 1 # 米忽悠国服 hyperion
|
||||
HOYOLAB = 2 # 米忽悠国际服 hoyolab
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
def __new__(cls, *args, **kwargs):
|
||||
cls.update_forward_refs()
|
||||
return super(Settings, cls).__new__(cls) # pylint: disable=E1120
|
||||
|
||||
class Config(BaseSettings.Config):
|
||||
case_sensitive = False
|
||||
json_loads = jsonlib.loads
|
||||
json_dumps = jsonlib.dumps
|
@ -1,69 +0,0 @@
|
||||
from telegram import Update, ReplyKeyboardRemove
|
||||
from telegram.error import BadRequest, Forbidden
|
||||
from telegram.ext import CallbackContext, ConversationHandler
|
||||
|
||||
from core.plugin import handler, conversation
|
||||
from utils.bot import get_chat
|
||||
from utils.log import logger
|
||||
|
||||
|
||||
async def clean_message(context: CallbackContext):
|
||||
job = context.job
|
||||
message_id = job.data
|
||||
chat_info = f"chat_id[{job.chat_id}]"
|
||||
try:
|
||||
chat = await get_chat(job.chat_id)
|
||||
full_name = chat.full_name
|
||||
if full_name:
|
||||
chat_info = f"{full_name}[{chat.id}]"
|
||||
else:
|
||||
chat_info = f"{chat.title}[{chat.id}]"
|
||||
except (BadRequest, Forbidden) as exc:
|
||||
logger.warning("获取 chat info 失败 %s", exc.message)
|
||||
except Exception as exc:
|
||||
logger.warning("获取 chat info 消息失败 %s", str(exc))
|
||||
logger.debug("删除消息 %s message_id[%s]", chat_info, message_id)
|
||||
try:
|
||||
# noinspection PyTypeChecker
|
||||
await context.bot.delete_message(chat_id=job.chat_id, message_id=message_id)
|
||||
except BadRequest as exc:
|
||||
if "not found" in exc.message:
|
||||
logger.warning("删除消息 %s message_id[%s] 失败 消息不存在", chat_info, message_id)
|
||||
elif "Message can't be deleted" in exc.message:
|
||||
logger.warning("删除消息 %s message_id[%s] 失败 消息无法删除 可能是没有授权", chat_info, message_id)
|
||||
else:
|
||||
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
|
||||
except Forbidden as exc:
|
||||
if "bot was kicked" in exc.message:
|
||||
logger.warning("删除消息 %s message_id[%s] 失败 已经被踢出群", chat_info, message_id)
|
||||
else:
|
||||
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
|
||||
|
||||
|
||||
def add_delete_message_job(context: CallbackContext, chat_id: int, message_id: int, delete_seconds: int):
|
||||
context.job_queue.run_once(
|
||||
callback=clean_message,
|
||||
when=delete_seconds,
|
||||
data=message_id,
|
||||
name=f"{chat_id}|{message_id}|clean_message",
|
||||
chat_id=chat_id,
|
||||
job_kwargs={"replace_existing": True, "id": f"{chat_id}|{message_id}|clean_message"},
|
||||
)
|
||||
|
||||
|
||||
class _BasePlugin:
|
||||
@staticmethod
|
||||
def _add_delete_message_job(context: CallbackContext, chat_id: int, message_id: int, delete_seconds: int = 60):
|
||||
return add_delete_message_job(context, chat_id, message_id, delete_seconds)
|
||||
|
||||
|
||||
class _Conversation(_BasePlugin):
|
||||
@conversation.fallback
|
||||
@handler.command(command="cancel", block=True)
|
||||
async def cancel(self, update: Update, _: CallbackContext) -> int:
|
||||
await update.effective_message.reply_text("退出命令", reply_markup=ReplyKeyboardRemove())
|
||||
return ConversationHandler.END
|
||||
|
||||
|
||||
class BasePlugin(_BasePlugin):
|
||||
Conversation = _Conversation
|
345
core/bot.py
345
core/bot.py
@ -1,345 +0,0 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
from asyncio import CancelledError
|
||||
from importlib import import_module
|
||||
from multiprocessing import RLock as Lock
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, ClassVar, Dict, Iterator, List, NoReturn, Optional, TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import genshin
|
||||
import pytz
|
||||
from async_timeout import timeout
|
||||
from telegram import Update
|
||||
from telegram import __version__ as tg_version
|
||||
from telegram.error import NetworkError, TimedOut
|
||||
from telegram.ext import (
|
||||
AIORateLimiter,
|
||||
Application as TgApplication,
|
||||
CallbackContext,
|
||||
Defaults,
|
||||
JobQueue,
|
||||
MessageHandler,
|
||||
filters,
|
||||
TypeHandler,
|
||||
)
|
||||
from telegram.ext.filters import StatusUpdate
|
||||
|
||||
from core.config import BotConfig, config # pylint: disable=W0611
|
||||
from core.error import ServiceNotFoundError
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
from core.plugin import Plugin, _Plugin
|
||||
from core.service import Service
|
||||
from metadata.scripts.metadatas import make_github_fast
|
||||
from utils.const import PLUGIN_DIR, PROJECT_ROOT
|
||||
from utils.log import logger
|
||||
|
||||
|
||||
__all__ = ["bot"]
|
||||
|
||||
T = TypeVar("T")
|
||||
PluginType = TypeVar("PluginType", bound=_Plugin)
|
||||
|
||||
try:
|
||||
from telegram import __version_info__ as tg_version_info
|
||||
except ImportError:
|
||||
tg_version_info = (0, 0, 0, 0, 0) # type: ignore[assignment]
|
||||
|
||||
if tg_version_info < (20, 0, 0, "alpha", 6):
|
||||
logger.warning(
|
||||
"Bot与当前PTB版本 [cyan bold]%s[/] [red bold]不兼容[/],请更新到最新版本后使用 [blue bold]poetry install[/] 重新安装依赖",
|
||||
tg_version,
|
||||
extra={"markup": True},
|
||||
)
|
||||
|
||||
|
||||
class Bot:
|
||||
_lock: ClassVar[Lock] = Lock()
|
||||
_instance: ClassVar[Optional["Bot"]] = None
|
||||
|
||||
def __new__(cls, *args, **kwargs) -> "Bot":
|
||||
"""实现单例"""
|
||||
with cls._lock: # 使线程、进程安全
|
||||
if cls._instance is None:
|
||||
cls._instance = object.__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
app: Optional[TgApplication] = None
|
||||
_config: BotConfig = config
|
||||
_services: Dict[Type[T], T] = {}
|
||||
_running: bool = False
|
||||
|
||||
def _inject(self, signature: inspect.Signature, target: Callable[..., T]) -> T:
|
||||
kwargs = {}
|
||||
for name, parameter in signature.parameters.items():
|
||||
if name != "self" and parameter.annotation != inspect.Parameter.empty:
|
||||
if value := self._services.get(parameter.annotation):
|
||||
kwargs[name] = value
|
||||
return target(**kwargs)
|
||||
|
||||
def init_inject(self, target: Callable[..., T]) -> T:
|
||||
"""用于实例化Plugin的方法。用于给插件传入一些必要组件,如 MySQL、Redis等"""
|
||||
if isinstance(target, type):
|
||||
signature = inspect.signature(target.__init__)
|
||||
else:
|
||||
signature = inspect.signature(target)
|
||||
return self._inject(signature, target)
|
||||
|
||||
async def async_inject(self, target: Callable[..., T]) -> T:
|
||||
return await self._inject(inspect.signature(target), target)
|
||||
|
||||
def _gen_pkg(self, root: Path) -> Iterator[str]:
|
||||
"""生成可以用于 import_module 导入的字符串"""
|
||||
for path in root.iterdir():
|
||||
if not path.name.startswith("_"):
|
||||
if path.is_dir():
|
||||
yield from self._gen_pkg(path)
|
||||
elif path.suffix == ".py":
|
||||
yield str(path.relative_to(PROJECT_ROOT).with_suffix("")).replace(os.sep, ".")
|
||||
|
||||
async def install_plugins(self):
|
||||
"""安装插件"""
|
||||
for pkg in self._gen_pkg(PLUGIN_DIR):
|
||||
try:
|
||||
import_module(pkg) # 导入插件
|
||||
except Exception as e: # pylint: disable=W0703
|
||||
logger.exception(
|
||||
'在导入文件 "%s" 的过程中遇到了错误 [red bold]%s[/]', pkg, type(e).__name__, exc_info=e, extra={"markup": True}
|
||||
)
|
||||
continue # 如有错误则继续
|
||||
callback_dict: Dict[int, List[Callable]] = {}
|
||||
for plugin_cls in {*Plugin.__subclasses__(), *Plugin.Conversation.__subclasses__()}:
|
||||
path = f"{plugin_cls.__module__}.{plugin_cls.__name__}"
|
||||
try:
|
||||
plugin: PluginType = self.init_inject(plugin_cls)
|
||||
if hasattr(plugin, "__async_init__"):
|
||||
await self.async_inject(plugin.__async_init__)
|
||||
handlers = plugin.handlers
|
||||
for index, handler in enumerate(handlers):
|
||||
if isinstance(handler, TypeHandler): # 对 TypeHandler 进行特殊处理,优先级必须设置 -1,否则无用
|
||||
handlers.pop(index)
|
||||
self.app.add_handler(handler, group=-1)
|
||||
self.app.add_handlers(handlers)
|
||||
if handlers:
|
||||
logger.debug('插件 "%s" 添加了 %s 个 handler ', path, len(handlers))
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
for priority, callback in plugin._new_chat_members_handler_funcs(): # pylint: disable=W0212
|
||||
if not callback_dict.get(priority):
|
||||
callback_dict[priority] = []
|
||||
callback_dict[priority].append(callback)
|
||||
|
||||
error_handlers = plugin.error_handlers
|
||||
for callback, block in error_handlers.items():
|
||||
self.app.add_error_handler(callback, block)
|
||||
if error_handlers:
|
||||
logger.debug('插件 "%s" 添加了 %s 个 error handler ', path, len(error_handlers))
|
||||
|
||||
if jobs := plugin.jobs:
|
||||
logger.debug('插件 "%s" 添加了 %s 个 jobs ', path, len(jobs))
|
||||
logger.success('插件 "%s" 载入成功', path)
|
||||
except Exception as e: # pylint: disable=W0703
|
||||
logger.exception(
|
||||
'在安装插件 "%s" 的过程中遇到了错误 [red bold]%s[/]', path, type(e).__name__, exc_info=e, extra={"markup": True}
|
||||
)
|
||||
if callback_dict:
|
||||
num = sum(len(callback_dict[i]) for i in callback_dict)
|
||||
|
||||
async def _new_chat_member_callback(update: "Update", context: "CallbackContext"):
|
||||
nonlocal callback
|
||||
for _, value in callback_dict.items():
|
||||
for callback in value:
|
||||
await callback(update, context)
|
||||
|
||||
self.app.add_handler(
|
||||
MessageHandler(callback=_new_chat_member_callback, filters=StatusUpdate.NEW_CHAT_MEMBERS, block=False)
|
||||
)
|
||||
logger.success(
|
||||
"成功添加了 %s 个针对 [blue]%s[/] 的 [blue]MessageHandler[/]",
|
||||
num,
|
||||
StatusUpdate.NEW_CHAT_MEMBERS,
|
||||
extra={"markup": True},
|
||||
)
|
||||
# special handler
|
||||
from plugins.system.start import StartPlugin
|
||||
|
||||
self.app.add_handler(
|
||||
MessageHandler(
|
||||
callback=StartPlugin.unknown_command, filters=filters.COMMAND & filters.ChatType.PRIVATE, block=False
|
||||
)
|
||||
)
|
||||
|
||||
async def _start_base_services(self):
|
||||
for pkg in self._gen_pkg(PROJECT_ROOT / "core/base"):
|
||||
try:
|
||||
import_module(pkg)
|
||||
except Exception as e: # pylint: disable=W0703
|
||||
logger.exception(
|
||||
'在导入文件 "%s" 的过程中遇到了错误 [red bold]%s[/]', pkg, type(e).__name__, exc_info=e, extra={"markup": True}
|
||||
)
|
||||
raise SystemExit from e
|
||||
for base_service_cls in Service.__subclasses__():
|
||||
try:
|
||||
if hasattr(base_service_cls, "from_config"):
|
||||
instance = base_service_cls.from_config(self._config)
|
||||
else:
|
||||
instance = self.init_inject(base_service_cls)
|
||||
await instance.start()
|
||||
logger.success('服务 "%s" 初始化成功', base_service_cls.__name__)
|
||||
self._services.update({base_service_cls: instance})
|
||||
except Exception as e:
|
||||
logger.error('服务 "%s" 初始化失败', base_service_cls.__name__)
|
||||
raise SystemExit from e
|
||||
|
||||
async def start_services(self):
|
||||
"""启动服务"""
|
||||
await self._start_base_services()
|
||||
for path in (PROJECT_ROOT / "core").iterdir():
|
||||
if not path.name.startswith("_") and path.is_dir() and path.name != "base":
|
||||
pkg = str(path.relative_to(PROJECT_ROOT).with_suffix("")).replace(os.sep, ".")
|
||||
try:
|
||||
import_module(pkg)
|
||||
except Exception as e: # pylint: disable=W0703
|
||||
logger.exception(
|
||||
'在导入文件 "%s" 的过程中遇到了错误 [red bold]%s[/]',
|
||||
pkg,
|
||||
type(e).__name__,
|
||||
exc_info=e,
|
||||
extra={"markup": True},
|
||||
)
|
||||
continue
|
||||
|
||||
async def stop_services(self):
|
||||
"""关闭服务"""
|
||||
if not self._services:
|
||||
return
|
||||
logger.info("正在关闭服务")
|
||||
for _, service in filter(lambda x: not isinstance(x[1], TgApplication), self._services.items()):
|
||||
async with timeout(5):
|
||||
try:
|
||||
if hasattr(service, "stop"):
|
||||
if inspect.iscoroutinefunction(service.stop):
|
||||
await service.stop()
|
||||
else:
|
||||
service.stop()
|
||||
logger.success('服务 "%s" 关闭成功', service.__class__.__name__)
|
||||
except CancelledError:
|
||||
logger.warning('服务 "%s" 关闭超时', service.__class__.__name__)
|
||||
except Exception as e: # pylint: disable=W0703
|
||||
logger.exception('服务 "%s" 关闭失败', service.__class__.__name__, exc_info=e)
|
||||
|
||||
async def _post_init(self, context: CallbackContext) -> NoReturn:
|
||||
logger.info("开始初始化 genshin.py 相关资源")
|
||||
try:
|
||||
# 替换为 fastgit 镜像源
|
||||
for i in dir(genshin.utility.extdb):
|
||||
if "_URL" in i:
|
||||
setattr(
|
||||
genshin.utility.extdb,
|
||||
i,
|
||||
make_github_fast(getattr(genshin.utility.extdb, i)),
|
||||
)
|
||||
await genshin.utility.update_characters_enka()
|
||||
except Exception as exc: # pylint: disable=W0703
|
||||
logger.error("初始化 genshin.py 相关资源失败")
|
||||
logger.exception(exc)
|
||||
else:
|
||||
logger.success("初始化 genshin.py 相关资源成功")
|
||||
self._services.update({CallbackContext: context})
|
||||
logger.info("开始初始化服务")
|
||||
await self.start_services()
|
||||
logger.info("开始安装插件")
|
||||
await self.install_plugins()
|
||||
logger.info("BOT 初始化成功")
|
||||
|
||||
def launch(self) -> NoReturn:
|
||||
"""启动机器人"""
|
||||
self._running = True
|
||||
logger.info("正在初始化BOT")
|
||||
self.app = (
|
||||
TgApplication.builder()
|
||||
.read_timeout(self.config.read_timeout)
|
||||
.write_timeout(self.config.write_timeout)
|
||||
.connect_timeout(self.config.connect_timeout)
|
||||
.pool_timeout(self.config.pool_timeout)
|
||||
.get_updates_read_timeout(self.config.update_read_timeout)
|
||||
.get_updates_write_timeout(self.config.update_write_timeout)
|
||||
.get_updates_connect_timeout(self.config.update_connect_timeout)
|
||||
.get_updates_pool_timeout(self.config.update_pool_timeout)
|
||||
.rate_limiter(AIORateLimiter())
|
||||
.defaults(Defaults(tzinfo=pytz.timezone("Asia/Shanghai")))
|
||||
.token(self._config.bot_token)
|
||||
.post_init(self._post_init)
|
||||
.build()
|
||||
)
|
||||
try:
|
||||
for _ in range(5):
|
||||
try:
|
||||
self.app.run_polling(
|
||||
close_loop=False,
|
||||
timeout=self.config.timeout,
|
||||
allowed_updates=Update.ALL_TYPES,
|
||||
)
|
||||
break
|
||||
except TimedOut:
|
||||
logger.warning("连接至 [blue]telegram[/] 服务器失败,正在重试", extra={"markup": True})
|
||||
continue
|
||||
except NetworkError as e:
|
||||
if "SSLZeroReturnError" in str(e):
|
||||
logger.error("代理服务出现异常, 请检查您的代理服务是否配置成功.")
|
||||
else:
|
||||
logger.error("网络连接出现问题, 请检查您的网络状况.")
|
||||
break
|
||||
except (SystemExit, KeyboardInterrupt):
|
||||
pass
|
||||
except Exception as e: # pylint: disable=W0703
|
||||
logger.exception("BOT 执行过程中出现错误", exc_info=e)
|
||||
finally:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(self.stop_services())
|
||||
loop.close()
|
||||
logger.info("BOT 已经关闭")
|
||||
self._running = False
|
||||
|
||||
def find_service(self, target: Type[T]) -> T:
|
||||
"""查找服务。若没找到则抛出 ServiceNotFoundError"""
|
||||
if (result := self._services.get(target)) is None:
|
||||
raise ServiceNotFoundError(target)
|
||||
return result
|
||||
|
||||
def add_service(self, service: T) -> NoReturn:
|
||||
"""添加服务。若已经有同类型的服务,则会抛出异常"""
|
||||
if type(service) in self._services:
|
||||
raise ValueError(f'Service "{type(service)}" is already existed.')
|
||||
self.update_service(service)
|
||||
|
||||
def update_service(self, service: T):
|
||||
"""更新服务。若服务不存在,则添加;若存在,则更新"""
|
||||
self._services.update({type(service): service})
|
||||
|
||||
def contain_service(self, service: Any) -> bool:
|
||||
"""判断服务是否存在"""
|
||||
if isinstance(service, type):
|
||||
return service in self._services
|
||||
else:
|
||||
return service in self._services.values()
|
||||
|
||||
@property
|
||||
def job_queue(self) -> JobQueue:
|
||||
return self.app.job_queue
|
||||
|
||||
@property
|
||||
def services(self) -> Dict[Type[T], T]:
|
||||
return self._services
|
||||
|
||||
@property
|
||||
def config(self) -> BotConfig:
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return self._running
|
||||
|
||||
|
||||
bot = Bot()
|
1
core/builtins/__init__.py
Normal file
1
core/builtins/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""bot builtins"""
|
38
core/builtins/contexts.py
Normal file
38
core/builtins/contexts.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""上下文管理"""
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from telegram.ext import CallbackContext
|
||||
from telegram import Update
|
||||
|
||||
__all__ = [
|
||||
"CallbackContextCV",
|
||||
"UpdateCV",
|
||||
"handler_contexts",
|
||||
"job_contexts",
|
||||
]
|
||||
|
||||
CallbackContextCV: ContextVar["CallbackContext"] = ContextVar("TelegramContextCallback")
|
||||
UpdateCV: ContextVar["Update"] = ContextVar("TelegramUpdate")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def handler_contexts(update: "Update", context: "CallbackContext") -> None:
|
||||
context_token = CallbackContextCV.set(context)
|
||||
update_token = UpdateCV.set(update)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
CallbackContextCV.reset(context_token)
|
||||
UpdateCV.reset(update_token)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def job_contexts(context: "CallbackContext") -> None:
|
||||
token = CallbackContextCV.set(context)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
CallbackContextCV.reset(token)
|
309
core/builtins/dispatcher.py
Normal file
309
core/builtins/dispatcher.py
Normal file
@ -0,0 +1,309 @@
|
||||
"""参数分发器"""
|
||||
import asyncio
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import AbstractEventLoop
|
||||
from functools import cached_property, lru_cache, partial, wraps
|
||||
from inspect import Parameter, Signature
|
||||
from itertools import chain
|
||||
from types import GenericAlias, MethodType
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from arkowrapper import ArkoWrapper
|
||||
from fastapi import FastAPI
|
||||
from telegram import Bot as TelegramBot, Chat, Message, Update, User
|
||||
from telegram.ext import Application as TelegramApplication, CallbackContext, Job
|
||||
from typing_extensions import ParamSpec
|
||||
from uvicorn import Server
|
||||
|
||||
from core.application import Application
|
||||
from utils.const import WRAPPER_ASSIGNMENTS
|
||||
from utils.typedefs import R, T
|
||||
|
||||
__all__ = (
|
||||
"catch",
|
||||
"AbstractDispatcher",
|
||||
"BaseDispatcher",
|
||||
"HandlerDispatcher",
|
||||
"JobDispatcher",
|
||||
"dispatched",
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
TargetType = Union[Type, str, Callable[[Any], bool]]
|
||||
|
||||
_CATCH_TARGET_ATTR = "_catch_targets"
|
||||
|
||||
|
||||
def catch(*targets: Union[str, Type]) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def decorate(func: Callable[P, R]) -> Callable[P, R]:
|
||||
setattr(func, _CATCH_TARGET_ATTR, targets)
|
||||
|
||||
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
@lru_cache(64)
|
||||
def get_signature(func: Union[type, Callable]) -> Signature:
|
||||
if isinstance(func, type):
|
||||
return inspect.signature(func.__init__)
|
||||
return inspect.signature(func)
|
||||
|
||||
|
||||
class AbstractDispatcher(ABC):
|
||||
"""参数分发器"""
|
||||
|
||||
IGNORED_ATTRS = []
|
||||
|
||||
_args: List[Any] = []
|
||||
_kwargs: Dict[Union[str, Type], Any] = {}
|
||||
_application: "Optional[Application]" = None
|
||||
|
||||
def set_application(self, application: "Application") -> None:
|
||||
self._application = application
|
||||
|
||||
@property
|
||||
def application(self) -> "Application":
|
||||
if self._application is None:
|
||||
raise RuntimeError(f"No application was set for this {self.__class__.__name__}.")
|
||||
return self._application
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
self._args = list(args)
|
||||
self._kwargs = dict(kwargs)
|
||||
|
||||
for _, value in kwargs.items():
|
||||
type_arg = type(value)
|
||||
if type_arg != str:
|
||||
self._kwargs[type_arg] = value
|
||||
|
||||
for arg in args:
|
||||
type_arg = type(arg)
|
||||
if type_arg != str:
|
||||
self._kwargs[type_arg] = arg
|
||||
|
||||
@cached_property
|
||||
def catch_funcs(self) -> List[MethodType]:
|
||||
# noinspection PyTypeChecker
|
||||
return list(
|
||||
ArkoWrapper(dir(self))
|
||||
.filter(lambda x: not x.startswith("_"))
|
||||
.filter(
|
||||
lambda x: x not in self.IGNORED_ATTRS + ["dispatch", "catch_funcs", "catch_func_map", "dispatch_funcs"]
|
||||
)
|
||||
.map(lambda x: getattr(self, x))
|
||||
.filter(lambda x: isinstance(x, MethodType))
|
||||
.filter(lambda x: hasattr(x, "_catch_targets"))
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def catch_func_map(self) -> Dict[Union[str, Type[T]], Callable[..., T]]:
|
||||
result = {}
|
||||
for catch_func in self.catch_funcs:
|
||||
catch_targets = getattr(catch_func, _CATCH_TARGET_ATTR)
|
||||
for catch_target in catch_targets:
|
||||
result[catch_target] = catch_func
|
||||
return result
|
||||
|
||||
@cached_property
|
||||
def dispatch_funcs(self) -> List[MethodType]:
|
||||
return list(
|
||||
ArkoWrapper(dir(self))
|
||||
.filter(lambda x: x.startswith("dispatch_by_"))
|
||||
.map(lambda x: getattr(self, x))
|
||||
.filter(lambda x: isinstance(x, MethodType))
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def dispatch_by_default(self, parameter: Parameter) -> Parameter:
|
||||
"""默认的 dispatch 方法"""
|
||||
|
||||
@abstractmethod
|
||||
def dispatch_by_catch_funcs(self, parameter: Parameter) -> Parameter:
|
||||
"""使用 catch_func 获取并分配参数"""
|
||||
|
||||
def dispatch(self, func: Callable[P, R]) -> Callable[..., R]:
|
||||
"""将参数分配给函数,从而合成一个无需参数即可执行的函数"""
|
||||
params = {}
|
||||
signature = get_signature(func)
|
||||
parameters: Dict[str, Parameter] = dict(signature.parameters)
|
||||
|
||||
for name, parameter in list(parameters.items()):
|
||||
parameter: Parameter
|
||||
if any(
|
||||
[
|
||||
name == "self" and isinstance(func, (type, MethodType)),
|
||||
parameter.kind in [Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL],
|
||||
]
|
||||
):
|
||||
del parameters[name]
|
||||
continue
|
||||
|
||||
for dispatch_func in self.dispatch_funcs:
|
||||
parameters[name] = dispatch_func(parameter)
|
||||
|
||||
for name, parameter in parameters.items():
|
||||
if parameter.default != Parameter.empty:
|
||||
params[name] = parameter.default
|
||||
else:
|
||||
params[name] = None
|
||||
|
||||
return partial(func, **params)
|
||||
|
||||
@catch(Application)
|
||||
def catch_application(self) -> Application:
|
||||
return self.application
|
||||
|
||||
|
||||
class BaseDispatcher(AbstractDispatcher):
|
||||
"""默认参数分发器"""
|
||||
|
||||
_instances: Sequence[Any]
|
||||
|
||||
def _get_kwargs(self) -> Dict[Type[T], T]:
|
||||
result = self._get_default_kwargs()
|
||||
result[AbstractDispatcher] = self
|
||||
result.update(self._kwargs)
|
||||
return result
|
||||
|
||||
def _get_default_kwargs(self) -> Dict[Type[T], T]:
|
||||
application = self.application
|
||||
_default_kwargs = {
|
||||
FastAPI: application.web_app,
|
||||
Server: application.web_server,
|
||||
TelegramApplication: application.telegram,
|
||||
TelegramBot: application.telegram.bot,
|
||||
}
|
||||
if not application.running:
|
||||
for obj in chain(
|
||||
application.managers.dependency,
|
||||
application.managers.components,
|
||||
application.managers.services,
|
||||
application.managers.plugins,
|
||||
):
|
||||
_default_kwargs[type(obj)] = obj
|
||||
return {k: v for k, v in _default_kwargs.items() if v is not None}
|
||||
|
||||
def dispatch_by_default(self, parameter: Parameter) -> Parameter:
|
||||
annotation = parameter.annotation
|
||||
# noinspection PyTypeChecker
|
||||
if isinstance(annotation, type) and (value := self._get_kwargs().get(annotation, None)) is not None:
|
||||
parameter._default = value # pylint: disable=W0212
|
||||
return parameter
|
||||
|
||||
def dispatch_by_catch_funcs(self, parameter: Parameter) -> Parameter:
|
||||
annotation = parameter.annotation
|
||||
if annotation != Any and isinstance(annotation, GenericAlias):
|
||||
return parameter
|
||||
|
||||
catch_func = self.catch_func_map.get(annotation) or self.catch_func_map.get(parameter.name)
|
||||
if catch_func is not None:
|
||||
# noinspection PyUnresolvedReferences,PyProtectedMember
|
||||
parameter._default = catch_func() # pylint: disable=W0212
|
||||
return parameter
|
||||
|
||||
@catch(AbstractEventLoop)
|
||||
def catch_loop(self) -> AbstractEventLoop:
|
||||
return asyncio.get_event_loop()
|
||||
|
||||
|
||||
class HandlerDispatcher(BaseDispatcher):
|
||||
"""Handler 参数分发器"""
|
||||
|
||||
def __init__(self, update: Optional[Update] = None, context: Optional[CallbackContext] = None, **kwargs) -> None:
|
||||
super().__init__(update=update, context=context, **kwargs)
|
||||
self._update = update
|
||||
self._context = context
|
||||
|
||||
def dispatch(
|
||||
self, func: Callable[P, R], *, update: Optional[Update] = None, context: Optional[CallbackContext] = None
|
||||
) -> Callable[..., R]:
|
||||
self._update = update or self._update
|
||||
self._context = context or self._context
|
||||
if self._update is None:
|
||||
from core.builtins.contexts import UpdateCV
|
||||
|
||||
self._update = UpdateCV.get()
|
||||
if self._context is None:
|
||||
from core.builtins.contexts import CallbackContextCV
|
||||
|
||||
self._context = CallbackContextCV.get()
|
||||
return super().dispatch(func)
|
||||
|
||||
def dispatch_by_default(self, parameter: Parameter) -> Parameter:
|
||||
"""HandlerDispatcher 默认不使用 dispatch_by_default"""
|
||||
return parameter
|
||||
|
||||
@catch(Update)
|
||||
def catch_update(self) -> Update:
|
||||
return self._update
|
||||
|
||||
@catch(CallbackContext)
|
||||
def catch_context(self) -> CallbackContext:
|
||||
return self._context
|
||||
|
||||
@catch(Message)
|
||||
def catch_message(self) -> Message:
|
||||
return self._update.effective_message
|
||||
|
||||
@catch(User)
|
||||
def catch_user(self) -> User:
|
||||
return self._update.effective_user
|
||||
|
||||
@catch(Chat)
|
||||
def catch_chat(self) -> Chat:
|
||||
return self._update.effective_chat
|
||||
|
||||
|
||||
class JobDispatcher(BaseDispatcher):
|
||||
"""Job 参数分发器"""
|
||||
|
||||
def __init__(self, context: Optional[CallbackContext] = None, **kwargs) -> None:
|
||||
super().__init__(context=context, **kwargs)
|
||||
self._context = context
|
||||
|
||||
def dispatch(self, func: Callable[P, R], *, context: Optional[CallbackContext] = None) -> Callable[..., R]:
|
||||
self._context = context or self._context
|
||||
if self._context is None:
|
||||
from core.builtins.contexts import CallbackContextCV
|
||||
|
||||
self._context = CallbackContextCV.get()
|
||||
return super().dispatch(func)
|
||||
|
||||
@catch("data")
|
||||
def catch_data(self) -> Any:
|
||||
return self._context.job.data
|
||||
|
||||
@catch(Job)
|
||||
def catch_job(self) -> Job:
|
||||
return self._context.job
|
||||
|
||||
@catch(CallbackContext)
|
||||
def catch_context(self) -> CallbackContext:
|
||||
return self._context
|
||||
|
||||
|
||||
def dispatched(dispatcher: Type[AbstractDispatcher] = BaseDispatcher):
|
||||
def decorate(func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
return dispatcher().dispatch(func)(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorate
|
131
core/builtins/executor.py
Normal file
131
core/builtins/executor.py
Normal file
@ -0,0 +1,131 @@
|
||||
"""执行器"""
|
||||
import inspect
|
||||
from functools import cached_property
|
||||
from multiprocessing import RLock as Lock
|
||||
from typing import Callable, ClassVar, Dict, Generic, Optional, TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
from telegram import Update
|
||||
from telegram.ext import CallbackContext
|
||||
from typing_extensions import ParamSpec, Self
|
||||
|
||||
from core.builtins.contexts import handler_contexts, job_contexts
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.application import Application
|
||||
from core.builtins.dispatcher import AbstractDispatcher, HandlerDispatcher
|
||||
from multiprocessing.synchronize import RLock as LockType
|
||||
|
||||
__all__ = ("BaseExecutor", "Executor", "HandlerExecutor", "JobExecutor")
|
||||
|
||||
T = TypeVar("T")
|
||||
R = TypeVar("R")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class BaseExecutor:
|
||||
"""执行器
|
||||
Args:
|
||||
name(str): 该执行器的名称。执行器的名称是唯一的。
|
||||
|
||||
只支持执行只拥有 POSITIONAL_OR_KEYWORD 和 KEYWORD_ONLY 两种参数类型的函数
|
||||
"""
|
||||
|
||||
_lock: ClassVar["LockType"] = Lock()
|
||||
_instances: ClassVar[Dict[str, Self]] = {}
|
||||
_application: "Optional[Application]" = None
|
||||
|
||||
def set_application(self, application: "Application") -> None:
|
||||
self._application = application
|
||||
|
||||
@property
|
||||
def application(self) -> "Application":
|
||||
if self._application is None:
|
||||
raise RuntimeError(f"No application was set for this {self.__class__.__name__}.")
|
||||
return self._application
|
||||
|
||||
def __new__(cls: Type[T], name: str, *args, **kwargs) -> T:
|
||||
with cls._lock:
|
||||
if (instance := cls._instances.get(name)) is None:
|
||||
instance = object.__new__(cls)
|
||||
instance.__init__(name, *args, **kwargs)
|
||||
cls._instances.update({name: instance})
|
||||
return instance
|
||||
|
||||
@cached_property
|
||||
def name(self) -> str:
|
||||
"""当前执行器的名称"""
|
||||
return self._name
|
||||
|
||||
def __init__(self, name: str, dispatcher: Optional[Type["AbstractDispatcher"]] = None) -> None:
|
||||
self._name = name
|
||||
self._dispatcher = dispatcher
|
||||
|
||||
|
||||
class Executor(BaseExecutor, Generic[P, R]):
|
||||
async def __call__(
|
||||
self,
|
||||
target: Callable[P, R],
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
**kwargs,
|
||||
) -> R:
|
||||
dispatcher = self._dispatcher or dispatcher
|
||||
dispatcher_instance = dispatcher(**kwargs)
|
||||
dispatcher_instance.set_application(application=self.application)
|
||||
dispatched_func = dispatcher_instance.dispatch(target) # 分发参数,组成新函数
|
||||
|
||||
# 执行
|
||||
if inspect.iscoroutinefunction(target):
|
||||
result = await dispatched_func()
|
||||
else:
|
||||
result = dispatched_func()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class HandlerExecutor(BaseExecutor, Generic[P, R]):
|
||||
"""Handler专用执行器"""
|
||||
|
||||
_callback: Callable[P, R]
|
||||
_dispatcher: "HandlerDispatcher"
|
||||
|
||||
def __init__(self, func: Callable[P, R], dispatcher: Optional[Type["HandlerDispatcher"]] = None) -> None:
|
||||
if dispatcher is None:
|
||||
from core.builtins.dispatcher import HandlerDispatcher
|
||||
|
||||
dispatcher = HandlerDispatcher
|
||||
super().__init__("handler", dispatcher)
|
||||
self._callback = func
|
||||
self._dispatcher = dispatcher()
|
||||
|
||||
def set_application(self, application: "Application") -> None:
|
||||
self._application = application
|
||||
if self._dispatcher is not None:
|
||||
self._dispatcher.set_application(application)
|
||||
|
||||
async def __call__(self, update: Update, context: CallbackContext) -> R:
|
||||
with handler_contexts(update, context):
|
||||
dispatched_func = self._dispatcher.dispatch(self._callback, update=update, context=context)
|
||||
return await dispatched_func()
|
||||
|
||||
|
||||
class JobExecutor(BaseExecutor):
|
||||
"""Job 专用执行器"""
|
||||
|
||||
def __init__(self, func: Callable[P, R], dispatcher: Optional[Type["AbstractDispatcher"]] = None) -> None:
|
||||
if dispatcher is None:
|
||||
from core.builtins.dispatcher import JobDispatcher
|
||||
|
||||
dispatcher = JobDispatcher
|
||||
super().__init__("job", dispatcher)
|
||||
self._callback = func
|
||||
self._dispatcher = dispatcher()
|
||||
|
||||
def set_application(self, application: "Application") -> None:
|
||||
self._application = application
|
||||
if self._dispatcher is not None:
|
||||
self._dispatcher.set_application(application)
|
||||
|
||||
async def __call__(self, context: CallbackContext) -> R:
|
||||
with job_contexts(context):
|
||||
dispatched_func = self._dispatcher.dispatch(self._callback, context=context)
|
||||
return await dispatched_func()
|
185
core/builtins/reloader.py
Normal file
185
core/builtins/reloader.py
Normal file
@ -0,0 +1,185 @@
|
||||
import inspect
|
||||
import multiprocessing
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterator, List, Optional, TYPE_CHECKING
|
||||
|
||||
from watchfiles import watch
|
||||
|
||||
from utils.const import HANDLED_SIGNALS, PROJECT_ROOT
|
||||
from utils.log import logger
|
||||
from utils.typedefs import StrOrPath
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from multiprocessing.process import BaseProcess
|
||||
|
||||
__all__ = ("Reloader",)
|
||||
|
||||
multiprocessing.allow_connection_pickling()
|
||||
spawn = multiprocessing.get_context("spawn")
|
||||
|
||||
|
||||
class FileFilter:
|
||||
"""监控文件过滤"""
|
||||
|
||||
def __init__(self, includes: List[str], excludes: List[str]) -> None:
|
||||
default_includes = ["*.py"]
|
||||
self.includes = [default for default in default_includes if default not in excludes]
|
||||
self.includes.extend(includes)
|
||||
self.includes = list(set(self.includes))
|
||||
|
||||
default_excludes = [".*", ".py[cod]", ".sw.*", "~*", __file__]
|
||||
self.excludes = [default for default in default_excludes if default not in includes]
|
||||
self.exclude_dirs = []
|
||||
for e in excludes:
|
||||
p = Path(e)
|
||||
try:
|
||||
is_dir = p.is_dir()
|
||||
except OSError:
|
||||
is_dir = False
|
||||
|
||||
if is_dir:
|
||||
self.exclude_dirs.append(p)
|
||||
else:
|
||||
self.excludes.append(e)
|
||||
self.excludes = list(set(self.excludes))
|
||||
|
||||
def __call__(self, path: Path) -> bool:
|
||||
for include_pattern in self.includes:
|
||||
if path.match(include_pattern):
|
||||
for exclude_dir in self.exclude_dirs:
|
||||
if exclude_dir in path.parents:
|
||||
return False
|
||||
|
||||
for exclude_pattern in self.excludes:
|
||||
if path.match(exclude_pattern):
|
||||
return False
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class Reloader:
|
||||
_target: Callable[..., None]
|
||||
_process: "BaseProcess"
|
||||
|
||||
@property
|
||||
def process(self) -> "BaseProcess":
|
||||
return self._process
|
||||
|
||||
@property
|
||||
def target(self) -> Callable[..., None]:
|
||||
return self._target
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Callable[..., None],
|
||||
*,
|
||||
reload_delay: float = 0.25,
|
||||
reload_dirs: List[StrOrPath] = None,
|
||||
reload_includes: List[str] = None,
|
||||
reload_excludes: List[str] = None,
|
||||
):
|
||||
if inspect.iscoroutinefunction(target):
|
||||
raise ValueError("不支持异步函数")
|
||||
self._target = target
|
||||
|
||||
self.reload_delay = reload_delay
|
||||
|
||||
_reload_dirs = []
|
||||
for reload_dir in reload_dirs or []:
|
||||
_reload_dirs.append(PROJECT_ROOT.joinpath(Path(reload_dir)))
|
||||
|
||||
self.reload_dirs = []
|
||||
for reload_dir in _reload_dirs:
|
||||
append = True
|
||||
for parent in reload_dir.parents:
|
||||
if parent in _reload_dirs:
|
||||
append = False
|
||||
break
|
||||
if append:
|
||||
self.reload_dirs.append(reload_dir)
|
||||
|
||||
if not self.reload_dirs:
|
||||
logger.warning("需要检测的目标文件夹列表为空", extra={"tag": "Reloader"})
|
||||
|
||||
self._should_exit = threading.Event()
|
||||
|
||||
frame = inspect.currentframe().f_back
|
||||
|
||||
self.watch_filter = FileFilter(reload_includes or [], (reload_excludes or []) + [frame.f_globals["__file__"]])
|
||||
self.watcher = watch(
|
||||
*self.reload_dirs,
|
||||
watch_filter=None,
|
||||
stop_event=self._should_exit,
|
||||
yield_on_timeout=True,
|
||||
)
|
||||
|
||||
def get_changes(self) -> Optional[List[Path]]:
|
||||
if not self._process.is_alive():
|
||||
logger.info("目标进程已经关闭", extra={"tag": "Reloader"})
|
||||
self._should_exit.set()
|
||||
try:
|
||||
changes = next(self.watcher)
|
||||
except StopIteration:
|
||||
return None
|
||||
if changes:
|
||||
unique_paths = {Path(c[1]) for c in changes}
|
||||
return [p for p in unique_paths if self.watch_filter(p)]
|
||||
return None
|
||||
|
||||
def __iter__(self) -> Iterator[Optional[List[Path]]]:
|
||||
return self
|
||||
|
||||
def __next__(self) -> Optional[List[Path]]:
|
||||
return self.get_changes()
|
||||
|
||||
def run(self) -> None:
|
||||
self.startup()
|
||||
for changes in self:
|
||||
if changes:
|
||||
logger.warning(
|
||||
"检测到文件 %s 发生改变, 正在重载...",
|
||||
[str(c.relative_to(PROJECT_ROOT)).replace(os.sep, "/") for c in changes],
|
||||
extra={"tag": "Reloader"},
|
||||
)
|
||||
self.restart()
|
||||
|
||||
self.shutdown()
|
||||
|
||||
def signal_handler(self, *_) -> None:
|
||||
"""当接收到结束信号量时"""
|
||||
self._process.join(3)
|
||||
if self._process.is_alive():
|
||||
self._process.terminate()
|
||||
self._process.join()
|
||||
self._should_exit.set()
|
||||
|
||||
def startup(self) -> None:
|
||||
"""启动进程"""
|
||||
logger.info("目标进程正在启动", extra={"tag": "Reloader"})
|
||||
|
||||
for sig in HANDLED_SIGNALS:
|
||||
signal.signal(sig, self.signal_handler)
|
||||
|
||||
self._process = spawn.Process(target=self._target)
|
||||
self._process.start()
|
||||
logger.success("目标进程启动成功", extra={"tag": "Reloader"})
|
||||
|
||||
def restart(self) -> None:
|
||||
"""重启进程"""
|
||||
self._process.terminate()
|
||||
self._process.join(10)
|
||||
|
||||
self._process = spawn.Process(target=self._target)
|
||||
self._process.start()
|
||||
logger.info("目标进程已经重载", extra={"tag": "Reloader"})
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""关闭进程"""
|
||||
self._process.terminate()
|
||||
self._process.join(10)
|
||||
|
||||
logger.info("重载器已经关闭", extra={"tag": "Reloader"})
|
@ -1,19 +1,15 @@
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import dotenv
|
||||
from pydantic import AnyUrl, BaseModel, Field
|
||||
from pydantic import AnyUrl, Field
|
||||
|
||||
from core.basemodel import Settings
|
||||
from utils.const import PROJECT_ROOT
|
||||
from utils.models.base import Settings
|
||||
from utils.typedefs import NaturalNumber
|
||||
|
||||
__all__ = ["BotConfig", "config", "JoinGroups"]
|
||||
__all__ = ("ApplicationConfig", "config", "JoinGroups")
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
@ -25,22 +21,12 @@ class JoinGroups(str, Enum):
|
||||
ALLOW_ALL = "ALLOW_ALL"
|
||||
|
||||
|
||||
class ConfigChannel(BaseModel):
|
||||
name: str
|
||||
chat_id: int
|
||||
|
||||
|
||||
class ConfigUser(BaseModel):
|
||||
username: Optional[str]
|
||||
user_id: int
|
||||
|
||||
|
||||
class MySqlConfig(Settings):
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 3306
|
||||
username: str = None
|
||||
password: str = None
|
||||
database: str = None
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
database: Optional[str] = None
|
||||
|
||||
class Config(Settings.Config):
|
||||
env_prefix = "db_"
|
||||
@ -58,7 +44,7 @@ class RedisConfig(Settings):
|
||||
|
||||
class LoggerConfig(Settings):
|
||||
name: str = "TGPaimon"
|
||||
width: int = 180
|
||||
width: Optional[int] = None
|
||||
time_format: str = "[%Y-%m-%d %X]"
|
||||
traceback_max_frames: int = 20
|
||||
path: Path = PROJECT_ROOT / "logs"
|
||||
@ -78,6 +64,9 @@ class MTProtoConfig(Settings):
|
||||
|
||||
|
||||
class WebServerConfig(Settings):
|
||||
enable: bool = False
|
||||
"""是否启用WebServer"""
|
||||
|
||||
url: AnyUrl = "http://localhost:8080"
|
||||
host: str = "localhost"
|
||||
port: int = 8080
|
||||
@ -97,6 +86,16 @@ class ErrorConfig(Settings):
|
||||
env_prefix = "error_"
|
||||
|
||||
|
||||
class ReloadConfig(Settings):
|
||||
delay: float = 0.25
|
||||
dirs: List[str] = []
|
||||
include: List[str] = []
|
||||
exclude: List[str] = []
|
||||
|
||||
class Config(Settings.Config):
|
||||
env_prefix = "reload_"
|
||||
|
||||
|
||||
class NoticeConfig(Settings):
|
||||
user_mismatch: str = "再乱点我叫西风骑士团、千岩军、天领奉行、三十人团和风纪官了!"
|
||||
|
||||
@ -104,24 +103,32 @@ class NoticeConfig(Settings):
|
||||
env_prefix = "notice_"
|
||||
|
||||
|
||||
class PluginConfig(Settings):
|
||||
download_file_max_size: int = 5
|
||||
|
||||
class Config(Settings.Config):
|
||||
env_prefix = "plugin_"
|
||||
|
||||
|
||||
class BotConfig(Settings):
|
||||
class ApplicationConfig(Settings):
|
||||
debug: bool = False
|
||||
"""debug 开关"""
|
||||
retry: int = 5
|
||||
"""重试次数"""
|
||||
auto_reload: bool = False
|
||||
"""自动重载"""
|
||||
|
||||
proxy_url: Optional[AnyUrl] = None
|
||||
"""代理链接"""
|
||||
|
||||
bot_token: str = ""
|
||||
"""BOT的token"""
|
||||
|
||||
owner: Optional[int] = None
|
||||
|
||||
channels: List[int] = []
|
||||
"""文章推送群组"""
|
||||
|
||||
channels: List["ConfigChannel"] = []
|
||||
admins: List["ConfigUser"] = []
|
||||
verify_groups: List[Union[int, str]] = []
|
||||
"""启用群验证功能的群组"""
|
||||
join_groups: Optional[JoinGroups] = JoinGroups.NO_ALLOW
|
||||
"""是否允许机器人被邀请到其它群组"""
|
||||
|
||||
timeout: int = 10
|
||||
connection_pool_size: int = 256
|
||||
read_timeout: Optional[float] = None
|
||||
write_timeout: Optional[float] = None
|
||||
connect_timeout: Optional[float] = None
|
||||
@ -138,6 +145,7 @@ class BotConfig(Settings):
|
||||
pass_challenge_app_key: str = ""
|
||||
pass_challenge_user_web: str = ""
|
||||
|
||||
reload: ReloadConfig = ReloadConfig()
|
||||
mysql: MySqlConfig = MySqlConfig()
|
||||
logger: LoggerConfig = LoggerConfig()
|
||||
webserver: WebServerConfig = WebServerConfig()
|
||||
@ -145,8 +153,7 @@ class BotConfig(Settings):
|
||||
mtproto: MTProtoConfig = MTProtoConfig()
|
||||
error: ErrorConfig = ErrorConfig()
|
||||
notice: NoticeConfig = NoticeConfig()
|
||||
plugin: PluginConfig = PluginConfig()
|
||||
|
||||
|
||||
BotConfig.update_forward_refs()
|
||||
config = BotConfig()
|
||||
ApplicationConfig.update_forward_refs()
|
||||
config = ApplicationConfig()
|
||||
|
@ -1,21 +0,0 @@
|
||||
from core.base.mysql import MySQL
|
||||
from core.base.redisdb import RedisDB
|
||||
from core.cookies.cache import PublicCookiesCache
|
||||
from core.cookies.repositories import CookiesRepository
|
||||
from core.cookies.services import CookiesService, PublicCookiesService
|
||||
from core.service import init_service
|
||||
|
||||
|
||||
@init_service
|
||||
def create_cookie_service(mysql: MySQL):
|
||||
_repository = CookiesRepository(mysql)
|
||||
_service = CookiesService(_repository)
|
||||
return _service
|
||||
|
||||
|
||||
@init_service
|
||||
def create_public_cookie_service(mysql: MySQL, redis: RedisDB):
|
||||
_repository = CookiesRepository(mysql)
|
||||
_cache = PublicCookiesCache(redis)
|
||||
_service = PublicCookiesService(_repository, _cache)
|
||||
return _service
|
@ -1,27 +0,0 @@
|
||||
import enum
|
||||
from typing import Optional, Dict
|
||||
|
||||
from sqlmodel import SQLModel, Field, JSON, Enum, Column
|
||||
|
||||
|
||||
class CookiesStatusEnum(int, enum.Enum):
|
||||
STATUS_SUCCESS = 0
|
||||
INVALID_COOKIES = 1
|
||||
TOO_MANY_REQUESTS = 2
|
||||
|
||||
|
||||
class Cookies(SQLModel):
|
||||
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
||||
|
||||
id: int = Field(primary_key=True)
|
||||
user_id: Optional[int] = Field(foreign_key="user.user_id")
|
||||
cookies: Optional[Dict[str, str]] = Field(sa_column=Column(JSON))
|
||||
status: Optional[CookiesStatusEnum] = Field(sa_column=Column(Enum(CookiesStatusEnum)))
|
||||
|
||||
|
||||
class HyperionCookie(Cookies, table=True):
|
||||
__tablename__ = "mihoyo_cookies"
|
||||
|
||||
|
||||
class HoyolabCookie(Cookies, table=True):
|
||||
__tablename__ = "hoyoverse_cookies"
|
@ -1,109 +0,0 @@
|
||||
from typing import cast, List
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from core.base.mysql import MySQL
|
||||
from utils.error import RegionNotFoundError
|
||||
from utils.models.base import RegionEnum
|
||||
from .error import CookiesNotFoundError
|
||||
from .models import HyperionCookie, HoyolabCookie, Cookies
|
||||
|
||||
|
||||
class CookiesRepository:
|
||||
def __init__(self, mysql: MySQL):
|
||||
self.mysql = mysql
|
||||
|
||||
async def add_cookies(self, user_id: int, cookies: dict, region: RegionEnum):
|
||||
async with self.mysql.Session() as session:
|
||||
session = cast(AsyncSession, session)
|
||||
if region == RegionEnum.HYPERION:
|
||||
db_data = HyperionCookie(user_id=user_id, cookies=cookies)
|
||||
elif region == RegionEnum.HOYOLAB:
|
||||
db_data = HoyolabCookie(user_id=user_id, cookies=cookies)
|
||||
else:
|
||||
raise RegionNotFoundError(region.name)
|
||||
session.add(db_data)
|
||||
await session.commit()
|
||||
|
||||
async def update_cookies(self, user_id: int, cookies: dict, region: RegionEnum):
|
||||
async with self.mysql.Session() as session:
|
||||
session = cast(AsyncSession, session)
|
||||
if region == RegionEnum.HYPERION:
|
||||
statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id)
|
||||
elif region == RegionEnum.HOYOLAB:
|
||||
statement = select(HoyolabCookie).where(HoyolabCookie.user_id == user_id)
|
||||
else:
|
||||
raise RegionNotFoundError(region.name)
|
||||
results = await session.exec(statement)
|
||||
db_cookies = results.first()
|
||||
if db_cookies is None:
|
||||
raise CookiesNotFoundError(user_id)
|
||||
db_cookies = db_cookies[0]
|
||||
db_cookies.cookies = cookies
|
||||
session.add(db_cookies)
|
||||
await session.commit()
|
||||
await session.refresh(db_cookies)
|
||||
|
||||
async def update_cookies_ex(self, cookies: Cookies, region: RegionEnum):
|
||||
async with self.mysql.Session() as session:
|
||||
session = cast(AsyncSession, session)
|
||||
if region not in [RegionEnum.HYPERION, RegionEnum.HOYOLAB]:
|
||||
raise RegionNotFoundError(region.name)
|
||||
session.add(cookies)
|
||||
await session.commit()
|
||||
await session.refresh(cookies)
|
||||
|
||||
async def get_cookies(self, user_id, region: RegionEnum) -> Cookies:
|
||||
async with self.mysql.Session() as session:
|
||||
session = cast(AsyncSession, session)
|
||||
if region == RegionEnum.HYPERION:
|
||||
statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id)
|
||||
results = await session.exec(statement)
|
||||
db_cookies = results.first()
|
||||
if db_cookies is None:
|
||||
raise CookiesNotFoundError(user_id)
|
||||
return db_cookies[0]
|
||||
elif region == RegionEnum.HOYOLAB:
|
||||
statement = select(HoyolabCookie).where(HoyolabCookie.user_id == user_id)
|
||||
results = await session.exec(statement)
|
||||
db_cookies = results.first()
|
||||
if db_cookies is None:
|
||||
raise CookiesNotFoundError(user_id)
|
||||
return db_cookies[0]
|
||||
else:
|
||||
raise RegionNotFoundError(region.name)
|
||||
|
||||
async def get_all_cookies(self, region: RegionEnum) -> List[Cookies]:
|
||||
async with self.mysql.Session() as session:
|
||||
session = cast(AsyncSession, session)
|
||||
if region == RegionEnum.HYPERION:
|
||||
statement = select(HyperionCookie)
|
||||
results = await session.exec(statement)
|
||||
db_cookies = results.all()
|
||||
return [cookies[0] for cookies in db_cookies]
|
||||
elif region == RegionEnum.HOYOLAB:
|
||||
statement = select(HoyolabCookie)
|
||||
results = await session.exec(statement)
|
||||
db_cookies = results.all()
|
||||
return [cookies[0] for cookies in db_cookies]
|
||||
else:
|
||||
raise RegionNotFoundError(region.name)
|
||||
|
||||
async def del_cookies(self, user_id, region: RegionEnum):
|
||||
async with self.mysql.Session() as session:
|
||||
session = cast(AsyncSession, session)
|
||||
if region == RegionEnum.HYPERION:
|
||||
statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id)
|
||||
elif region == RegionEnum.HOYOLAB:
|
||||
statement = select(HoyolabCookie).where(HoyolabCookie.user_id == user_id)
|
||||
else:
|
||||
raise RegionNotFoundError(region.name)
|
||||
results = await session.execute(statement)
|
||||
try:
|
||||
db_cookies = results.unique().scalar_one()
|
||||
except NoResultFound as exc:
|
||||
raise CookiesNotFoundError(user_id) from exc
|
||||
await session.delete(db_cookies)
|
||||
await session.commit()
|
1
core/dependence/__init__.py
Normal file
1
core/dependence/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""基础服务"""
|
@ -1,26 +1,40 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from playwright.async_api import Browser, Playwright, async_playwright, Error
|
||||
from playwright.async_api import Error, async_playwright
|
||||
|
||||
from core.service import Service
|
||||
from core.base_service import BaseService
|
||||
from utils.log import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from playwright.async_api import Playwright as AsyncPlaywright, Browser
|
||||
|
||||
__all__ = ("AioBrowser",)
|
||||
|
||||
|
||||
class AioBrowser(BaseService.Dependence):
|
||||
@property
|
||||
def browser(self):
|
||||
return self._browser
|
||||
|
||||
class AioBrowser(Service):
|
||||
def __init__(self, loop=None):
|
||||
self.browser: Optional[Browser] = None
|
||||
self._playwright: Optional[Playwright] = None
|
||||
self._browser: Optional["Browser"] = None
|
||||
self._playwright: Optional["AsyncPlaywright"] = None
|
||||
self._loop = loop
|
||||
|
||||
async def start(self):
|
||||
async def get_browser(self):
|
||||
if self._browser is None:
|
||||
await self.initialize()
|
||||
return self._browser
|
||||
|
||||
async def initialize(self):
|
||||
if self._playwright is None:
|
||||
logger.info("正在尝试启动 [blue]Playwright[/]", extra={"markup": True})
|
||||
self._playwright = await async_playwright().start()
|
||||
logger.success("[blue]Playwright[/] 启动成功", extra={"markup": True})
|
||||
if self.browser is None:
|
||||
if self._browser is None:
|
||||
logger.info("正在尝试启动 [blue]Browser[/]", extra={"markup": True})
|
||||
try:
|
||||
self.browser = await self._playwright.chromium.launch(timeout=5000)
|
||||
self._browser = await self._playwright.chromium.launch(timeout=5000)
|
||||
logger.success("[blue]Browser[/] 启动成功", extra={"markup": True})
|
||||
except Error as err:
|
||||
if "playwright install" in str(err):
|
||||
@ -33,15 +47,10 @@ class AioBrowser(Service):
|
||||
raise RuntimeError("检查到 playwright 刚刚安装或者未升级\n请运行以下命令下载新浏览器\nplaywright install chromium")
|
||||
raise err
|
||||
|
||||
return self.browser
|
||||
return self._browser
|
||||
|
||||
async def stop(self):
|
||||
if self.browser is not None:
|
||||
await self.browser.close()
|
||||
async def shutdown(self):
|
||||
if self._browser is not None:
|
||||
await self._browser.close()
|
||||
if self._playwright is not None:
|
||||
await self._playwright.stop()
|
||||
|
||||
async def get_browser(self) -> Browser:
|
||||
if self.browser is None:
|
||||
await self.start()
|
||||
return self.browser
|
||||
self._playwright.stop()
|
16
core/dependence/aiobrowser.pyi
Normal file
16
core/dependence/aiobrowser.pyi
Normal file
@ -0,0 +1,16 @@
|
||||
from asyncio import AbstractEventLoop
|
||||
|
||||
from playwright.async_api import Browser, Playwright as AsyncPlaywright
|
||||
|
||||
from core.base_service import BaseService
|
||||
|
||||
__all__ = ("AioBrowser",)
|
||||
|
||||
class AioBrowser(BaseService.Dependence):
|
||||
_browser: Browser | None
|
||||
_playwright: AsyncPlaywright | None
|
||||
_loop: AbstractEventLoop
|
||||
|
||||
@property
|
||||
def browser(self) -> Browser | None: ...
|
||||
async def get_browser(self) -> Browser: ...
|
@ -17,7 +17,7 @@ from enkanetwork.model.assets import CharacterAsset as EnkaCharacterAsset
|
||||
from httpx import AsyncClient, HTTPError, HTTPStatusError, TransportError, URL
|
||||
from typing_extensions import Self
|
||||
|
||||
from core.service import Service
|
||||
from core.base_service import BaseService
|
||||
from metadata.genshin import AVATAR_DATA, HONEY_DATA, MATERIAL_DATA, NAMECARD_DATA, WEAPON_DATA
|
||||
from metadata.scripts.honey import update_honey_metadata
|
||||
from metadata.scripts.metadatas import update_metadata_from_ambr, update_metadata_from_github
|
||||
@ -31,6 +31,8 @@ if TYPE_CHECKING:
|
||||
from httpx import Response
|
||||
from multiprocessing.synchronize import RLock
|
||||
|
||||
__all__ = ("AssetsServiceType", "AssetsService", "AssetsServiceError", "AssetsCouldNotFound", "DEFAULT_EnkaAssets")
|
||||
|
||||
ICON_TYPE = Union[Callable[[bool], Awaitable[Optional[Path]]], Callable[..., Awaitable[Optional[Path]]]]
|
||||
NAME_MAP_TYPE = Dict[str, StrOrURL]
|
||||
|
||||
@ -127,7 +129,7 @@ class _AssetsService(ABC):
|
||||
|
||||
async def _download(self, url: StrOrURL, path: Path, retry: int = 5) -> Path | None:
|
||||
"""从 url 下载图标至 path"""
|
||||
logger.debug(f"正在从 {url} 下载图标至 {path}")
|
||||
logger.debug("正在从 %s 下载图标至 %s", url, path)
|
||||
headers = {"user-agent": "TGPaimonBot/3.0"} if URL(url).host == "enka.network" else None
|
||||
for time in range(retry):
|
||||
try:
|
||||
@ -204,8 +206,8 @@ class _AssetsService(ABC):
|
||||
"""魔法"""
|
||||
if item in self.icon_types:
|
||||
return partial(self._get_img, item=item)
|
||||
else:
|
||||
object.__getattribute__(self, item)
|
||||
object.__getattribute__(self, item)
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
@cached_property
|
||||
@ -498,7 +500,7 @@ class _NamecardAssets(_AssetsService):
|
||||
}
|
||||
|
||||
|
||||
class AssetsService(Service):
|
||||
class AssetsService(BaseService.Dependence):
|
||||
"""asset服务
|
||||
|
||||
用于储存和管理 asset :
|
||||
@ -527,8 +529,10 @@ class AssetsService(Service):
|
||||
):
|
||||
setattr(self, attr, globals()[assets_type_name]())
|
||||
|
||||
async def start(self): # pylint: disable=R0201
|
||||
async def initialize(self) -> None: # pylint: disable=R0201
|
||||
"""启动 AssetsService 服务,刷新元数据"""
|
||||
logger.info("正在刷新元数据")
|
||||
# todo 这3个任务同时异步下载
|
||||
await update_metadata_from_github(False)
|
||||
await update_metadata_from_ambr(False)
|
||||
await update_honey_metadata(False)
|
167
core/dependence/assets.pyi
Normal file
167
core/dependence/assets.pyi
Normal file
@ -0,0 +1,167 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Awaitable, Callable, ClassVar, TypeVar
|
||||
|
||||
from enkanetwork import Assets as EnkaAssets
|
||||
from enkanetwork.model.assets import CharacterAsset as EnkaCharacterAsset
|
||||
from httpx import AsyncClient
|
||||
from typing_extensions import Self
|
||||
|
||||
from core.base_service import BaseService
|
||||
from utils.typedefs import StrOrInt
|
||||
|
||||
__all__ = ("AssetsServiceType", "AssetsService", "AssetsServiceError", "AssetsCouldNotFound", "DEFAULT_EnkaAssets")
|
||||
|
||||
ICON_TYPE = Callable[[bool], Awaitable[Path | None]] | Callable[..., Awaitable[Path | None]]
|
||||
DEFAULT_EnkaAssets: EnkaAssets
|
||||
_GET_TYPE = partial | list[str] | int | str | ICON_TYPE | Path | AsyncClient | None | Self | dict[str, str]
|
||||
|
||||
class AssetsServiceError(Exception): ...
|
||||
|
||||
class AssetsCouldNotFound(AssetsServiceError):
|
||||
message: str
|
||||
target: str
|
||||
def __init__(self, message: str, target: str): ...
|
||||
|
||||
class _AssetsService(ABC):
|
||||
icon_types: ClassVar[list[str]]
|
||||
id: int
|
||||
type: str
|
||||
|
||||
icon: ICON_TYPE
|
||||
"""图标"""
|
||||
|
||||
@abstractmethod
|
||||
@property
|
||||
def game_name(self) -> str:
|
||||
"""游戏数据中的名称"""
|
||||
@property
|
||||
def honey_id(self) -> str:
|
||||
"""当前资源在 Honey Impact 所对应的 ID"""
|
||||
@property
|
||||
def path(self) -> Path:
|
||||
"""当前资源的文件夹"""
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
"""当前的 http client"""
|
||||
def __init__(self, client: AsyncClient | None = None) -> None: ...
|
||||
def __call__(self, target: int) -> Self:
|
||||
"""用于生成与 target 对应的 assets"""
|
||||
def __getattr__(self, item: str) -> _GET_TYPE:
|
||||
"""魔法"""
|
||||
async def get_link(self, item: str) -> str | None:
|
||||
"""获取相应图标链接"""
|
||||
@abstractmethod
|
||||
@property
|
||||
def game_name_map(self) -> dict[str, str]:
|
||||
"""游戏中的图标名"""
|
||||
@abstractmethod
|
||||
@property
|
||||
def honey_name_map(self) -> dict[str, str]:
|
||||
"""来自honey的图标名"""
|
||||
|
||||
class _AvatarAssets(_AssetsService):
|
||||
enka: EnkaCharacterAsset | None
|
||||
|
||||
side: ICON_TYPE
|
||||
"""侧视图图标"""
|
||||
|
||||
card: ICON_TYPE
|
||||
"""卡片图标"""
|
||||
|
||||
gacha: ICON_TYPE
|
||||
"""抽卡立绘"""
|
||||
|
||||
gacha_card: ICON_TYPE
|
||||
"""抽卡卡片"""
|
||||
|
||||
@property
|
||||
def honey_name_map(self) -> dict[str, str]: ...
|
||||
@property
|
||||
def game_name_map(self) -> dict[str, str]: ...
|
||||
@property
|
||||
def enka(self) -> EnkaCharacterAsset | None: ...
|
||||
def __init__(self, client: AsyncClient | None = None, enka: EnkaAssets | None = None) -> None: ...
|
||||
def __call__(self, target: StrOrInt) -> Self: ...
|
||||
def __getitem__(self, item: str) -> _GET_TYPE | EnkaCharacterAsset: ...
|
||||
def game_name(self) -> str: ...
|
||||
|
||||
class _WeaponAssets(_AssetsService):
|
||||
awaken: ICON_TYPE
|
||||
"""突破后图标"""
|
||||
|
||||
gacha: ICON_TYPE
|
||||
"""抽卡立绘"""
|
||||
|
||||
@property
|
||||
def honey_name_map(self) -> dict[str, str]: ...
|
||||
@property
|
||||
def game_name_map(self) -> dict[str, str]: ...
|
||||
def __call__(self, target: StrOrInt) -> Self: ...
|
||||
def game_name(self) -> str: ...
|
||||
|
||||
class _MaterialAssets(_AssetsService):
|
||||
@property
|
||||
def honey_name_map(self) -> dict[str, str]: ...
|
||||
@property
|
||||
def game_name_map(self) -> dict[str, str]: ...
|
||||
def __call__(self, target: StrOrInt) -> Self: ...
|
||||
def game_name(self) -> str: ...
|
||||
|
||||
class _ArtifactAssets(_AssetsService):
|
||||
flower: ICON_TYPE
|
||||
"""生之花"""
|
||||
|
||||
plume: ICON_TYPE
|
||||
"""死之羽"""
|
||||
|
||||
sands: ICON_TYPE
|
||||
"""时之沙"""
|
||||
|
||||
goblet: ICON_TYPE
|
||||
"""空之杯"""
|
||||
|
||||
circlet: ICON_TYPE
|
||||
"""理之冠"""
|
||||
|
||||
@property
|
||||
def honey_name_map(self) -> dict[str, str]: ...
|
||||
@property
|
||||
def game_name_map(self) -> dict[str, str]: ...
|
||||
def game_name(self) -> str: ...
|
||||
|
||||
class _NamecardAssets(_AssetsService):
|
||||
enka: EnkaCharacterAsset | None
|
||||
|
||||
navbar: ICON_TYPE
|
||||
"""好友名片背景"""
|
||||
|
||||
profile: ICON_TYPE
|
||||
"""个人资料名片背景"""
|
||||
|
||||
@property
|
||||
def honey_name_map(self) -> dict[str, str]: ...
|
||||
@property
|
||||
def game_name_map(self) -> dict[str, str]: ...
|
||||
def game_name(self) -> str: ...
|
||||
|
||||
class AssetsService(BaseService.Dependence):
|
||||
avatar: _AvatarAssets
|
||||
"""角色"""
|
||||
|
||||
weapon: _WeaponAssets
|
||||
"""武器"""
|
||||
|
||||
material: _MaterialAssets
|
||||
"""素材"""
|
||||
|
||||
artifact: _ArtifactAssets
|
||||
"""圣遗物"""
|
||||
|
||||
namecard: _NamecardAssets
|
||||
"""名片"""
|
||||
|
||||
AssetsServiceType = TypeVar("AssetsServiceType", bound=_AssetsService)
|
@ -4,6 +4,8 @@ from urllib.parse import urlparse
|
||||
|
||||
import aiofiles
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.config import config as bot_config
|
||||
from utils.log import logger
|
||||
|
||||
try:
|
||||
@ -13,13 +15,12 @@ try:
|
||||
session.log.debug = lambda *args, **kwargs: None # 关闭日记
|
||||
PYROGRAM_AVAILABLE = True
|
||||
except ImportError:
|
||||
Client = None
|
||||
session = None
|
||||
PYROGRAM_AVAILABLE = False
|
||||
|
||||
from core.bot import bot
|
||||
from core.service import Service
|
||||
|
||||
|
||||
class MTProto(Service):
|
||||
class MTProto(BaseService.Dependence):
|
||||
async def get_session(self):
|
||||
async with aiofiles.open(self.session_path, mode="r") as f:
|
||||
return await f.read()
|
||||
@ -32,9 +33,9 @@ class MTProto(Service):
|
||||
return os.path.exists(self.session_path)
|
||||
|
||||
def __init__(self):
|
||||
self.name = "PaimonBot"
|
||||
self.name = "paigram"
|
||||
current_dir = os.getcwd()
|
||||
self.session_path = os.path.join(current_dir, "paimon.session")
|
||||
self.session_path = os.path.join(current_dir, "paigram.session")
|
||||
self.client: Optional[Client] = None
|
||||
self.proxy: Optional[dict] = None
|
||||
http_proxy = os.environ.get("HTTP_PROXY")
|
||||
@ -42,25 +43,25 @@ class MTProto(Service):
|
||||
http_proxy_url = urlparse(http_proxy)
|
||||
self.proxy = {"scheme": "http", "hostname": http_proxy_url.hostname, "port": http_proxy_url.port}
|
||||
|
||||
async def start(self): # pylint: disable=W0221
|
||||
async def initialize(self): # pylint: disable=W0221
|
||||
if not PYROGRAM_AVAILABLE:
|
||||
logger.info("MTProto 服务需要的 pyrogram 模块未导入 本次服务 client 为 None")
|
||||
return
|
||||
if bot.config.mtproto.api_id is None:
|
||||
if bot_config.mtproto.api_id is None:
|
||||
logger.info("MTProto 服务需要的 api_id 未配置 本次服务 client 为 None")
|
||||
return
|
||||
if bot.config.mtproto.api_hash is None:
|
||||
if bot_config.mtproto.api_hash is None:
|
||||
logger.info("MTProto 服务需要的 api_hash 未配置 本次服务 client 为 None")
|
||||
return
|
||||
self.client = Client(
|
||||
api_id=bot.config.mtproto.api_id,
|
||||
api_hash=bot.config.mtproto.api_hash,
|
||||
api_id=bot_config.mtproto.api_id,
|
||||
api_hash=bot_config.mtproto.api_hash,
|
||||
name=self.name,
|
||||
bot_token=bot.config.bot_token,
|
||||
bot_token=bot_config.bot_token,
|
||||
proxy=self.proxy,
|
||||
)
|
||||
await self.client.start()
|
||||
|
||||
async def stop(self): # pylint: disable=W0221
|
||||
async def shutdown(self): # pylint: disable=W0221
|
||||
if self.client is not None:
|
||||
await self.client.stop(block=False)
|
31
core/dependence/mtproto.pyi
Normal file
31
core/dependence/mtproto.pyi
Normal file
@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
from typing import TypedDict
|
||||
|
||||
from core.base_service import BaseService
|
||||
|
||||
try:
|
||||
from pyrogram import Client
|
||||
from pyrogram.session import session
|
||||
|
||||
PYROGRAM_AVAILABLE = True
|
||||
except ImportError:
|
||||
Client = None
|
||||
session = None
|
||||
PYROGRAM_AVAILABLE = False
|
||||
|
||||
__all__ = ("MTProto",)
|
||||
|
||||
class _ProxyType(TypedDict):
|
||||
scheme: str
|
||||
hostname: str | None
|
||||
port: int | None
|
||||
|
||||
class MTProto(BaseService.Dependence):
|
||||
name: str
|
||||
session_path: str
|
||||
client: Client | None
|
||||
proxy: _ProxyType | None
|
||||
|
||||
async def get_session(self) -> str: ...
|
||||
async def set_session(self, b: str) -> None: ...
|
||||
def session_exists(self) -> bool: ...
|
50
core/dependence/mysql.py
Normal file
50
core/dependence/mysql.py
Normal file
@ -0,0 +1,50 @@
|
||||
import contextlib
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.engine import URL
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from typing_extensions import Self
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.config import ApplicationConfig
|
||||
from core.sqlmodel.session import AsyncSession
|
||||
|
||||
__all__ = ("MySQL",)
|
||||
|
||||
|
||||
class MySQL(BaseService.Dependence):
|
||||
@classmethod
|
||||
def from_config(cls, config: ApplicationConfig) -> Self:
|
||||
return cls(**config.mysql.dict())
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
database: Optional[str] = None,
|
||||
):
|
||||
self.database = database
|
||||
self.password = password
|
||||
self.username = username
|
||||
self.port = port
|
||||
self.host = host
|
||||
self.url = URL.create(
|
||||
"mysql+asyncmy",
|
||||
username=self.username,
|
||||
password=self.password,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
database=self.database,
|
||||
)
|
||||
self.engine = create_async_engine(self.url)
|
||||
self.Session = sessionmaker(bind=self.engine, class_=AsyncSession)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def session(self) -> AsyncSession:
|
||||
yield self.Session()
|
||||
|
||||
async def shutdown(self):
|
||||
self.Session.close_all()
|
@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
from typing import Optional, Union
|
||||
|
||||
import fakeredis.aioredis
|
||||
@ -6,14 +5,16 @@ from redis import asyncio as aioredis
|
||||
from redis.exceptions import ConnectionError as RedisConnectionError, TimeoutError as RedisTimeoutError
|
||||
from typing_extensions import Self
|
||||
|
||||
from core.config import BotConfig
|
||||
from core.service import Service
|
||||
from core.base_service import BaseService
|
||||
from core.config import ApplicationConfig
|
||||
from utils.log import logger
|
||||
|
||||
__all__ = ["RedisDB"]
|
||||
|
||||
class RedisDB(Service):
|
||||
|
||||
class RedisDB(BaseService.Dependence):
|
||||
@classmethod
|
||||
def from_config(cls, config: BotConfig) -> Self:
|
||||
def from_config(cls, config: ApplicationConfig) -> Self:
|
||||
return cls(**config.redis.dict())
|
||||
|
||||
def __init__(
|
||||
@ -24,6 +25,7 @@ class RedisDB(Service):
|
||||
self.key_prefix = "paimon_bot"
|
||||
|
||||
async def ping(self):
|
||||
# noinspection PyUnresolvedReferences
|
||||
if await self.client.ping():
|
||||
logger.info("连接 [red]Redis[/] 成功", extra={"markup": True})
|
||||
else:
|
||||
@ -34,7 +36,7 @@ class RedisDB(Service):
|
||||
self.client = fakeredis.aioredis.FakeRedis()
|
||||
await self.ping()
|
||||
|
||||
async def start(self): # pylint: disable=W0221
|
||||
async def initialize(self):
|
||||
logger.info("正在尝试建立与 [red]Redis[/] 连接", extra={"markup": True})
|
||||
try:
|
||||
await self.ping()
|
||||
@ -45,5 +47,5 @@ class RedisDB(Service):
|
||||
logger.warning("连接 [red]Redis[/] 失败,使用 [red]fakeredis[/] 模拟", extra={"markup": True})
|
||||
await self.start_fake_redis()
|
||||
|
||||
async def stop(self): # pylint: disable=W0221
|
||||
async def shutdown(self):
|
||||
await self.client.close()
|
@ -1,16 +0,0 @@
|
||||
from core.base.redisdb import RedisDB
|
||||
from core.service import init_service
|
||||
from .cache import GameCache
|
||||
from .services import GameMaterialService, GameStrategyService
|
||||
|
||||
|
||||
@init_service
|
||||
def create_game_strategy_service(redis: RedisDB):
|
||||
_cache = GameCache(redis, "game:strategy")
|
||||
return GameStrategyService(_cache)
|
||||
|
||||
|
||||
@init_service
|
||||
def create_game_material_service(redis: RedisDB):
|
||||
_cache = GameCache(redis, "game:material")
|
||||
return GameMaterialService(_cache)
|
59
core/handler/adminhandler.py
Normal file
59
core/handler/adminhandler.py
Normal file
@ -0,0 +1,59 @@
|
||||
import asyncio
|
||||
from typing import TypeVar, TYPE_CHECKING, Any, Optional
|
||||
|
||||
from telegram import Update
|
||||
from telegram.ext import ApplicationHandlerStop, BaseHandler
|
||||
|
||||
from core.error import ServiceNotFoundError
|
||||
from core.services.users.services import UserAdminService
|
||||
from utils.log import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.application import Application
|
||||
from telegram.ext import Application as TelegramApplication
|
||||
|
||||
RT = TypeVar("RT")
|
||||
UT = TypeVar("UT")
|
||||
|
||||
CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]")
|
||||
|
||||
|
||||
class AdminHandler(BaseHandler[Update, CCT]):
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __init__(self, handler: BaseHandler[Update, CCT], application: "Application") -> None:
|
||||
self.handler = handler
|
||||
self.application = application
|
||||
self.user_service: Optional["UserAdminService"] = None
|
||||
super().__init__(self.handler.callback)
|
||||
|
||||
def check_update(self, update: object) -> bool:
|
||||
if not isinstance(update, Update):
|
||||
return False
|
||||
return self.handler.check_update(update)
|
||||
|
||||
async def _user_service(self) -> "UserAdminService":
|
||||
async with self._lock:
|
||||
if self.user_service is not None:
|
||||
return self.user_service
|
||||
user_service: UserAdminService = self.application.managers.services_map.get(UserAdminService, None)
|
||||
if user_service is None:
|
||||
raise ServiceNotFoundError("UserAdminService")
|
||||
self.user_service = user_service
|
||||
return self.user_service
|
||||
|
||||
async def handle_update(
|
||||
self,
|
||||
update: "UT",
|
||||
application: "TelegramApplication[Any, CCT, Any, Any, Any, Any]",
|
||||
check_result: Any,
|
||||
context: "CCT",
|
||||
) -> RT:
|
||||
user_service = await self._user_service()
|
||||
user = update.effective_user
|
||||
if await user_service.is_admin(user.id):
|
||||
return await self.handler.handle_update(update, application, check_result, context)
|
||||
message = update.effective_message
|
||||
logger.warning("用户 %s[%s] 触发尝试调用Admin命令但权限不足", user.full_name, user.id)
|
||||
await message.reply_text("权限不足")
|
||||
raise ApplicationHandlerStop
|
62
core/handler/callbackqueryhandler.py
Normal file
62
core/handler/callbackqueryhandler.py
Normal file
@ -0,0 +1,62 @@
|
||||
import asyncio
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from types import TracebackType
|
||||
from typing import TypeVar, TYPE_CHECKING, Any, Optional, Type
|
||||
|
||||
from telegram.ext import CallbackQueryHandler as BaseCallbackQueryHandler, ApplicationHandlerStop
|
||||
|
||||
from utils.log import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from telegram.ext import Application
|
||||
|
||||
RT = TypeVar("RT")
|
||||
UT = TypeVar("UT")
|
||||
CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]")
|
||||
|
||||
|
||||
class OverlappingException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class OverlappingContext(AbstractAsyncContextManager):
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __init__(self, context: "CCT"):
|
||||
self.context = context
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
async with self._lock:
|
||||
flag = self.context.user_data.get("overlapping", False)
|
||||
if flag:
|
||||
raise OverlappingException
|
||||
self.context.user_data["overlapping"] = True
|
||||
return None
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc: Optional[BaseException],
|
||||
tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
async with self._lock:
|
||||
del self.context.user_data["overlapping"]
|
||||
return None
|
||||
|
||||
|
||||
class CallbackQueryHandler(BaseCallbackQueryHandler):
|
||||
async def handle_update(
|
||||
self,
|
||||
update: "UT",
|
||||
application: "Application[Any, CCT, Any, Any, Any, Any]",
|
||||
check_result: Any,
|
||||
context: "CCT",
|
||||
) -> RT:
|
||||
self.collect_additional_context(context, update, application, check_result)
|
||||
try:
|
||||
async with OverlappingContext(context):
|
||||
return await self.callback(update, context)
|
||||
except OverlappingException as exc:
|
||||
user = update.effective_user
|
||||
logger.warning("用户 %s[%s] 触发 overlapping 该次命令已忽略", user.full_name, user.id)
|
||||
raise ApplicationHandlerStop from exc
|
71
core/handler/limiterhandler.py
Normal file
71
core/handler/limiterhandler.py
Normal file
@ -0,0 +1,71 @@
|
||||
import asyncio
|
||||
from typing import TypeVar, Optional
|
||||
|
||||
from telegram import Update
|
||||
from telegram.ext import ContextTypes, ApplicationHandlerStop, TypeHandler
|
||||
|
||||
from utils.log import logger
|
||||
|
||||
UT = TypeVar("UT")
|
||||
CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]")
|
||||
|
||||
|
||||
class LimiterHandler(TypeHandler[UT, CCT]):
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __init__(
|
||||
self, max_rate: float = 5, time_period: float = 10, amount: float = 1, limit_time: Optional[float] = None
|
||||
):
|
||||
"""Limiter Handler 通过
|
||||
`Leaky bucket algorithm <https://en.wikipedia.org/wiki/Leaky_bucket>`_
|
||||
实现对用户的输入的精确控制
|
||||
|
||||
输入超过一定速率后,代码会抛出
|
||||
:class:`telegram.ext.ApplicationHandlerStop`
|
||||
异常并在一段时间内防止用户执行任何其他操作
|
||||
|
||||
:param max_rate: 在抛出异常之前最多允许 频率/秒 的速度
|
||||
:param time_period: 在限制速率的时间段的持续时间
|
||||
:param amount: 提供的容量
|
||||
:param limit_time: 限制时间 如果不提供限制时间为 max_rate / time_period * amount
|
||||
"""
|
||||
self.max_rate = max_rate
|
||||
self.amount = amount
|
||||
self._rate_per_sec = max_rate / time_period
|
||||
self.limit_time = limit_time
|
||||
super().__init__(Update, self.limiter_callback)
|
||||
|
||||
async def limiter_callback(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
if update.inline_query is not None:
|
||||
return
|
||||
loop = asyncio.get_running_loop()
|
||||
async with self._lock:
|
||||
time = loop.time()
|
||||
user_data = context.user_data
|
||||
if user_data is None:
|
||||
return
|
||||
user_limit_time = user_data.get("limit_time")
|
||||
if user_limit_time is not None:
|
||||
if time >= user_limit_time:
|
||||
del user_data["limit_time"]
|
||||
else:
|
||||
raise ApplicationHandlerStop
|
||||
last_task_time = user_data.get("last_task_time", 0)
|
||||
if last_task_time:
|
||||
task_level = user_data.get("task_level", 0)
|
||||
elapsed = time - last_task_time
|
||||
decrement = elapsed * self._rate_per_sec
|
||||
task_level = max(task_level - decrement, 0)
|
||||
user_data["task_level"] = task_level
|
||||
if not task_level + self.amount <= self.max_rate:
|
||||
if self.limit_time:
|
||||
limit_time = self.limit_time
|
||||
else:
|
||||
limit_time = 1 / self._rate_per_sec * self.amount
|
||||
user_data["limit_time"] = time + limit_time
|
||||
user = update.effective_user
|
||||
logger.warning("用户 %s[%s] 触发洪水限制 已被限制 %s 秒", user.full_name, user.id, limit_time)
|
||||
raise ApplicationHandlerStop
|
||||
user_data["last_task_time"] = time
|
||||
task_level = user_data.get("task_level", 0)
|
||||
user_data["task_level"] = task_level + self.amount
|
286
core/manager.py
Normal file
286
core/manager.py
Normal file
@ -0,0 +1,286 @@
|
||||
import asyncio
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from typing import Dict, Generic, List, Optional, TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
from arkowrapper import ArkoWrapper
|
||||
from async_timeout import timeout
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from core.base_service import BaseServiceType, ComponentType, DependenceType, get_all_services
|
||||
from core.config import config as bot_config
|
||||
from utils.const import PLUGIN_DIR, PROJECT_ROOT
|
||||
from utils.helpers import gen_pkg
|
||||
from utils.log import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.application import Application
|
||||
from core.plugin import PluginType
|
||||
from core.builtins.executor import Executor
|
||||
|
||||
__all__ = ("DependenceManager", "PluginManager", "ComponentManager", "ServiceManager", "Managers")
|
||||
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
def _load_module(path: Path) -> None:
|
||||
for pkg in gen_pkg(path):
|
||||
try:
|
||||
logger.debug('正在导入 "%s"', pkg)
|
||||
import_module(pkg)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'在导入 "%s" 的过程中遇到了错误 [red bold]%s[/]', pkg, type(e).__name__, exc_info=e, extra={"markup": True}
|
||||
)
|
||||
raise SystemExit from e
|
||||
|
||||
|
||||
class Manager(Generic[T]):
|
||||
"""生命周期控制基类"""
|
||||
|
||||
_executor: Optional["Executor"] = None
|
||||
_lib: Dict[Type[T], T] = {}
|
||||
_application: "Optional[Application]" = None
|
||||
|
||||
def set_application(self, application: "Application") -> None:
|
||||
self._application = application
|
||||
|
||||
@property
|
||||
def application(self) -> "Application":
|
||||
if self._application is None:
|
||||
raise RuntimeError(f"No application was set for this {self.__class__.__name__}.")
|
||||
return self._application
|
||||
|
||||
@property
|
||||
def executor(self) -> "Executor":
|
||||
"""执行器"""
|
||||
if self._executor is None:
|
||||
raise RuntimeError(f"No executor was set for this {self.__class__.__name__}.")
|
||||
return self._executor
|
||||
|
||||
def build_executor(self, name: str):
|
||||
from core.builtins.executor import Executor
|
||||
from core.builtins.dispatcher import BaseDispatcher
|
||||
|
||||
self._executor = Executor(name, dispatcher=BaseDispatcher)
|
||||
self._executor.set_application(self.application)
|
||||
|
||||
|
||||
class DependenceManager(Manager[DependenceType]):
|
||||
"""基础依赖管理"""
|
||||
|
||||
_dependency: Dict[Type[DependenceType], DependenceType] = {}
|
||||
|
||||
@property
|
||||
def dependency(self) -> List[DependenceType]:
|
||||
return list(self._dependency.values())
|
||||
|
||||
@property
|
||||
def dependency_map(self) -> Dict[Type[DependenceType], DependenceType]:
|
||||
return self._dependency
|
||||
|
||||
async def start_dependency(self) -> None:
|
||||
_load_module(PROJECT_ROOT / "core/dependence")
|
||||
|
||||
for dependence in filter(lambda x: x.is_dependence, get_all_services()):
|
||||
dependence: Type[DependenceType]
|
||||
instance: DependenceType
|
||||
try:
|
||||
if hasattr(dependence, "from_config"): # 如果有 from_config 方法
|
||||
instance = dependence.from_config(bot_config) # 用 from_config 实例化服务
|
||||
else:
|
||||
instance = await self.executor(dependence)
|
||||
|
||||
await instance.initialize()
|
||||
logger.success('基础服务 "%s" 启动成功', dependence.__name__)
|
||||
|
||||
self._lib[dependence] = instance
|
||||
self._dependency[dependence] = instance
|
||||
|
||||
except Exception as e:
|
||||
logger.exception('基础服务 "%s" 初始化失败,BOT 将自动关闭', dependence.__name__)
|
||||
raise SystemExit from e
|
||||
|
||||
async def stop_dependency(self) -> None:
|
||||
async def task(d):
|
||||
try:
|
||||
async with timeout(5):
|
||||
await d.shutdown()
|
||||
logger.debug('基础服务 "%s" 关闭成功', d.__class__.__name__)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning('基础服务 "%s" 关闭超时', d.__class__.__name__)
|
||||
except Exception as e:
|
||||
logger.error('基础服务 "%s" 关闭错误', d.__class__.__name__, exc_info=e)
|
||||
|
||||
tasks = []
|
||||
for dependence in self._dependency.values():
|
||||
tasks.append(asyncio.create_task(task(dependence)))
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
class ComponentManager(Manager[ComponentType]):
|
||||
"""组件管理"""
|
||||
|
||||
_components: Dict[Type[ComponentType], ComponentType] = {}
|
||||
|
||||
@property
|
||||
def components(self) -> List[ComponentType]:
|
||||
return list(self._components.values())
|
||||
|
||||
@property
|
||||
def components_map(self) -> Dict[Type[ComponentType], ComponentType]:
|
||||
return self._components
|
||||
|
||||
async def init_components(self):
|
||||
for path in filter(
|
||||
lambda x: x.is_dir() and not x.name.startswith("_"), PROJECT_ROOT.joinpath("core/services").iterdir()
|
||||
):
|
||||
_load_module(path)
|
||||
components = ArkoWrapper(get_all_services()).filter(lambda x: x.is_component)
|
||||
retry_times = 0
|
||||
max_retry_times = len(components)
|
||||
while components:
|
||||
start_len = len(components)
|
||||
for component in list(components):
|
||||
component: Type[ComponentType]
|
||||
instance: ComponentType
|
||||
try:
|
||||
instance = await self.executor(component)
|
||||
self._lib[component] = instance
|
||||
self._components[component] = instance
|
||||
components = components.remove(component)
|
||||
except Exception as e: # pylint: disable=W0703
|
||||
logger.debug('组件 "%s" 初始化失败: [red]%s[/]', component.__name__, e, extra={"markup": True})
|
||||
end_len = len(list(components))
|
||||
if start_len == end_len:
|
||||
retry_times += 1
|
||||
|
||||
if retry_times == max_retry_times and components:
|
||||
for component in components:
|
||||
logger.error('组件 "%s" 初始化失败', component.__name__)
|
||||
raise SystemExit
|
||||
|
||||
|
||||
class ServiceManager(Manager[BaseServiceType]):
|
||||
"""服务控制类"""
|
||||
|
||||
_services: Dict[Type[BaseServiceType], BaseServiceType] = {}
|
||||
|
||||
@property
|
||||
def services(self) -> List[BaseServiceType]:
|
||||
return list(self._services.values())
|
||||
|
||||
@property
|
||||
def services_map(self) -> Dict[Type[BaseServiceType], BaseServiceType]:
|
||||
return self._services
|
||||
|
||||
async def _initialize_service(self, target: Type[BaseServiceType]) -> BaseServiceType:
|
||||
instance: BaseServiceType
|
||||
try:
|
||||
if hasattr(target, "from_config"): # 如果有 from_config 方法
|
||||
instance = target.from_config(bot_config) # 用 from_config 实例化服务
|
||||
else:
|
||||
instance = await self.executor(target)
|
||||
|
||||
await instance.initialize()
|
||||
logger.success('服务 "%s" 启动成功', target.__name__)
|
||||
|
||||
return instance
|
||||
|
||||
except Exception as e: # pylint: disable=W0703
|
||||
logger.exception('服务 "%s" 初始化失败,BOT 将自动关闭', target.__name__)
|
||||
raise SystemExit from e
|
||||
|
||||
async def start_services(self) -> None:
|
||||
for path in filter(
|
||||
lambda x: x.is_dir() and not x.name.startswith("_"), PROJECT_ROOT.joinpath("core/services").iterdir()
|
||||
):
|
||||
_load_module(path)
|
||||
|
||||
for service in filter(lambda x: not x.is_component and not x.is_dependence, get_all_services()): # 遍历所有服务类
|
||||
instance = await self._initialize_service(service)
|
||||
|
||||
self._lib[service] = instance
|
||||
self._services[service] = instance
|
||||
|
||||
async def stop_services(self) -> None:
|
||||
"""关闭服务"""
|
||||
if not self._services:
|
||||
return
|
||||
|
||||
async def task(s):
|
||||
try:
|
||||
async with timeout(5):
|
||||
await s.shutdown()
|
||||
logger.success('服务 "%s" 关闭成功', s.__class__.__name__)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning('服务 "%s" 关闭超时', s.__class__.__name__)
|
||||
except Exception as e:
|
||||
logger.warning('服务 "%s" 关闭失败', s.__class__.__name__, exc_info=e)
|
||||
|
||||
logger.info("正在关闭服务")
|
||||
tasks = []
|
||||
for service in self._services.values():
|
||||
tasks.append(asyncio.create_task(task(service)))
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
class PluginManager(Manager["PluginType"]):
|
||||
"""插件管理"""
|
||||
|
||||
_plugins: Dict[Type["PluginType"], "PluginType"] = {}
|
||||
|
||||
@property
|
||||
def plugins(self) -> List["PluginType"]:
|
||||
"""所有已经加载的插件"""
|
||||
return list(self._plugins.values())
|
||||
|
||||
@property
|
||||
def plugins_map(self) -> Dict[Type["PluginType"], "PluginType"]:
|
||||
return self._plugins
|
||||
|
||||
async def install_plugins(self) -> None:
|
||||
"""安装所有插件"""
|
||||
from core.plugin import get_all_plugins
|
||||
|
||||
for path in filter(lambda x: x.is_dir(), PLUGIN_DIR.iterdir()):
|
||||
_load_module(path)
|
||||
|
||||
for plugin in get_all_plugins():
|
||||
plugin: Type["PluginType"]
|
||||
|
||||
try:
|
||||
instance: "PluginType" = await self.executor(plugin)
|
||||
except Exception as e: # pylint: disable=W0703
|
||||
logger.error('插件 "%s" 初始化失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e)
|
||||
continue
|
||||
|
||||
self._plugins[plugin] = instance
|
||||
|
||||
if self._application is not None:
|
||||
instance.set_application(self._application)
|
||||
|
||||
await asyncio.create_task(self.plugin_install_task(plugin, instance))
|
||||
|
||||
@staticmethod
|
||||
async def plugin_install_task(plugin: Type["PluginType"], instance: "PluginType"):
|
||||
try:
|
||||
await instance.install()
|
||||
logger.success('插件 "%s" 安装成功', f"{plugin.__module__}.{plugin.__name__}")
|
||||
except Exception as e: # pylint: disable=W0703
|
||||
logger.error('插件 "%s" 安装失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e)
|
||||
|
||||
async def uninstall_plugins(self) -> None:
|
||||
for plugin in self._plugins.values():
|
||||
try:
|
||||
await plugin.uninstall()
|
||||
except Exception as e: # pylint: disable=W0703
|
||||
logger.error('插件 "%s" 卸载失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e)
|
||||
|
||||
|
||||
class Managers(DependenceManager, ComponentManager, ServiceManager, PluginManager):
|
||||
"""BOT 除自身外的生命周期管理类"""
|
106
core/override/telegram.py
Normal file
106
core/override/telegram.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""重写 telegram.request.HTTPXRequest 使其使用 ujson 库进行 json 序列化"""
|
||||
from typing import Any, AsyncIterable, Optional
|
||||
|
||||
import httpcore
|
||||
from httpx import (
|
||||
AsyncByteStream,
|
||||
AsyncHTTPTransport as DefaultAsyncHTTPTransport,
|
||||
Limits,
|
||||
Response as DefaultResponse,
|
||||
Timeout,
|
||||
)
|
||||
from telegram.request import HTTPXRequest as DefaultHTTPXRequest
|
||||
|
||||
try:
|
||||
import ujson as jsonlib
|
||||
except ImportError:
|
||||
import json as jsonlib
|
||||
|
||||
__all__ = ("HTTPXRequest",)
|
||||
|
||||
|
||||
class Response(DefaultResponse):
|
||||
def json(self, **kwargs: Any) -> Any:
|
||||
# noinspection PyProtectedMember
|
||||
from httpx._utils import guess_json_utf
|
||||
|
||||
if self.charset_encoding is None and self.content and len(self.content) > 3:
|
||||
encoding = guess_json_utf(self.content)
|
||||
if encoding is not None:
|
||||
return jsonlib.loads(self.content.decode(encoding), **kwargs)
|
||||
return jsonlib.loads(self.text, **kwargs)
|
||||
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
class AsyncHTTPTransport(DefaultAsyncHTTPTransport):
|
||||
async def handle_async_request(self, request) -> Response:
|
||||
from httpx._transports.default import (
|
||||
map_httpcore_exceptions,
|
||||
AsyncResponseStream,
|
||||
)
|
||||
|
||||
if not isinstance(request.stream, AsyncByteStream):
|
||||
raise AssertionError
|
||||
|
||||
req = httpcore.Request(
|
||||
method=request.method,
|
||||
url=httpcore.URL(
|
||||
scheme=request.url.raw_scheme,
|
||||
host=request.url.raw_host,
|
||||
port=request.url.port,
|
||||
target=request.url.raw_path,
|
||||
),
|
||||
headers=request.headers.raw,
|
||||
content=request.stream,
|
||||
extensions=request.extensions,
|
||||
)
|
||||
with map_httpcore_exceptions():
|
||||
resp = await self._pool.handle_async_request(req)
|
||||
|
||||
if not isinstance(resp.stream, AsyncIterable):
|
||||
raise AssertionError
|
||||
|
||||
return Response(
|
||||
status_code=resp.status,
|
||||
headers=resp.headers,
|
||||
stream=AsyncResponseStream(resp.stream),
|
||||
extensions=resp.extensions,
|
||||
)
|
||||
|
||||
|
||||
class HTTPXRequest(DefaultHTTPXRequest):
|
||||
def __init__( # pylint: disable=W0231
|
||||
self,
|
||||
connection_pool_size: int = 1,
|
||||
proxy_url: str = None,
|
||||
read_timeout: Optional[float] = 5.0,
|
||||
write_timeout: Optional[float] = 5.0,
|
||||
connect_timeout: Optional[float] = 5.0,
|
||||
pool_timeout: Optional[float] = 1.0,
|
||||
):
|
||||
timeout = Timeout(
|
||||
connect=connect_timeout,
|
||||
read=read_timeout,
|
||||
write=write_timeout,
|
||||
pool=pool_timeout,
|
||||
)
|
||||
limits = Limits(
|
||||
max_connections=connection_pool_size,
|
||||
max_keepalive_connections=connection_pool_size,
|
||||
)
|
||||
self._client_kwargs = dict(
|
||||
timeout=timeout,
|
||||
proxies=proxy_url,
|
||||
limits=limits,
|
||||
transport=AsyncHTTPTransport(limits=limits),
|
||||
)
|
||||
|
||||
try:
|
||||
self._client = self._build_client()
|
||||
except ImportError as exc:
|
||||
if "httpx[socks]" not in str(exc):
|
||||
raise exc
|
||||
|
||||
raise RuntimeError(
|
||||
"To use Socks5 proxies, PTB must be installed via `pip install python-telegram-bot[socks]`."
|
||||
) from exc
|
483
core/plugin.py
483
core/plugin.py
@ -1,483 +0,0 @@
|
||||
import copy
|
||||
import datetime
|
||||
import re
|
||||
from importlib import import_module
|
||||
from re import Pattern
|
||||
from types import MethodType
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
from telegram._utils.defaultvalue import DEFAULT_TRUE
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
from telegram._utils.types import DVInput, JSONDict
|
||||
from telegram.ext import BaseHandler, ConversationHandler, Job
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
from telegram.ext._utils.types import JobCallback
|
||||
from telegram.ext.filters import BaseFilter
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
__all__ = ["Plugin", "handler", "conversation", "job", "error_handler"]
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
HandlerType = TypeVar("HandlerType", bound=BaseHandler)
|
||||
TimeType = Union[float, datetime.timedelta, datetime.datetime, datetime.time]
|
||||
|
||||
_Module = import_module("telegram.ext")
|
||||
|
||||
_NORMAL_HANDLER_ATTR_NAME = "_handler_data"
|
||||
_CONVERSATION_HANDLER_ATTR_NAME = "_conversation_data"
|
||||
_JOB_ATTR_NAME = "_job_data"
|
||||
|
||||
_EXCLUDE_ATTRS = ["handlers", "jobs", "error_handlers"]
|
||||
|
||||
|
||||
class _Plugin:
|
||||
def _make_handler(self, datas: Union[List[Dict], Dict]) -> List[HandlerType]:
|
||||
result = []
|
||||
if isinstance(datas, list):
|
||||
for data in filter(lambda x: x, datas):
|
||||
func = getattr(self, data.pop("func"))
|
||||
result.append(data.pop("type")(callback=func, **data.pop("kwargs")))
|
||||
else:
|
||||
func = getattr(self, datas.pop("func"))
|
||||
result.append(datas.pop("type")(callback=func, **datas.pop("kwargs")))
|
||||
return result
|
||||
|
||||
@property
|
||||
def handlers(self) -> List[HandlerType]:
|
||||
result = []
|
||||
for attr in dir(self):
|
||||
# noinspection PyUnboundLocalVariable
|
||||
if (
|
||||
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
||||
and isinstance(func := getattr(self, attr), MethodType)
|
||||
and (datas := getattr(func, _NORMAL_HANDLER_ATTR_NAME, None))
|
||||
):
|
||||
for data in datas:
|
||||
if data["type"] not in ["error", "new_chat_member"]:
|
||||
result.extend(self._make_handler(data))
|
||||
return result
|
||||
|
||||
def _new_chat_members_handler_funcs(self) -> List[Tuple[int, Callable]]:
|
||||
result = []
|
||||
for attr in dir(self):
|
||||
# noinspection PyUnboundLocalVariable
|
||||
if (
|
||||
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
||||
and isinstance(func := getattr(self, attr), MethodType)
|
||||
and (datas := getattr(func, _NORMAL_HANDLER_ATTR_NAME, None))
|
||||
):
|
||||
for data in datas:
|
||||
if data and data["type"] == "new_chat_member":
|
||||
result.append((data["priority"], func))
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def error_handlers(self) -> Dict[Callable, bool]:
|
||||
result = {}
|
||||
for attr in dir(self):
|
||||
# noinspection PyUnboundLocalVariable
|
||||
if (
|
||||
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
||||
and isinstance(func := getattr(self, attr), MethodType)
|
||||
and (datas := getattr(func, _NORMAL_HANDLER_ATTR_NAME, None))
|
||||
):
|
||||
for data in datas:
|
||||
if data and data["type"] == "error":
|
||||
result.update({func: data["block"]})
|
||||
return result
|
||||
|
||||
@property
|
||||
def jobs(self) -> List[Job]:
|
||||
from core.bot import bot
|
||||
|
||||
result = []
|
||||
for attr in dir(self):
|
||||
# noinspection PyUnboundLocalVariable
|
||||
if (
|
||||
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
||||
and isinstance(func := getattr(self, attr), MethodType)
|
||||
and (datas := getattr(func, _JOB_ATTR_NAME, None))
|
||||
):
|
||||
for data in datas:
|
||||
_job = getattr(bot.job_queue, data.pop("type"))(
|
||||
callback=func, **data.pop("kwargs"), **{key: data.pop(key) for key in list(data.keys())}
|
||||
)
|
||||
result.append(_job)
|
||||
return result
|
||||
|
||||
|
||||
class _Conversation(_Plugin):
|
||||
_conversation_kwargs: Dict
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
cls._conversation_kwargs = kwargs
|
||||
super(_Conversation, cls).__init_subclass__()
|
||||
return cls
|
||||
|
||||
@property
|
||||
def handlers(self) -> List[HandlerType]:
|
||||
result: List[HandlerType] = []
|
||||
|
||||
entry_points: List[HandlerType] = []
|
||||
states: Dict[Any, List[HandlerType]] = {}
|
||||
fallbacks: List[HandlerType] = []
|
||||
for attr in dir(self):
|
||||
# noinspection PyUnboundLocalVariable
|
||||
if (
|
||||
not (attr.startswith("_") or attr == "handlers")
|
||||
and isinstance(func := getattr(self, attr), Callable)
|
||||
and (handler_datas := getattr(func, _NORMAL_HANDLER_ATTR_NAME, None))
|
||||
):
|
||||
conversation_data = getattr(func, _CONVERSATION_HANDLER_ATTR_NAME, None)
|
||||
if attr == "cancel":
|
||||
handler_datas = copy.deepcopy(handler_datas)
|
||||
conversation_data = copy.deepcopy(conversation_data)
|
||||
_handlers = self._make_handler(handler_datas)
|
||||
if conversation_data:
|
||||
if (_type := conversation_data.pop("type")) == "entry":
|
||||
entry_points.extend(_handlers)
|
||||
elif _type == "state":
|
||||
if (key := conversation_data.pop("state")) in states:
|
||||
states[key].extend(_handlers)
|
||||
else:
|
||||
states[key] = _handlers
|
||||
elif _type == "fallback":
|
||||
fallbacks.extend(_handlers)
|
||||
else:
|
||||
result.extend(_handlers)
|
||||
if entry_points or states or fallbacks:
|
||||
result.append(
|
||||
ConversationHandler(
|
||||
entry_points, states, fallbacks, **self.__class__._conversation_kwargs # pylint: disable=W0212
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class Plugin(_Plugin):
|
||||
Conversation = _Conversation
|
||||
|
||||
|
||||
class _Handler:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
@property
|
||||
def _type(self) -> Type[BaseHandler]:
|
||||
return getattr(_Module, f"{self.__class__.__name__.strip('_')}Handler")
|
||||
|
||||
def __call__(self, func: Callable[P, T]) -> Callable[P, T]:
|
||||
data = {"type": self._type, "func": func.__name__, "kwargs": self.kwargs}
|
||||
if hasattr(func, _NORMAL_HANDLER_ATTR_NAME):
|
||||
handler_datas = getattr(func, _NORMAL_HANDLER_ATTR_NAME)
|
||||
handler_datas.append(data)
|
||||
setattr(func, _NORMAL_HANDLER_ATTR_NAME, handler_datas)
|
||||
else:
|
||||
setattr(func, _NORMAL_HANDLER_ATTR_NAME, [data])
|
||||
return func
|
||||
|
||||
|
||||
class _CallbackQuery(_Handler):
|
||||
def __init__(
|
||||
self,
|
||||
pattern: Union[str, Pattern, type, Callable[[object], Optional[bool]]] = None,
|
||||
block: DVInput[bool] = DEFAULT_TRUE,
|
||||
):
|
||||
super(_CallbackQuery, self).__init__(pattern=pattern, block=block)
|
||||
|
||||
|
||||
class _ChatJoinRequest(_Handler):
|
||||
def __init__(self, block: DVInput[bool] = DEFAULT_TRUE):
|
||||
super(_ChatJoinRequest, self).__init__(block=block)
|
||||
|
||||
|
||||
class _ChatMember(_Handler):
|
||||
def __init__(self, chat_member_types: int = -1, block: DVInput[bool] = DEFAULT_TRUE):
|
||||
super().__init__(chat_member_types=chat_member_types, block=block)
|
||||
|
||||
|
||||
class _ChosenInlineResult(_Handler):
|
||||
def __init__(self, block: DVInput[bool] = DEFAULT_TRUE, pattern: Union[str, Pattern] = None):
|
||||
super().__init__(block=block, pattern=pattern)
|
||||
|
||||
|
||||
class _Command(_Handler):
|
||||
def __init__(self, command: str, filters: "BaseFilter" = None, block: DVInput[bool] = DEFAULT_TRUE):
|
||||
super(_Command, self).__init__(command=command, filters=filters, block=block)
|
||||
|
||||
|
||||
class _InlineQuery(_Handler):
|
||||
def __init__(
|
||||
self, pattern: Union[str, Pattern] = None, block: DVInput[bool] = DEFAULT_TRUE, chat_types: List[str] = None
|
||||
):
|
||||
super().__init__(pattern=pattern, block=block, chat_types=chat_types)
|
||||
|
||||
|
||||
class _MessageNewChatMembers(_Handler):
|
||||
def __init__(self, func: Callable[P, T] = None, *, priority: int = 5):
|
||||
super().__init__()
|
||||
self.func = func
|
||||
self.priority = priority
|
||||
|
||||
def __call__(self, func: Callable[P, T] = None) -> Callable[P, T]:
|
||||
self.func = self.func or func
|
||||
data = {"type": "new_chat_member", "priority": self.priority}
|
||||
if hasattr(func, _NORMAL_HANDLER_ATTR_NAME):
|
||||
handler_datas = getattr(func, _NORMAL_HANDLER_ATTR_NAME)
|
||||
handler_datas.append(data)
|
||||
setattr(func, _NORMAL_HANDLER_ATTR_NAME, handler_datas)
|
||||
else:
|
||||
setattr(func, _NORMAL_HANDLER_ATTR_NAME, [data])
|
||||
return func
|
||||
|
||||
|
||||
class _Message(_Handler):
|
||||
def __init__(
|
||||
self,
|
||||
filters: "BaseFilter",
|
||||
block: DVInput[bool] = DEFAULT_TRUE,
|
||||
):
|
||||
super(_Message, self).__init__(filters=filters, block=block)
|
||||
|
||||
new_chat_members = _MessageNewChatMembers
|
||||
|
||||
|
||||
class _PollAnswer(_Handler):
|
||||
def __init__(self, block: DVInput[bool] = DEFAULT_TRUE):
|
||||
super(_PollAnswer, self).__init__(block=block)
|
||||
|
||||
|
||||
class _Poll(_Handler):
|
||||
def __init__(self, block: DVInput[bool] = DEFAULT_TRUE):
|
||||
super(_Poll, self).__init__(block=block)
|
||||
|
||||
|
||||
class _PreCheckoutQuery(_Handler):
|
||||
def __init__(self, block: DVInput[bool] = DEFAULT_TRUE):
|
||||
super(_PreCheckoutQuery, self).__init__(block=block)
|
||||
|
||||
|
||||
class _Prefix(_Handler):
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
command: str,
|
||||
filters: BaseFilter = None,
|
||||
block: DVInput[bool] = DEFAULT_TRUE,
|
||||
):
|
||||
super(_Prefix, self).__init__(prefix=prefix, command=command, filters=filters, block=block)
|
||||
|
||||
|
||||
class _ShippingQuery(_Handler):
|
||||
def __init__(self, block: DVInput[bool] = DEFAULT_TRUE):
|
||||
super(_ShippingQuery, self).__init__(block=block)
|
||||
|
||||
|
||||
class _StringCommand(_Handler):
|
||||
def __init__(self, command: str):
|
||||
super(_StringCommand, self).__init__(command=command)
|
||||
|
||||
|
||||
class _StringRegex(_Handler):
|
||||
def __init__(self, pattern: Union[str, Pattern], block: DVInput[bool] = DEFAULT_TRUE):
|
||||
super(_StringRegex, self).__init__(pattern=pattern, block=block)
|
||||
|
||||
|
||||
class _Type(_Handler):
|
||||
# noinspection PyShadowingBuiltins
|
||||
def __init__(
|
||||
self, type: Type, strict: bool = False, block: DVInput[bool] = DEFAULT_TRUE # pylint: disable=redefined-builtin
|
||||
):
|
||||
super(_Type, self).__init__(type=type, strict=strict, block=block)
|
||||
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
class handler(_Handler):
|
||||
def __init__(self, handler_type: Callable[P, HandlerType], **kwargs: P.kwargs):
|
||||
self._type_ = handler_type
|
||||
super(handler, self).__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def _type(self) -> Type[BaseHandler]:
|
||||
# noinspection PyTypeChecker
|
||||
return self._type_
|
||||
|
||||
callback_query = _CallbackQuery
|
||||
chat_join_request = _ChatJoinRequest
|
||||
chat_member = _ChatMember
|
||||
chosen_inline_result = _ChosenInlineResult
|
||||
command = _Command
|
||||
inline_query = _InlineQuery
|
||||
message = _Message
|
||||
poll_answer = _PollAnswer
|
||||
pool = _Poll
|
||||
pre_checkout_query = _PreCheckoutQuery
|
||||
prefix = _Prefix
|
||||
shipping_query = _ShippingQuery
|
||||
string_command = _StringCommand
|
||||
string_regex = _StringRegex
|
||||
type = _Type
|
||||
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
class error_handler:
|
||||
def __init__(self, func: Callable[P, T] = None, *, block: bool = DEFAULT_TRUE):
|
||||
self._func = func
|
||||
self._block = block
|
||||
|
||||
def __call__(self, func: Callable[P, T] = None) -> Callable[P, T]:
|
||||
self._func = func or self._func
|
||||
data = {"type": "error", "block": self._block}
|
||||
if hasattr(func, _NORMAL_HANDLER_ATTR_NAME):
|
||||
handler_datas = getattr(func, _NORMAL_HANDLER_ATTR_NAME)
|
||||
handler_datas.append(data)
|
||||
setattr(func, _NORMAL_HANDLER_ATTR_NAME, handler_datas)
|
||||
else:
|
||||
setattr(func, _NORMAL_HANDLER_ATTR_NAME, [data])
|
||||
return func
|
||||
|
||||
|
||||
def _entry(func: Callable[P, T]) -> Callable[P, T]:
|
||||
setattr(func, _CONVERSATION_HANDLER_ATTR_NAME, {"type": "entry"})
|
||||
return func
|
||||
|
||||
|
||||
class _State:
|
||||
def __init__(self, state: Any):
|
||||
self.state = state
|
||||
|
||||
def __call__(self, func: Callable[P, T] = None) -> Callable[P, T]:
|
||||
setattr(func, _CONVERSATION_HANDLER_ATTR_NAME, {"type": "state", "state": self.state})
|
||||
return func
|
||||
|
||||
|
||||
def _fallback(func: Callable[P, T]) -> Callable[P, T]:
|
||||
setattr(func, _CONVERSATION_HANDLER_ATTR_NAME, {"type": "fallback"})
|
||||
return func
|
||||
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
class conversation(_Handler):
|
||||
entry_point = _entry
|
||||
state = _State
|
||||
fallback = _fallback
|
||||
|
||||
|
||||
class _Job:
|
||||
kwargs: Dict = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = None,
|
||||
data: object = None,
|
||||
chat_id: int = None,
|
||||
user_id: int = None,
|
||||
job_kwargs: JSONDict = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.name = name
|
||||
self.data = data
|
||||
self.chat_id = chat_id
|
||||
self.user_id = user_id
|
||||
self.job_kwargs = {} if job_kwargs is None else job_kwargs
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __call__(self, func: JobCallback) -> JobCallback:
|
||||
data = {
|
||||
"name": self.name,
|
||||
"data": self.data,
|
||||
"chat_id": self.chat_id,
|
||||
"user_id": self.user_id,
|
||||
"job_kwargs": self.job_kwargs,
|
||||
"kwargs": self.kwargs,
|
||||
"type": re.sub(r"([A-Z])", lambda x: "_" + x.group().lower(), self.__class__.__name__).lstrip("_"),
|
||||
}
|
||||
if hasattr(func, _JOB_ATTR_NAME):
|
||||
job_datas = getattr(func, _JOB_ATTR_NAME)
|
||||
job_datas.append(data)
|
||||
setattr(func, _JOB_ATTR_NAME, job_datas)
|
||||
else:
|
||||
setattr(func, _JOB_ATTR_NAME, [data])
|
||||
return func
|
||||
|
||||
|
||||
class _RunOnce(_Job):
|
||||
def __init__(
|
||||
self,
|
||||
when: TimeType,
|
||||
data: object = None,
|
||||
name: str = None,
|
||||
chat_id: int = None,
|
||||
user_id: int = None,
|
||||
job_kwargs: JSONDict = None,
|
||||
):
|
||||
super().__init__(name, data, chat_id, user_id, job_kwargs, when=when)
|
||||
|
||||
|
||||
class _RunRepeating(_Job):
|
||||
def __init__(
|
||||
self,
|
||||
interval: Union[float, datetime.timedelta],
|
||||
first: TimeType = None,
|
||||
last: TimeType = None,
|
||||
data: object = None,
|
||||
name: str = None,
|
||||
chat_id: int = None,
|
||||
user_id: int = None,
|
||||
job_kwargs: JSONDict = None,
|
||||
):
|
||||
super().__init__(name, data, chat_id, user_id, job_kwargs, interval=interval, first=first, last=last)
|
||||
|
||||
|
||||
class _RunMonthly(_Job):
|
||||
def __init__(
|
||||
self,
|
||||
when: datetime.time,
|
||||
day: int,
|
||||
data: object = None,
|
||||
name: str = None,
|
||||
chat_id: int = None,
|
||||
user_id: int = None,
|
||||
job_kwargs: JSONDict = None,
|
||||
):
|
||||
super().__init__(name, data, chat_id, user_id, job_kwargs, when=when, day=day)
|
||||
|
||||
|
||||
class _RunDaily(_Job):
|
||||
def __init__(
|
||||
self,
|
||||
time: datetime.time,
|
||||
days: Tuple[int, ...] = tuple(range(7)),
|
||||
data: object = None,
|
||||
name: str = None,
|
||||
chat_id: int = None,
|
||||
user_id: int = None,
|
||||
job_kwargs: JSONDict = None,
|
||||
):
|
||||
super().__init__(name, data, chat_id, user_id, job_kwargs, time=time, days=days)
|
||||
|
||||
|
||||
class _RunCustom(_Job):
|
||||
def __init__(
|
||||
self,
|
||||
data: object = None,
|
||||
name: str = None,
|
||||
chat_id: int = None,
|
||||
user_id: int = None,
|
||||
job_kwargs: JSONDict = None,
|
||||
):
|
||||
super().__init__(name, data, chat_id, user_id, job_kwargs)
|
||||
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
class job:
|
||||
run_once = _RunOnce
|
||||
run_repeating = _RunRepeating
|
||||
run_monthly = _RunMonthly
|
||||
run_daily = _RunDaily
|
||||
run_custom = _RunCustom
|
16
core/plugin/__init__.py
Normal file
16
core/plugin/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
"""插件"""
|
||||
|
||||
from core.plugin._handler import conversation, error_handler, handler
|
||||
from core.plugin._job import TimeType, job
|
||||
from core.plugin._plugin import Plugin, PluginType, get_all_plugins
|
||||
|
||||
__all__ = (
|
||||
"Plugin",
|
||||
"PluginType",
|
||||
"get_all_plugins",
|
||||
"handler",
|
||||
"error_handler",
|
||||
"conversation",
|
||||
"job",
|
||||
"TimeType",
|
||||
)
|
175
core/plugin/_funcs.py
Normal file
175
core/plugin/_funcs.py
Normal file
@ -0,0 +1,175 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union, TYPE_CHECKING
|
||||
|
||||
import aiofiles
|
||||
import httpx
|
||||
from httpx import UnsupportedProtocol
|
||||
from telegram import Chat, Message, ReplyKeyboardRemove, Update
|
||||
from telegram.error import BadRequest, Forbidden
|
||||
from telegram.ext import CallbackContext, ConversationHandler, Job
|
||||
|
||||
from core.dependence.redisdb import RedisDB
|
||||
from core.plugin._handler import conversation, handler
|
||||
from utils.const import CACHE_DIR, REQUEST_HEADERS
|
||||
from utils.error import UrlResourcesNotFoundError
|
||||
from utils.helpers import sha1
|
||||
from utils.log import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.application import Application
|
||||
|
||||
try:
|
||||
import ujson as json
|
||||
except ImportError:
|
||||
import json
|
||||
|
||||
__all__ = (
|
||||
"PluginFuncs",
|
||||
"ConversationFuncs",
|
||||
)
|
||||
|
||||
|
||||
class PluginFuncs:
|
||||
_application: "Optional[Application]" = None
|
||||
|
||||
def set_application(self, application: "Application") -> None:
|
||||
self._application = application
|
||||
|
||||
@property
|
||||
def application(self) -> "Application":
|
||||
if self._application is None:
|
||||
raise RuntimeError("No application was set for this PluginManager.")
|
||||
return self._application
|
||||
|
||||
async def _delete_message(self, context: CallbackContext) -> None:
|
||||
job = context.job
|
||||
message_id = job.data
|
||||
chat_info = f"chat_id[{job.chat_id}]"
|
||||
|
||||
try:
|
||||
chat = await self.get_chat(job.chat_id)
|
||||
full_name = chat.full_name
|
||||
if full_name:
|
||||
chat_info = f"{full_name}[{chat.id}]"
|
||||
else:
|
||||
chat_info = f"{chat.title}[{chat.id}]"
|
||||
except (BadRequest, Forbidden) as exc:
|
||||
logger.warning("获取 chat info 失败 %s", exc.message)
|
||||
except Exception as exc:
|
||||
logger.warning("获取 chat info 消息失败 %s", str(exc))
|
||||
|
||||
logger.debug("删除消息 %s message_id[%s]", chat_info, message_id)
|
||||
|
||||
try:
|
||||
# noinspection PyTypeChecker
|
||||
await context.bot.delete_message(chat_id=job.chat_id, message_id=message_id)
|
||||
except BadRequest as exc:
|
||||
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
|
||||
|
||||
async def get_chat(self, chat_id: Union[str, int], redis_db: Optional[RedisDB] = None, ttl: int = 86400) -> Chat:
|
||||
application = self.application
|
||||
redis_db: RedisDB = redis_db or self.application.managers.services_map.get(RedisDB, None)
|
||||
|
||||
if not redis_db:
|
||||
return await application.bot.get_chat(chat_id)
|
||||
|
||||
qname = f"bot:chat:{chat_id}"
|
||||
|
||||
data = await redis_db.client.get(qname)
|
||||
if data:
|
||||
json_data = json.loads(data)
|
||||
return Chat.de_json(json_data, application.telegram.bot)
|
||||
|
||||
chat_info = await application.telegram.bot.get_chat(chat_id)
|
||||
await redis_db.client.set(qname, chat_info.to_json())
|
||||
await redis_db.client.expire(qname, ttl)
|
||||
return chat_info
|
||||
|
||||
def add_delete_message_job(
|
||||
self,
|
||||
message: Optional[Union[int, Message]] = None,
|
||||
*,
|
||||
delay: int = 60,
|
||||
name: Optional[str] = None,
|
||||
chat: Optional[Union[int, Chat]] = None,
|
||||
context: Optional[CallbackContext] = None,
|
||||
) -> Job:
|
||||
"""延迟删除消息"""
|
||||
|
||||
if isinstance(message, Message):
|
||||
if chat is None:
|
||||
chat = message.chat_id
|
||||
message = message.id
|
||||
|
||||
chat = chat.id if isinstance(chat, Chat) else chat
|
||||
|
||||
job_queue = self.application.job_queue or context.job_queue
|
||||
|
||||
if job_queue is None:
|
||||
raise RuntimeError
|
||||
|
||||
return job_queue.run_once(
|
||||
callback=self._delete_message,
|
||||
when=delay,
|
||||
data=message,
|
||||
name=f"{chat}|{message}|{name}|delete_message" if name else f"{chat}|{message}|delete_message",
|
||||
chat_id=chat,
|
||||
job_kwargs={"replace_existing": True, "id": f"{chat}|{message}|delete_message"},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def download_resource(url: str, return_path: bool = False) -> str:
|
||||
url_sha1 = sha1(url) # url 的 hash 值
|
||||
pathed_url = Path(url)
|
||||
|
||||
file_name = url_sha1 + pathed_url.suffix
|
||||
file_path = CACHE_DIR.joinpath(file_name)
|
||||
|
||||
if not file_path.exists(): # 若文件不存在,则下载
|
||||
async with httpx.AsyncClient(headers=REQUEST_HEADERS) as client:
|
||||
try:
|
||||
response = await client.get(url)
|
||||
except UnsupportedProtocol:
|
||||
logger.error("链接不支持 url[%s]", url)
|
||||
return ""
|
||||
|
||||
if response.is_error:
|
||||
logger.error("请求出现错误 url[%s] status_code[%s]", url, response.status_code)
|
||||
raise UrlResourcesNotFoundError(url)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error("download_resource 获取url[%s] 错误 status_code[%s]", url, response.status_code)
|
||||
raise UrlResourcesNotFoundError(url)
|
||||
|
||||
async with aiofiles.open(file_path, mode="wb") as f:
|
||||
await f.write(response.content)
|
||||
|
||||
logger.debug("download_resource 获取url[%s] 并下载到 file_dir[%s]", url, file_path)
|
||||
|
||||
return file_path if return_path else Path(file_path).as_uri()
|
||||
|
||||
@staticmethod
|
||||
def get_args(context: Optional[CallbackContext] = None) -> List[str]:
|
||||
args = context.args
|
||||
match = context.match
|
||||
|
||||
if args is None:
|
||||
if match is not None and (command := match.groups()[0]):
|
||||
temp = []
|
||||
command_parts = command.split(" ")
|
||||
for command_part in command_parts:
|
||||
if command_part:
|
||||
temp.append(command_part)
|
||||
return temp
|
||||
return []
|
||||
if len(args) >= 1:
|
||||
return args
|
||||
return []
|
||||
|
||||
|
||||
class ConversationFuncs:
|
||||
@conversation.fallback
|
||||
@handler.command(command="cancel", block=True)
|
||||
async def cancel(self, update: Update, _) -> int:
|
||||
await update.effective_message.reply_text("退出命令", reply_markup=ReplyKeyboardRemove())
|
||||
return ConversationHandler.END
|
380
core/plugin/_handler.py
Normal file
380
core/plugin/_handler.py
Normal file
@ -0,0 +1,380 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from importlib import import_module
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Pattern,
|
||||
TYPE_CHECKING,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
from telegram._utils.defaultvalue import DEFAULT_TRUE
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
from telegram._utils.types import DVInput
|
||||
from telegram.ext import BaseHandler
|
||||
from telegram.ext.filters import BaseFilter
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from core.handler.callbackqueryhandler import CallbackQueryHandler
|
||||
from utils.const import WRAPPER_ASSIGNMENTS as _WRAPPER_ASSIGNMENTS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.builtins.dispatcher import AbstractDispatcher
|
||||
|
||||
__all__ = (
|
||||
"handler",
|
||||
"conversation",
|
||||
"ConversationDataType",
|
||||
"ConversationData",
|
||||
"HandlerData",
|
||||
"ErrorHandlerData",
|
||||
"error_handler",
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
R = TypeVar("R")
|
||||
UT = TypeVar("UT")
|
||||
|
||||
HandlerType = TypeVar("HandlerType", bound=BaseHandler)
|
||||
HandlerCls = Type[HandlerType]
|
||||
|
||||
Module = import_module("telegram.ext")
|
||||
|
||||
HANDLER_DATA_ATTR_NAME = "_handler_datas"
|
||||
"""用于储存生成 handler 时所需要的参数(例如 block)的属性名"""
|
||||
|
||||
ERROR_HANDLER_ATTR_NAME = "_error_handler_data"
|
||||
|
||||
CONVERSATION_HANDLER_ATTR_NAME = "_conversation_handler_data"
|
||||
"""用于储存生成 ConversationHandler 时所需要的参数(例如 block)的属性名"""
|
||||
|
||||
WRAPPER_ASSIGNMENTS = list(
|
||||
set(
|
||||
_WRAPPER_ASSIGNMENTS
|
||||
+ [
|
||||
HANDLER_DATA_ATTR_NAME,
|
||||
ERROR_HANDLER_ATTR_NAME,
|
||||
CONVERSATION_HANDLER_ATTR_NAME,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dataclass(init=True)
|
||||
class HandlerData:
|
||||
type: Type[HandlerType]
|
||||
admin: bool
|
||||
kwargs: Dict[str, Any]
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None
|
||||
|
||||
|
||||
class _Handler:
|
||||
_type: Type["HandlerType"]
|
||||
|
||||
kwargs: Dict[str, Any] = {}
|
||||
|
||||
def __init_subclass__(cls, **kwargs) -> None:
|
||||
"""用于获取 python-telegram-bot 中对应的 handler class"""
|
||||
|
||||
handler_name = f"{cls.__name__.strip('_')}Handler"
|
||||
|
||||
if handler_name == "CallbackQueryHandler":
|
||||
cls._type = CallbackQueryHandler
|
||||
return
|
||||
|
||||
cls._type = getattr(Module, handler_name, None)
|
||||
|
||||
def __init__(self, admin: bool = False, dispatcher: Optional[Type["AbstractDispatcher"]] = None, **kwargs) -> None:
|
||||
self.dispatcher = dispatcher
|
||||
self.admin = admin
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __call__(self, func: Callable[P, R]) -> Callable[P, R]:
|
||||
"""decorator实现,从 func 生成 Handler"""
|
||||
|
||||
handler_datas = getattr(func, HANDLER_DATA_ATTR_NAME, [])
|
||||
handler_datas.append(
|
||||
HandlerData(type=self._type, admin=self.admin, kwargs=self.kwargs, dispatcher=self.dispatcher)
|
||||
)
|
||||
setattr(func, HANDLER_DATA_ATTR_NAME, handler_datas)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
class _CallbackQuery(_Handler):
|
||||
def __init__(
|
||||
self,
|
||||
pattern: Union[str, Pattern, type, Callable[[object], Optional[bool]]] = None,
|
||||
*,
|
||||
block: DVInput[bool] = DEFAULT_TRUE,
|
||||
admin: bool = False,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super(_CallbackQuery, self).__init__(pattern=pattern, block=block, admin=admin, dispatcher=dispatcher)
|
||||
|
||||
|
||||
class _ChatJoinRequest(_Handler):
|
||||
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
|
||||
super(_ChatJoinRequest, self).__init__(block=block, dispatcher=dispatcher)
|
||||
|
||||
|
||||
class _ChatMember(_Handler):
|
||||
def __init__(
|
||||
self,
|
||||
chat_member_types: int = -1,
|
||||
*,
|
||||
block: DVInput[bool] = DEFAULT_TRUE,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super().__init__(chat_member_types=chat_member_types, block=block, dispatcher=dispatcher)
|
||||
|
||||
|
||||
class _ChosenInlineResult(_Handler):
|
||||
def __init__(
|
||||
self,
|
||||
block: DVInput[bool] = DEFAULT_TRUE,
|
||||
*,
|
||||
pattern: Union[str, Pattern] = None,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super().__init__(block=block, pattern=pattern, dispatcher=dispatcher)
|
||||
|
||||
|
||||
class _Command(_Handler):
|
||||
def __init__(
|
||||
self,
|
||||
command: Union[str, List[str]],
|
||||
filters: "BaseFilter" = None,
|
||||
*,
|
||||
block: DVInput[bool] = DEFAULT_TRUE,
|
||||
admin: bool = False,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super(_Command, self).__init__(
|
||||
command=command, filters=filters, block=block, admin=admin, dispatcher=dispatcher
|
||||
)
|
||||
|
||||
|
||||
class _InlineQuery(_Handler):
|
||||
def __init__(
|
||||
self,
|
||||
pattern: Union[str, Pattern] = None,
|
||||
chat_types: List[str] = None,
|
||||
*,
|
||||
block: DVInput[bool] = DEFAULT_TRUE,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super(_InlineQuery, self).__init__(pattern=pattern, block=block, chat_types=chat_types, dispatcher=dispatcher)
|
||||
|
||||
|
||||
class _Message(_Handler):
|
||||
def __init__(
|
||||
self,
|
||||
filters: BaseFilter,
|
||||
*,
|
||||
block: DVInput[bool] = DEFAULT_TRUE,
|
||||
admin: bool = False,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
) -> None:
|
||||
super(_Message, self).__init__(filters=filters, block=block, admin=admin, dispatcher=dispatcher)
|
||||
|
||||
|
||||
class _PollAnswer(_Handler):
|
||||
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
|
||||
super(_PollAnswer, self).__init__(block=block, dispatcher=dispatcher)
|
||||
|
||||
|
||||
class _Poll(_Handler):
|
||||
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
|
||||
super(_Poll, self).__init__(block=block, dispatcher=dispatcher)
|
||||
|
||||
|
||||
class _PreCheckoutQuery(_Handler):
|
||||
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
|
||||
super(_PreCheckoutQuery, self).__init__(block=block, dispatcher=dispatcher)
|
||||
|
||||
|
||||
class _Prefix(_Handler):
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
command: str,
|
||||
filters: BaseFilter = None,
|
||||
*,
|
||||
block: DVInput[bool] = DEFAULT_TRUE,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super(_Prefix, self).__init__(
|
||||
prefix=prefix, command=command, filters=filters, block=block, dispatcher=dispatcher
|
||||
)
|
||||
|
||||
|
||||
class _ShippingQuery(_Handler):
|
||||
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
|
||||
super(_ShippingQuery, self).__init__(block=block, dispatcher=dispatcher)
|
||||
|
||||
|
||||
class _StringCommand(_Handler):
|
||||
def __init__(
|
||||
self,
|
||||
command: str,
|
||||
*,
|
||||
admin: bool = False,
|
||||
block: DVInput[bool] = DEFAULT_TRUE,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super(_StringCommand, self).__init__(command=command, block=block, admin=admin, dispatcher=dispatcher)
|
||||
|
||||
|
||||
class _StringRegex(_Handler):
|
||||
def __init__(
|
||||
self,
|
||||
pattern: Union[str, Pattern],
|
||||
*,
|
||||
block: DVInput[bool] = DEFAULT_TRUE,
|
||||
admin: bool = False,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super(_StringRegex, self).__init__(pattern=pattern, block=block, admin=admin, dispatcher=dispatcher)
|
||||
|
||||
|
||||
class _Type(_Handler):
|
||||
# noinspection PyShadowingBuiltins
|
||||
def __init__(
|
||||
self,
|
||||
type: Type[UT], # pylint: disable=W0622
|
||||
strict: bool = False,
|
||||
*,
|
||||
block: DVInput[bool] = DEFAULT_TRUE,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
): # pylint: disable=redefined-builtin
|
||||
super(_Type, self).__init__(type=type, strict=strict, block=block, dispatcher=dispatcher)
|
||||
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
class handler(_Handler):
|
||||
callback_query = _CallbackQuery
|
||||
chat_join_request = _ChatJoinRequest
|
||||
chat_member = _ChatMember
|
||||
chosen_inline_result = _ChosenInlineResult
|
||||
command = _Command
|
||||
inline_query = _InlineQuery
|
||||
message = _Message
|
||||
poll_answer = _PollAnswer
|
||||
pool = _Poll
|
||||
pre_checkout_query = _PreCheckoutQuery
|
||||
prefix = _Prefix
|
||||
shipping_query = _ShippingQuery
|
||||
string_command = _StringCommand
|
||||
string_regex = _StringRegex
|
||||
type = _Type
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handler_type: Union[Callable[P, "HandlerType"], Type["HandlerType"]],
|
||||
*,
|
||||
admin: bool = False,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
self._type = handler_type
|
||||
super().__init__(admin=admin, dispatcher=dispatcher, **kwargs)
|
||||
|
||||
|
||||
class ConversationDataType(Enum):
|
||||
"""conversation handler 的类型"""
|
||||
|
||||
Entry = "entry"
|
||||
State = "state"
|
||||
Fallback = "fallback"
|
||||
|
||||
|
||||
class ConversationData(BaseModel):
|
||||
"""用于储存 conversation handler 的数据"""
|
||||
|
||||
type: ConversationDataType
|
||||
state: Optional[Any] = None
|
||||
|
||||
|
||||
class _ConversationType:
|
||||
_type: ClassVar[ConversationDataType]
|
||||
|
||||
def __init_subclass__(cls, **kwargs) -> None:
|
||||
cls._type = ConversationDataType(cls.__name__.lstrip("_").lower())
|
||||
|
||||
|
||||
def _entry(func: Callable[P, R]) -> Callable[P, R]:
|
||||
setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=ConversationDataType.Entry))
|
||||
|
||||
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
||||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
class _State(_ConversationType):
|
||||
def __init__(self, state: Any) -> None:
|
||||
self.state = state
|
||||
|
||||
def __call__(self, func: Callable[P, T] = None) -> Callable[P, T]:
|
||||
setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=self._type, state=self.state))
|
||||
return func
|
||||
|
||||
|
||||
def _fallback(func: Callable[P, R]) -> Callable[P, R]:
|
||||
setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=ConversationDataType.Fallback))
|
||||
|
||||
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
||||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
class conversation(_Handler):
|
||||
entry_point = _entry
|
||||
state = _State
|
||||
fallback = _fallback
|
||||
|
||||
|
||||
@dataclass(init=True)
|
||||
class ErrorHandlerData:
|
||||
block: bool
|
||||
func: Optional[Callable] = None
|
||||
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
class error_handler:
|
||||
_func: Callable[P, R]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
block: bool = DEFAULT_TRUE,
|
||||
):
|
||||
self._block = block
|
||||
|
||||
def __call__(self, func: Callable[P, T]) -> Callable[P, T]:
|
||||
self._func = func
|
||||
wraps(func, assigned=WRAPPER_ASSIGNMENTS)(self)
|
||||
|
||||
handler_datas = getattr(func, ERROR_HANDLER_ATTR_NAME, [])
|
||||
handler_datas.append(ErrorHandlerData(block=self._block))
|
||||
setattr(self._func, ERROR_HANDLER_ATTR_NAME, handler_datas)
|
||||
|
||||
return self._func
|
173
core/plugin/_job.py
Normal file
173
core/plugin/_job.py
Normal file
@ -0,0 +1,173 @@
|
||||
"""插件"""
|
||||
import datetime
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
from telegram._utils.types import JSONDict
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
from telegram.ext._utils.types import JobCallback
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.builtins.dispatcher import AbstractDispatcher
|
||||
|
||||
__all__ = ["TimeType", "job", "JobData"]
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
R = TypeVar("R")
|
||||
|
||||
TimeType = Union[float, datetime.timedelta, datetime.datetime, datetime.time]
|
||||
|
||||
_JOB_ATTR_NAME = "_job_data"
|
||||
|
||||
|
||||
@dataclass(init=True)
|
||||
class JobData:
|
||||
name: str
|
||||
data: Any
|
||||
chat_id: int
|
||||
user_id: int
|
||||
type: str
|
||||
job_kwargs: JSONDict = field(default_factory=dict)
|
||||
kwargs: JSONDict = field(default_factory=dict)
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None
|
||||
|
||||
|
||||
class _Job:
|
||||
kwargs: Dict = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = None,
|
||||
data: object = None,
|
||||
chat_id: int = None,
|
||||
user_id: int = None,
|
||||
job_kwargs: JSONDict = None,
|
||||
*,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.name = name
|
||||
self.data = data
|
||||
self.chat_id = chat_id
|
||||
self.user_id = user_id
|
||||
self.job_kwargs = {} if job_kwargs is None else job_kwargs
|
||||
self.kwargs = kwargs
|
||||
if dispatcher is None:
|
||||
from core.builtins.dispatcher import JobDispatcher
|
||||
|
||||
dispatcher = JobDispatcher
|
||||
|
||||
self.dispatcher = dispatcher
|
||||
|
||||
def __call__(self, func: JobCallback) -> JobCallback:
|
||||
data = JobData(
|
||||
name=self.name,
|
||||
data=self.data,
|
||||
chat_id=self.chat_id,
|
||||
user_id=self.user_id,
|
||||
job_kwargs=self.job_kwargs,
|
||||
kwargs=self.kwargs,
|
||||
type=re.sub(r"([A-Z])", lambda x: "_" + x.group().lower(), self.__class__.__name__).lstrip("_"),
|
||||
dispatcher=self.dispatcher,
|
||||
)
|
||||
if hasattr(func, _JOB_ATTR_NAME):
|
||||
job_datas = getattr(func, _JOB_ATTR_NAME)
|
||||
job_datas.append(data)
|
||||
setattr(func, _JOB_ATTR_NAME, job_datas)
|
||||
else:
|
||||
setattr(func, _JOB_ATTR_NAME, [data])
|
||||
return func
|
||||
|
||||
|
||||
class _RunOnce(_Job):
|
||||
def __init__(
|
||||
self,
|
||||
when: TimeType,
|
||||
data: object = None,
|
||||
name: str = None,
|
||||
chat_id: int = None,
|
||||
user_id: int = None,
|
||||
job_kwargs: JSONDict = None,
|
||||
*,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, when=when)
|
||||
|
||||
|
||||
class _RunRepeating(_Job):
|
||||
def __init__(
|
||||
self,
|
||||
interval: Union[float, datetime.timedelta],
|
||||
first: TimeType = None,
|
||||
last: TimeType = None,
|
||||
data: object = None,
|
||||
name: str = None,
|
||||
chat_id: int = None,
|
||||
user_id: int = None,
|
||||
job_kwargs: JSONDict = None,
|
||||
*,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, interval=interval, first=first, last=last
|
||||
)
|
||||
|
||||
|
||||
class _RunMonthly(_Job):
|
||||
def __init__(
|
||||
self,
|
||||
when: datetime.time,
|
||||
day: int,
|
||||
data: object = None,
|
||||
name: str = None,
|
||||
chat_id: int = None,
|
||||
user_id: int = None,
|
||||
job_kwargs: JSONDict = None,
|
||||
*,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, when=when, day=day)
|
||||
|
||||
|
||||
class _RunDaily(_Job):
|
||||
def __init__(
|
||||
self,
|
||||
time: datetime.time,
|
||||
days: Tuple[int, ...] = tuple(range(7)),
|
||||
data: object = None,
|
||||
name: str = None,
|
||||
chat_id: int = None,
|
||||
user_id: int = None,
|
||||
job_kwargs: JSONDict = None,
|
||||
*,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, time=time, days=days)
|
||||
|
||||
|
||||
class _RunCustom(_Job):
|
||||
def __init__(
|
||||
self,
|
||||
data: object = None,
|
||||
name: str = None,
|
||||
chat_id: int = None,
|
||||
user_id: int = None,
|
||||
job_kwargs: JSONDict = None,
|
||||
*,
|
||||
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||
):
|
||||
super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher)
|
||||
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
class job:
|
||||
run_once = _RunOnce
|
||||
run_repeating = _RunRepeating
|
||||
run_monthly = _RunMonthly
|
||||
run_daily = _RunDaily
|
||||
run_custom = _RunCustom
|
303
core/plugin/_plugin.py
Normal file
303
core/plugin/_plugin.py
Normal file
@ -0,0 +1,303 @@
|
||||
"""插件"""
|
||||
import asyncio
|
||||
from abc import ABC
|
||||
from dataclasses import asdict
|
||||
from datetime import timedelta
|
||||
from functools import partial, wraps
|
||||
from itertools import chain
|
||||
from multiprocessing import RLock as Lock
|
||||
from types import MethodType
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from telegram.ext import BaseHandler, ConversationHandler, Job, TypeHandler
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from core.handler.adminhandler import AdminHandler
|
||||
from core.plugin._funcs import ConversationFuncs, PluginFuncs
|
||||
from core.plugin._handler import ConversationDataType
|
||||
from utils.const import WRAPPER_ASSIGNMENTS
|
||||
from utils.helpers import isabstract
|
||||
from utils.log import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.application import Application
|
||||
from core.plugin._handler import ConversationData, HandlerData, ErrorHandlerData
|
||||
from core.plugin._job import JobData
|
||||
from multiprocessing.synchronize import RLock as LockType
|
||||
|
||||
__all__ = ("Plugin", "PluginType", "get_all_plugins")
|
||||
|
||||
wraps = partial(wraps, assigned=WRAPPER_ASSIGNMENTS)
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
R = TypeVar("R")
|
||||
|
||||
HandlerType = TypeVar("HandlerType", bound=BaseHandler)
|
||||
|
||||
_HANDLER_DATA_ATTR_NAME = "_handler_datas"
|
||||
"""用于储存生成 handler 时所需要的参数(例如 block)的属性名"""
|
||||
|
||||
_CONVERSATION_HANDLER_ATTR_NAME = "_conversation_handler_data"
|
||||
"""用于储存生成 ConversationHandler 时所需要的参数(例如 block)的属性名"""
|
||||
|
||||
_ERROR_HANDLER_ATTR_NAME = "_error_handler_data"
|
||||
|
||||
_JOB_ATTR_NAME = "_job_data"
|
||||
|
||||
_EXCLUDE_ATTRS = ["handlers", "jobs", "error_handlers"]
|
||||
|
||||
|
||||
class _Plugin(PluginFuncs):
|
||||
"""插件"""
|
||||
|
||||
_lock: ClassVar["LockType"] = Lock()
|
||||
_asyncio_lock: ClassVar["LockType"] = asyncio.Lock()
|
||||
_installed: bool = False
|
||||
|
||||
_handlers: Optional[List[HandlerType]] = None
|
||||
_error_handlers: Optional[List["ErrorHandlerData"]] = None
|
||||
_jobs: Optional[List[Job]] = None
|
||||
_application: "Optional[Application]" = None
|
||||
|
||||
def set_application(self, application: "Application") -> None:
|
||||
self._application = application
|
||||
|
||||
@property
|
||||
def application(self) -> "Application":
|
||||
if self._application is None:
|
||||
raise RuntimeError("No application was set for this Plugin.")
|
||||
return self._application
|
||||
|
||||
@property
|
||||
def handlers(self) -> List[HandlerType]:
|
||||
"""该插件的所有 handler"""
|
||||
with self._lock:
|
||||
if self._handlers is None:
|
||||
self._handlers = []
|
||||
|
||||
for attr in dir(self):
|
||||
if (
|
||||
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
||||
and isinstance(func := getattr(self, attr), MethodType)
|
||||
and (datas := getattr(func, _HANDLER_DATA_ATTR_NAME, []))
|
||||
):
|
||||
for data in datas:
|
||||
data: "HandlerData"
|
||||
if data.admin:
|
||||
self._handlers.append(
|
||||
AdminHandler(
|
||||
handler=data.type(
|
||||
callback=func,
|
||||
**data.kwargs,
|
||||
),
|
||||
application=self.application,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._handlers.append(
|
||||
data.type(
|
||||
callback=func,
|
||||
**data.kwargs,
|
||||
)
|
||||
)
|
||||
return self._handlers
|
||||
|
||||
@property
|
||||
def error_handlers(self) -> List["ErrorHandlerData"]:
|
||||
with self._lock:
|
||||
if self._error_handlers is None:
|
||||
self._error_handlers = []
|
||||
for attr in dir(self):
|
||||
if (
|
||||
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
||||
and isinstance(func := getattr(self, attr), MethodType)
|
||||
and (datas := getattr(func, _ERROR_HANDLER_ATTR_NAME, []))
|
||||
):
|
||||
for data in datas:
|
||||
data: "ErrorHandlerData"
|
||||
data.func = func
|
||||
self._error_handlers.append(data)
|
||||
|
||||
return self._error_handlers
|
||||
|
||||
def _install_jobs(self) -> None:
|
||||
if self._jobs is None:
|
||||
self._jobs = []
|
||||
for attr in dir(self):
|
||||
# noinspection PyUnboundLocalVariable
|
||||
if (
|
||||
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
||||
and isinstance(func := getattr(self, attr), MethodType)
|
||||
and (datas := getattr(func, _JOB_ATTR_NAME, []))
|
||||
):
|
||||
for data in datas:
|
||||
data: "JobData"
|
||||
self._jobs.append(
|
||||
getattr(self.application.telegram.job_queue, data.type)(
|
||||
callback=func,
|
||||
**data.kwargs,
|
||||
**{
|
||||
key: value
|
||||
for key, value in asdict(data).items()
|
||||
if key not in ["type", "kwargs", "dispatcher"]
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def jobs(self) -> List[Job]:
|
||||
with self._lock:
|
||||
if self._jobs is None:
|
||||
self._jobs = []
|
||||
self._install_jobs()
|
||||
return self._jobs
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""初始化插件"""
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""销毁插件"""
|
||||
|
||||
async def install(self) -> None:
|
||||
"""安装"""
|
||||
group = id(self)
|
||||
if not self._installed:
|
||||
await self.initialize()
|
||||
# initialize 必须先执行 如果出现异常不会执行 add_handler 以免出现问题
|
||||
async with self._asyncio_lock:
|
||||
self._install_jobs()
|
||||
|
||||
for h in self.handlers:
|
||||
if not isinstance(h, TypeHandler):
|
||||
self.application.telegram.add_handler(h, group)
|
||||
else:
|
||||
self.application.telegram.add_handler(h, -1)
|
||||
|
||||
for h in self.error_handlers:
|
||||
self.application.telegram.add_error_handler(h.func, h.block)
|
||||
self._installed = True
|
||||
|
||||
async def uninstall(self) -> None:
|
||||
"""卸载"""
|
||||
group = id(self)
|
||||
|
||||
with self._lock:
|
||||
if self._installed:
|
||||
if group in self.application.telegram.handlers:
|
||||
del self.application.telegram.handlers[id(self)]
|
||||
|
||||
for h in self.handlers:
|
||||
if isinstance(h, TypeHandler):
|
||||
self.application.telegram.remove_handler(h, -1)
|
||||
for h in self.error_handlers:
|
||||
self.application.telegram.remove_error_handler(h.func)
|
||||
|
||||
for j in self.application.telegram.job_queue.jobs():
|
||||
j.schedule_removal()
|
||||
await self.shutdown()
|
||||
self._installed = False
|
||||
|
||||
async def reload(self) -> None:
|
||||
await self.uninstall()
|
||||
await self.install()
|
||||
|
||||
|
||||
class _Conversation(_Plugin, ConversationFuncs, ABC):
|
||||
"""Conversation类"""
|
||||
|
||||
# noinspection SpellCheckingInspection
|
||||
class Config(BaseModel):
|
||||
allow_reentry: bool = False
|
||||
per_chat: bool = True
|
||||
per_user: bool = True
|
||||
per_message: bool = False
|
||||
conversation_timeout: Optional[Union[float, timedelta]] = None
|
||||
name: Optional[str] = None
|
||||
map_to_parent: Optional[Dict[object, object]] = None
|
||||
block: bool = False
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
cls._conversation_kwargs = kwargs
|
||||
super(_Conversation, cls).__init_subclass__()
|
||||
return cls
|
||||
|
||||
@property
|
||||
def handlers(self) -> List[HandlerType]:
|
||||
with self._lock:
|
||||
if self._handlers is None:
|
||||
self._handlers = []
|
||||
|
||||
entry_points: List[HandlerType] = []
|
||||
states: Dict[Any, List[HandlerType]] = {}
|
||||
fallbacks: List[HandlerType] = []
|
||||
for attr in dir(self):
|
||||
if (
|
||||
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
||||
and (func := getattr(self, attr, None)) is not None
|
||||
and (datas := getattr(func, _HANDLER_DATA_ATTR_NAME, []))
|
||||
):
|
||||
conversation_data: "ConversationData"
|
||||
|
||||
handlers: List[HandlerType] = []
|
||||
for data in datas:
|
||||
handlers.append(
|
||||
data.type(
|
||||
callback=func,
|
||||
**data.kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
if conversation_data := getattr(func, _CONVERSATION_HANDLER_ATTR_NAME, None):
|
||||
if (_type := conversation_data.type) is ConversationDataType.Entry:
|
||||
entry_points.extend(handlers)
|
||||
elif _type is ConversationDataType.State:
|
||||
if conversation_data.state in states:
|
||||
states[conversation_data.state].extend(handlers)
|
||||
else:
|
||||
states[conversation_data.state] = handlers
|
||||
elif _type is ConversationDataType.Fallback:
|
||||
fallbacks.extend(handlers)
|
||||
else:
|
||||
self._handlers.extend(handlers)
|
||||
else:
|
||||
self._handlers.extend(handlers)
|
||||
if entry_points and states and fallbacks:
|
||||
kwargs = self._conversation_kwargs
|
||||
kwargs.update(self.Config().dict())
|
||||
self._handlers.append(ConversationHandler(entry_points, states, fallbacks, **kwargs))
|
||||
else:
|
||||
temp_dict = {"entry_points": entry_points, "states": states, "fallbacks": fallbacks}
|
||||
reason = map(lambda x: f"'{x[0]}'", filter(lambda x: not x[1], temp_dict.items()))
|
||||
logger.warning(
|
||||
"'%s' 因缺少 '%s' 而生成无法生成 ConversationHandler", self.__class__.__name__, ", ".join(reason)
|
||||
)
|
||||
return self._handlers
|
||||
|
||||
|
||||
class Plugin(_Plugin, ABC):
|
||||
"""插件"""
|
||||
|
||||
Conversation = _Conversation
|
||||
|
||||
|
||||
PluginType = TypeVar("PluginType", bound=_Plugin)
|
||||
|
||||
|
||||
def get_all_plugins() -> Iterable[Type[PluginType]]:
|
||||
"""获取所有 Plugin 的子类"""
|
||||
return filter(
|
||||
lambda x: x.__name__[0] != "_" and not isabstract(x),
|
||||
chain(Plugin.__subclasses__(), _Conversation.__subclasses__()),
|
||||
)
|
@ -1,14 +0,0 @@
|
||||
from core.base.mysql import MySQL
|
||||
from core.base.redisdb import RedisDB
|
||||
from core.service import init_service
|
||||
from .cache import QuizCache
|
||||
from .repositories import QuizRepository
|
||||
from .services import QuizService
|
||||
|
||||
|
||||
@init_service
|
||||
def create_quiz_service(mysql: MySQL, redis: RedisDB):
|
||||
_repository = QuizRepository(mysql)
|
||||
_cache = QuizCache(redis)
|
||||
_service = QuizService(_repository, _cache)
|
||||
return _service
|
@ -1,19 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from .models import Answer, Question
|
||||
|
||||
|
||||
def CreatQuestionFromSQLData(data: tuple) -> List[Question]:
|
||||
temp_list = []
|
||||
for temp_data in data:
|
||||
(question_id, text) = temp_data
|
||||
temp_list.append(Question(question_id, text))
|
||||
return temp_list
|
||||
|
||||
|
||||
def CreatAnswerFromSQLData(data: tuple) -> List[Answer]:
|
||||
temp_list = []
|
||||
for temp_data in data:
|
||||
(answer_id, question_id, is_correct, text) = temp_data
|
||||
temp_list.append(Answer(answer_id, question_id, is_correct, text))
|
||||
return temp_list
|
@ -1,98 +0,0 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlmodel import SQLModel, Field, Column, Integer, ForeignKey
|
||||
|
||||
from utils.baseobject import BaseObject
|
||||
from utils.typedefs import JSONDict
|
||||
|
||||
|
||||
class AnswerDB(SQLModel, table=True):
|
||||
__tablename__ = "answer"
|
||||
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
||||
|
||||
id: int = Field(primary_key=True)
|
||||
question_id: Optional[int] = Field(
|
||||
sa_column=Column(Integer, ForeignKey("question.id", ondelete="RESTRICT", onupdate="RESTRICT"))
|
||||
)
|
||||
is_correct: Optional[bool] = Field()
|
||||
text: Optional[str] = Field()
|
||||
|
||||
|
||||
class QuestionDB(SQLModel, table=True):
|
||||
__tablename__ = "question"
|
||||
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
||||
|
||||
id: int = Field(primary_key=True)
|
||||
text: Optional[str] = Field()
|
||||
|
||||
|
||||
class Answer(BaseObject):
|
||||
def __init__(self, answer_id: int = 0, question_id: int = 0, is_correct: bool = True, text: str = ""):
|
||||
"""Answer类
|
||||
|
||||
:param answer_id: 答案ID
|
||||
:param question_id: 与之对应的问题ID
|
||||
:param is_correct: 该答案是否正确
|
||||
:param text: 答案文本
|
||||
"""
|
||||
self.answer_id = answer_id
|
||||
self.question_id = question_id
|
||||
self.text = text
|
||||
self.is_correct = is_correct
|
||||
|
||||
__slots__ = ("answer_id", "question_id", "text", "is_correct")
|
||||
|
||||
def to_database_data(self) -> AnswerDB:
|
||||
data = AnswerDB()
|
||||
data.id = self.answer_id
|
||||
data.question_id = self.question_id
|
||||
data.text = self.text
|
||||
data.is_correct = self.is_correct
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def de_database_data(cls, data: Optional[AnswerDB]) -> Optional["Answer"]:
|
||||
if data is None:
|
||||
return cls()
|
||||
return cls(answer_id=data.id, question_id=data.question_id, text=data.text, is_correct=data.is_correct)
|
||||
|
||||
|
||||
class Question(BaseObject):
|
||||
def __init__(self, question_id: int = 0, text: str = "", answers: List[Answer] = None):
|
||||
"""Question类
|
||||
|
||||
:param question_id: 问题ID
|
||||
:param text: 问题文本
|
||||
:param answers: 答案列表
|
||||
"""
|
||||
self.question_id = question_id
|
||||
self.text = text
|
||||
self.answers = [] if answers is None else answers
|
||||
|
||||
def to_database_data(self) -> QuestionDB:
|
||||
data = QuestionDB()
|
||||
data.text = self.text
|
||||
data.id = self.question_id
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def de_database_data(cls, data: Optional[QuestionDB]) -> Optional["Question"]:
|
||||
if data is None:
|
||||
return cls()
|
||||
return cls(question_id=data.id, text=data.text)
|
||||
|
||||
def to_dict(self) -> JSONDict:
|
||||
data = super().to_dict()
|
||||
if self.answers:
|
||||
data["answers"] = [e.to_dict() for e in self.answers]
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def de_json(cls, data: Optional[JSONDict]) -> Optional["Question"]:
|
||||
data = cls._parse_data(data)
|
||||
if not data:
|
||||
return None
|
||||
data["answers"] = Answer.de_list(data.get("answers"))
|
||||
return cls(**data)
|
||||
|
||||
__slots__ = ("question_id", "text", "answers")
|
@ -1,10 +0,0 @@
|
||||
from core.service import init_service
|
||||
from .services import SearchServices as _SearchServices
|
||||
|
||||
__all__ = []
|
||||
|
||||
|
||||
@init_service
|
||||
def create_search_service():
|
||||
_service = _SearchServices()
|
||||
return _service
|
@ -1,31 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable
|
||||
|
||||
from utils.log import logger
|
||||
|
||||
__all__ = ["Service", "init_service"]
|
||||
|
||||
|
||||
class Service(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""初始化"""
|
||||
|
||||
async def start(self):
|
||||
"""启动 service"""
|
||||
|
||||
async def stop(self):
|
||||
"""关闭 service"""
|
||||
|
||||
|
||||
def init_service(func: Callable):
|
||||
from core.bot import bot
|
||||
|
||||
if bot.is_running:
|
||||
try:
|
||||
service = bot.init_inject(func)
|
||||
logger.success(f'服务 "{service.__class__.__name__}" 初始化成功')
|
||||
bot.add_service(service)
|
||||
except Exception as e: # pylint: disable=W0703
|
||||
logger.exception(f"来自{func.__module__}的服务初始化失败:{e}")
|
||||
return func
|
0
core/services/__init__.py
Normal file
0
core/services/__init__.py
Normal file
5
core/services/cookies/__init__.py
Normal file
5
core/services/cookies/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
"""CookieService"""
|
||||
|
||||
from core.services.cookies.services import CookiesService, PublicCookiesService
|
||||
|
||||
__all__ = ("CookiesService", "PublicCookiesService")
|
@ -1,12 +1,15 @@
|
||||
from typing import List, Union
|
||||
|
||||
from core.base.redisdb import RedisDB
|
||||
from core.base_service import BaseService
|
||||
from core.basemodel import RegionEnum
|
||||
from core.dependence.redisdb import RedisDB
|
||||
from core.services.cookies.error import CookiesCachePoolExhausted
|
||||
from utils.error import RegionNotFoundError
|
||||
from utils.models.base import RegionEnum
|
||||
from .error import CookiesCachePoolExhausted
|
||||
|
||||
__all__ = ("PublicCookiesCache",)
|
||||
|
||||
|
||||
class PublicCookiesCache:
|
||||
class PublicCookiesCache(BaseService.Component):
|
||||
"""使用优先级(score)进行排序,对使用次数最少的Cookies进行审核"""
|
||||
|
||||
def __init__(self, redis: RedisDB):
|
||||
@ -19,10 +22,9 @@ class PublicCookiesCache:
|
||||
def get_public_cookies_queue_name(self, region: RegionEnum):
|
||||
if region == RegionEnum.HYPERION:
|
||||
return f"{self.score_qname}:yuanshen"
|
||||
elif region == RegionEnum.HOYOLAB:
|
||||
if region == RegionEnum.HOYOLAB:
|
||||
return f"{self.score_qname}:genshin"
|
||||
else:
|
||||
raise RegionNotFoundError(region.name)
|
||||
raise RegionNotFoundError(region.name)
|
||||
|
||||
async def putback_public_cookies(self, uid: int, region: RegionEnum):
|
||||
"""重新添加单个到缓存列表
|
@ -7,11 +7,6 @@ class CookiesCachePoolExhausted(CookieServiceError):
|
||||
super().__init__("Cookies cache pool is exhausted")
|
||||
|
||||
|
||||
class CookiesNotFoundError(CookieServiceError):
|
||||
def __init__(self, user_id):
|
||||
super().__init__(f"{user_id} cookies not found")
|
||||
|
||||
|
||||
class TooManyRequestPublicCookies(CookieServiceError):
|
||||
def __init__(self, user_id):
|
||||
super().__init__(f"{user_id} too many request public cookies")
|
39
core/services/cookies/models.py
Normal file
39
core/services/cookies/models.py
Normal file
@ -0,0 +1,39 @@
|
||||
import enum
|
||||
from typing import Optional, Dict
|
||||
|
||||
from sqlmodel import SQLModel, Field, Boolean, Column, Enum, JSON, Integer, BigInteger, Index
|
||||
|
||||
from core.basemodel import RegionEnum
|
||||
|
||||
__all__ = ("Cookies", "CookiesDataBase", "CookiesStatusEnum")
|
||||
|
||||
|
||||
class CookiesStatusEnum(int, enum.Enum):
|
||||
STATUS_SUCCESS = 0
|
||||
INVALID_COOKIES = 1
|
||||
TOO_MANY_REQUESTS = 2
|
||||
|
||||
|
||||
class Cookies(SQLModel):
|
||||
__table_args__ = (
|
||||
Index("index_user_account", "user_id", "account_id", unique=True),
|
||||
dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"),
|
||||
)
|
||||
id: Optional[int] = Field(default=None, sa_column=Column(Integer, primary_key=True, autoincrement=True))
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger()),
|
||||
)
|
||||
account_id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger(),
|
||||
),
|
||||
)
|
||||
data: Optional[Dict[str, str]] = Field(sa_column=Column(JSON))
|
||||
status: Optional[CookiesStatusEnum] = Field(sa_column=Column(Enum(CookiesStatusEnum)))
|
||||
region: RegionEnum = Field(sa_column=Column(Enum(RegionEnum)))
|
||||
is_share: Optional[bool] = Field(sa_column=Column(Boolean))
|
||||
|
||||
|
||||
class CookiesDataBase(Cookies, table=True):
|
||||
__tablename__ = "cookies"
|
55
core/services/cookies/repositories.py
Normal file
55
core/services/cookies/repositories.py
Normal file
@ -0,0 +1,55 @@
|
||||
from typing import Optional, List
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.basemodel import RegionEnum
|
||||
from core.dependence.mysql import MySQL
|
||||
from core.services.cookies.models import CookiesDataBase as Cookies
|
||||
from core.sqlmodel.session import AsyncSession
|
||||
|
||||
__all__ = ("CookiesRepository",)
|
||||
|
||||
|
||||
class CookiesRepository(BaseService.Component):
|
||||
def __init__(self, mysql: MySQL):
|
||||
self.engine = mysql.engine
|
||||
|
||||
async def get(
|
||||
self,
|
||||
user_id: int,
|
||||
account_id: Optional[int] = None,
|
||||
region: Optional[RegionEnum] = None,
|
||||
) -> Optional[Cookies]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = select(Cookies).where(Cookies.user_id == user_id)
|
||||
if account_id is not None:
|
||||
statement = statement.where(Cookies.account_id == account_id)
|
||||
if region is not None:
|
||||
statement = statement.where(Cookies.region == region)
|
||||
results = await session.exec(statement)
|
||||
return results.first()
|
||||
|
||||
async def add(self, cookies: Cookies) -> None:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(cookies)
|
||||
await session.commit()
|
||||
|
||||
async def update(self, cookies: Cookies) -> Cookies:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(cookies)
|
||||
await session.commit()
|
||||
await session.refresh(cookies)
|
||||
return cookies
|
||||
|
||||
async def delete(self, cookies: Cookies) -> None:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
await session.delete(cookies)
|
||||
await session.commit()
|
||||
|
||||
async def get_all_by_region(self, region: RegionEnum) -> List[Cookies]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = select(Cookies).where(Cookies.region == region)
|
||||
results = await session.exec(statement)
|
||||
cookies = results.all()
|
||||
return cookies
|
@ -1,67 +1,73 @@
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import genshin
|
||||
from genshin import GenshinException, InvalidCookies, TooManyRequests, types, Game
|
||||
from genshin import Game, GenshinException, InvalidCookies, TooManyRequests, types
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.basemodel import RegionEnum
|
||||
from core.services.cookies.cache import PublicCookiesCache
|
||||
from core.services.cookies.error import CookieServiceError, TooManyRequestPublicCookies
|
||||
from core.services.cookies.models import CookiesDataBase as Cookies, CookiesStatusEnum
|
||||
from core.services.cookies.repositories import CookiesRepository
|
||||
from utils.log import logger
|
||||
from utils.models.base import RegionEnum
|
||||
from .cache import PublicCookiesCache
|
||||
from .error import TooManyRequestPublicCookies, CookieServiceError
|
||||
from .models import CookiesStatusEnum
|
||||
from .repositories import CookiesNotFoundError, CookiesRepository
|
||||
|
||||
__all__ = ("CookiesService", "PublicCookiesService")
|
||||
|
||||
|
||||
class CookiesService:
|
||||
class CookiesService(BaseService):
|
||||
def __init__(self, cookies_repository: CookiesRepository) -> None:
|
||||
self._repository: CookiesRepository = cookies_repository
|
||||
|
||||
async def update_cookies(self, user_id: int, cookies: dict, region: RegionEnum):
|
||||
await self._repository.update_cookies(user_id, cookies, region)
|
||||
async def update(self, cookies: Cookies):
|
||||
await self._repository.update(cookies)
|
||||
|
||||
async def add_cookies(self, user_id: int, cookies: dict, region: RegionEnum):
|
||||
await self._repository.add_cookies(user_id, cookies, region)
|
||||
async def add(self, cookies: Cookies):
|
||||
await self._repository.add(cookies)
|
||||
|
||||
async def get_cookies(self, user_id: int, region: RegionEnum):
|
||||
return await self._repository.get_cookies(user_id, region)
|
||||
async def get(
|
||||
self,
|
||||
user_id: int,
|
||||
account_id: Optional[int] = None,
|
||||
region: Optional[RegionEnum] = None,
|
||||
) -> Optional[Cookies]:
|
||||
return await self._repository.get(user_id, account_id, region)
|
||||
|
||||
async def del_cookies(self, user_id: int, region: RegionEnum):
|
||||
return await self._repository.del_cookies(user_id, region)
|
||||
|
||||
async def add_or_update_cookies(self, user_id: int, cookies: dict, region: RegionEnum):
|
||||
try:
|
||||
await self.get_cookies(user_id, region)
|
||||
await self.update_cookies(user_id, cookies, region)
|
||||
except CookiesNotFoundError:
|
||||
await self.add_cookies(user_id, cookies, region)
|
||||
async def delete(self, cookies: Cookies) -> None:
|
||||
return await self._repository.delete(cookies)
|
||||
|
||||
|
||||
class PublicCookiesService:
|
||||
class PublicCookiesService(BaseService):
|
||||
def __init__(self, cookies_repository: CookiesRepository, public_cookies_cache: PublicCookiesCache):
|
||||
self._cache = public_cookies_cache
|
||||
self._repository: CookiesRepository = cookies_repository
|
||||
self.count: int = 0
|
||||
self.user_times_limiter = 3 * 3
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.info("正在初始化公共Cookies池")
|
||||
await self.refresh()
|
||||
logger.success("刷新公共Cookies池成功")
|
||||
|
||||
async def refresh(self):
|
||||
"""刷新公共Cookies 定时任务
|
||||
:return:
|
||||
"""
|
||||
user_list: List[int] = []
|
||||
cookies_list = await self._repository.get_all_cookies(RegionEnum.HYPERION) # 从数据库获取2
|
||||
cookies_list = await self._repository.get_all_by_region(RegionEnum.HYPERION) # 从数据库获取2
|
||||
for cookies in cookies_list:
|
||||
if cookies.status is None or cookies.status == CookiesStatusEnum.STATUS_SUCCESS:
|
||||
user_list.append(cookies.user_id)
|
||||
if len(user_list) > 0:
|
||||
add, count = await self._cache.add_public_cookies(user_list, RegionEnum.HYPERION)
|
||||
logger.info(f"国服公共Cookies池已经添加[{add}]个 当前成员数为[{count}]")
|
||||
logger.info("国服公共Cookies池已经添加[%s]个 当前成员数为[%s]", add, count)
|
||||
user_list.clear()
|
||||
cookies_list = await self._repository.get_all_cookies(RegionEnum.HOYOLAB)
|
||||
cookies_list = await self._repository.get_all_by_region(RegionEnum.HOYOLAB)
|
||||
for cookies in cookies_list:
|
||||
if cookies.status is None or cookies.status == CookiesStatusEnum.STATUS_SUCCESS:
|
||||
user_list.append(cookies.user_id)
|
||||
if len(user_list) > 0:
|
||||
add, count = await self._cache.add_public_cookies(user_list, RegionEnum.HOYOLAB)
|
||||
logger.info(f"国际服公共Cookies池已经添加[{add}]个 当前成员数为[{count}]")
|
||||
logger.info("国际服公共Cookies池已经添加[%s]个 当前成员数为[%s]", add, count)
|
||||
|
||||
async def get_cookies(self, user_id: int, region: RegionEnum = RegionEnum.NULL):
|
||||
"""获取公共Cookies
|
||||
@ -71,20 +77,19 @@ class PublicCookiesService:
|
||||
"""
|
||||
user_times = await self._cache.incr_by_user_times(user_id)
|
||||
if int(user_times) > self.user_times_limiter:
|
||||
logger.warning(f"用户 [{user_id}] 使用公共Cookie次数已经到达上限")
|
||||
logger.warning("用户 %s 使用公共Cookie次数已经到达上限", user_id)
|
||||
raise TooManyRequestPublicCookies(user_id)
|
||||
while True:
|
||||
public_id, count = await self._cache.get_public_cookies(region)
|
||||
try:
|
||||
cookies = await self._repository.get_cookies(public_id, region)
|
||||
except CookiesNotFoundError:
|
||||
cookies = await self._repository.get(public_id, region=region)
|
||||
if cookies is None:
|
||||
await self._cache.delete_public_cookies(public_id, region)
|
||||
continue
|
||||
if region == RegionEnum.HYPERION:
|
||||
client = genshin.Client(cookies=cookies.cookies, game=types.Game.GENSHIN, region=types.Region.CHINESE)
|
||||
client = genshin.Client(cookies=cookies.data, game=types.Game.GENSHIN, region=types.Region.CHINESE)
|
||||
elif region == RegionEnum.HOYOLAB:
|
||||
client = genshin.Client(
|
||||
cookies=cookies.cookies, game=types.Game.GENSHIN, region=types.Region.OVERSEAS, lang="zh-cn"
|
||||
cookies=cookies.data, game=types.Game.GENSHIN, region=types.Region.OVERSEAS, lang="zh-cn"
|
||||
)
|
||||
else:
|
||||
raise CookieServiceError
|
||||
@ -101,13 +106,13 @@ class PublicCookiesService:
|
||||
logger.warning("Cookies无效 ")
|
||||
logger.exception(exc)
|
||||
cookies.status = CookiesStatusEnum.INVALID_COOKIES
|
||||
await self._repository.update_cookies_ex(cookies, region)
|
||||
await self._repository.update(cookies)
|
||||
await self._cache.delete_public_cookies(cookies.user_id, region)
|
||||
continue
|
||||
except TooManyRequests:
|
||||
logger.warning("用户 [%s] 查询次数太多或操作频繁", public_id)
|
||||
cookies.status = CookiesStatusEnum.TOO_MANY_REQUESTS
|
||||
await self._repository.update_cookies_ex(cookies, region)
|
||||
await self._repository.update(cookies)
|
||||
await self._cache.delete_public_cookies(cookies.user_id, region)
|
||||
continue
|
||||
except GenshinException as exc:
|
1
core/services/game/__init__.py
Normal file
1
core/services/game/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""GameService"""
|
@ -1,12 +1,16 @@
|
||||
from typing import List
|
||||
|
||||
from core.base.redisdb import RedisDB
|
||||
from core.base_service import BaseService
|
||||
from core.dependence.redisdb import RedisDB
|
||||
|
||||
__all__ = ["GameCache", "GameCacheForStrategy", "GameCacheForMaterial"]
|
||||
|
||||
|
||||
class GameCache:
|
||||
def __init__(self, redis: RedisDB, qname: str, ttl: int = 3600):
|
||||
qname: str
|
||||
|
||||
def __init__(self, redis: RedisDB, ttl: int = 3600):
|
||||
self.client = redis.client
|
||||
self.qname = qname
|
||||
self.ttl = ttl
|
||||
|
||||
async def get_url_list(self, character_name: str):
|
||||
@ -19,3 +23,11 @@ class GameCache:
|
||||
await self.client.lpush(qname, *str_list)
|
||||
await self.client.expire(qname, self.ttl)
|
||||
return await self.client.llen(qname)
|
||||
|
||||
|
||||
class GameCacheForStrategy(BaseService.Component, GameCache):
|
||||
qname = "game:strategy"
|
||||
|
||||
|
||||
class GameCacheForMaterial(BaseService.Component, GameCache):
|
||||
qname = "game:material"
|
@ -1,11 +1,14 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.services.game.cache import GameCacheForMaterial, GameCacheForStrategy
|
||||
from modules.apihelper.client.components.hyperion import Hyperion
|
||||
from .cache import GameCache
|
||||
|
||||
__all__ = ("GameMaterialService", "GameStrategyService")
|
||||
|
||||
|
||||
class GameStrategyService:
|
||||
def __init__(self, cache: GameCache, collections: Optional[List[int]] = None):
|
||||
class GameStrategyService(BaseService):
|
||||
def __init__(self, cache: GameCacheForStrategy, collections: Optional[List[int]] = None):
|
||||
self._cache = cache
|
||||
self._hyperion = Hyperion()
|
||||
if collections is None:
|
||||
@ -49,8 +52,8 @@ class GameStrategyService:
|
||||
return artwork_info.image_urls[0]
|
||||
|
||||
|
||||
class GameMaterialService:
|
||||
def __init__(self, cache: GameCache, collections: Optional[List[int]] = None):
|
||||
class GameMaterialService(BaseService):
|
||||
def __init__(self, cache: GameCacheForMaterial, collections: Optional[List[int]] = None):
|
||||
self._cache = cache
|
||||
self._hyperion = Hyperion()
|
||||
self._collections = [428421, 1164644] if collections is None else collections
|
||||
@ -91,9 +94,8 @@ class GameMaterialService:
|
||||
await self._cache.set_url_list(character_name, image_url_list)
|
||||
if len(image_url_list) == 0:
|
||||
return ""
|
||||
elif len(image_url_list) == 1:
|
||||
if len(image_url_list) == 1:
|
||||
return image_url_list[0]
|
||||
elif character_name in self._special:
|
||||
if character_name in self._special:
|
||||
return image_url_list[2]
|
||||
else:
|
||||
return image_url_list[1]
|
||||
return image_url_list[1]
|
3
core/services/players/__init__.py
Normal file
3
core/services/players/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .services import PlayersService
|
||||
|
||||
__all__ = ("PlayersService",)
|
2
core/services/players/error.py
Normal file
2
core/services/players/error.py
Normal file
@ -0,0 +1,2 @@
|
||||
class PlayerNotFoundError(Exception):
|
||||
pass
|
96
core/services/players/models.py
Normal file
96
core/services/players/models.py
Normal file
@ -0,0 +1,96 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, BaseSettings
|
||||
from sqlalchemy import TypeDecorator
|
||||
from sqlmodel import Boolean, Column, Enum, Field, SQLModel, Integer, Index, BigInteger, VARCHAR, func, DateTime
|
||||
|
||||
from core.basemodel import RegionEnum
|
||||
|
||||
try:
|
||||
import ujson as jsonlib
|
||||
except ImportError:
|
||||
import json as jsonlib
|
||||
|
||||
__all__ = ("Player", "PlayersDataBase", "PlayerInfo", "PlayerInfoSQLModel")
|
||||
|
||||
|
||||
class Player(SQLModel):
|
||||
__table_args__ = (
|
||||
Index("index_user_account_player", "user_id", "account_id", "player_id", unique=True),
|
||||
dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"),
|
||||
)
|
||||
id: Optional[int] = Field(
|
||||
default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True)
|
||||
)
|
||||
user_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
|
||||
account_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
|
||||
player_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
|
||||
region: RegionEnum = Field(sa_column=Column(Enum(RegionEnum)))
|
||||
is_chosen: Optional[bool] = Field(sa_column=Column(Boolean))
|
||||
|
||||
|
||||
class PlayersDataBase(Player, table=True):
|
||||
__tablename__ = "players"
|
||||
|
||||
|
||||
class ExtraPlayerInfo(BaseModel):
|
||||
class Config(BaseSettings.Config):
|
||||
json_loads = jsonlib.loads
|
||||
json_dumps = jsonlib.dumps
|
||||
|
||||
waifu_id: Optional[int] = None
|
||||
|
||||
|
||||
class ExtraPlayerType(TypeDecorator): # pylint: disable=W0223
|
||||
impl = VARCHAR(length=521)
|
||||
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
"""
|
||||
:param value: ExtraPlayerInfo | obj | None
|
||||
:param dialect:
|
||||
:return:
|
||||
"""
|
||||
if value is not None:
|
||||
if isinstance(value, ExtraPlayerInfo):
|
||||
return value.json()
|
||||
raise TypeError
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
"""
|
||||
:param value: str | obj | None
|
||||
:param dialect:
|
||||
:return:
|
||||
"""
|
||||
if value is not None:
|
||||
return ExtraPlayerInfo.parse_raw(value)
|
||||
return None
|
||||
|
||||
|
||||
class PlayerInfo(SQLModel):
|
||||
__table_args__ = (
|
||||
Index("index_user_account_player", "user_id", "player_id", unique=True),
|
||||
dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"),
|
||||
)
|
||||
id: Optional[int] = Field(
|
||||
default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True)
|
||||
)
|
||||
user_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
|
||||
player_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
|
||||
nickname: Optional[str] = Field()
|
||||
signature: Optional[str] = Field()
|
||||
hand_image: Optional[int] = Field()
|
||||
name_card: Optional[int] = Field()
|
||||
extra_data: Optional[ExtraPlayerInfo] = Field(sa_column=Column(ExtraPlayerType))
|
||||
create_time: Optional[datetime] = Field(
|
||||
sa_column=Column(DateTime, server_default=func.now()) # pylint: disable=E1102
|
||||
)
|
||||
last_save_time: Optional[datetime] = Field(sa_column=Column(DateTime, onupdate=func.now())) # pylint: disable=E1102
|
||||
is_update: Optional[bool] = Field(sa_column=Column(Boolean))
|
||||
|
||||
|
||||
class PlayerInfoSQLModel(PlayerInfo, table=True):
|
||||
__tablename__ = "players_info"
|
109
core/services/players/repositories.py
Normal file
109
core/services/players/repositories.py
Normal file
@ -0,0 +1,109 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlmodel import select, delete
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.basemodel import RegionEnum
|
||||
from core.dependence.mysql import MySQL
|
||||
from core.services.players.models import PlayerInfoSQLModel
|
||||
from core.services.players.models import PlayersDataBase as Player
|
||||
from core.sqlmodel.session import AsyncSession
|
||||
|
||||
__all__ = ("PlayersRepository", "PlayerInfoRepository")
|
||||
|
||||
|
||||
class PlayersRepository(BaseService.Component):
|
||||
def __init__(self, mysql: MySQL):
|
||||
self.engine = mysql.engine
|
||||
|
||||
async def get(
|
||||
self,
|
||||
user_id: int,
|
||||
player_id: Optional[int] = None,
|
||||
account_id: Optional[int] = None,
|
||||
region: Optional[RegionEnum] = None,
|
||||
is_chosen: Optional[bool] = None,
|
||||
) -> Optional[Player]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = select(Player).where(Player.user_id == user_id)
|
||||
if player_id is not None:
|
||||
statement = statement.where(Player.player_id == player_id)
|
||||
if account_id is not None:
|
||||
statement = statement.where(Player.account_id == account_id)
|
||||
if region is not None:
|
||||
statement = statement.where(Player.region == region)
|
||||
if is_chosen is not None:
|
||||
statement = statement.where(Player.is_chosen == is_chosen)
|
||||
results = await session.exec(statement)
|
||||
return results.first()
|
||||
|
||||
async def add(self, player: Player) -> None:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(player)
|
||||
await session.commit()
|
||||
|
||||
async def delete(self, player: Player) -> None:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
await session.delete(player)
|
||||
await session.commit()
|
||||
|
||||
async def update(self, player: Player) -> None:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(player)
|
||||
await session.commit()
|
||||
await session.refresh(player)
|
||||
|
||||
async def get_all_by_user_id(self, user_id: int) -> List[Player]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = select(Player).where(Player.user_id == user_id)
|
||||
results = await session.exec(statement)
|
||||
players = results.all()
|
||||
return players
|
||||
|
||||
|
||||
class PlayerInfoRepository(BaseService.Component):
|
||||
def __init__(self, mysql: MySQL):
|
||||
self.engine = mysql.engine
|
||||
|
||||
async def get(
|
||||
self,
|
||||
user_id: int,
|
||||
player_id: int,
|
||||
) -> Optional[PlayerInfoSQLModel]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = (
|
||||
select(PlayerInfoSQLModel)
|
||||
.where(PlayerInfoSQLModel.player_id == player_id)
|
||||
.where(PlayerInfoSQLModel.user_id == user_id)
|
||||
)
|
||||
results = await session.exec(statement)
|
||||
return results.first()
|
||||
|
||||
async def add(self, player: PlayerInfoSQLModel) -> None:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(player)
|
||||
await session.commit()
|
||||
|
||||
async def delete(self, player: PlayerInfoSQLModel) -> None:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
await session.delete(player)
|
||||
await session.commit()
|
||||
|
||||
async def delete_by_id(
|
||||
self,
|
||||
user_id: int,
|
||||
player_id: int,
|
||||
) -> None:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = (
|
||||
delete(PlayerInfoSQLModel)
|
||||
.where(PlayerInfoSQLModel.player_id == player_id)
|
||||
.where(PlayerInfoSQLModel.user_id == user_id)
|
||||
)
|
||||
await session.execute(statement)
|
||||
|
||||
async def update(self, player: PlayerInfoSQLModel) -> None:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(player)
|
||||
await session.commit()
|
||||
await session.refresh(player)
|
184
core/services/players/services.py
Normal file
184
core/services/players/services.py
Normal file
@ -0,0 +1,184 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional
|
||||
|
||||
from aiohttp import ClientConnectorError
|
||||
from enkanetwork import (
|
||||
EnkaNetworkAPI,
|
||||
VaildateUIDError,
|
||||
HTTPException,
|
||||
EnkaPlayerNotFound,
|
||||
PlayerInfo as EnkaPlayerInfo,
|
||||
)
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.basemodel import RegionEnum
|
||||
from core.config import config
|
||||
from core.dependence.redisdb import RedisDB
|
||||
from core.services.players.models import PlayersDataBase as Player, PlayerInfoSQLModel, PlayerInfo
|
||||
from core.services.players.repositories import PlayersRepository, PlayerInfoRepository
|
||||
from utils.enkanetwork import RedisCache
|
||||
from utils.log import logger
|
||||
from utils.patch.aiohttp import AioHttpTimeoutException
|
||||
|
||||
__all__ = ("PlayersService", "PlayerInfoService")
|
||||
|
||||
|
||||
class PlayersService(BaseService):
|
||||
def __init__(self, players_repository: PlayersRepository) -> None:
|
||||
self._repository = players_repository
|
||||
|
||||
async def get(
|
||||
self,
|
||||
user_id: int,
|
||||
player_id: Optional[int] = None,
|
||||
account_id: Optional[int] = None,
|
||||
region: Optional[RegionEnum] = None,
|
||||
is_chosen: Optional[bool] = None,
|
||||
) -> Optional[Player]:
|
||||
return await self._repository.get(user_id, player_id, account_id, region, is_chosen)
|
||||
|
||||
async def get_player(self, user_id: int, region: Optional[RegionEnum] = None) -> Optional[Player]:
|
||||
return await self._repository.get(user_id, region=region, is_chosen=True)
|
||||
|
||||
async def add(self, player: Player) -> None:
|
||||
await self._repository.add(player)
|
||||
|
||||
async def update(self, player: Player) -> None:
|
||||
await self._repository.update(player)
|
||||
|
||||
async def get_all_by_user_id(self, user_id: int) -> List[Player]:
|
||||
return await self._repository.get_all_by_user_id(user_id)
|
||||
|
||||
async def remove_all_by_user_id(self, user_id: int):
|
||||
players = await self._repository.get_all_by_user_id(user_id)
|
||||
for player in players:
|
||||
await self._repository.delete(player)
|
||||
|
||||
async def delete(self, player: Player):
|
||||
await self._repository.delete(player)
|
||||
|
||||
|
||||
class PlayerInfoService(BaseService):
|
||||
def __init__(self, redis: RedisDB, players_info_repository: PlayerInfoRepository):
|
||||
self.cache = redis.client
|
||||
self._players_info_repository = players_info_repository
|
||||
self.enka_client = EnkaNetworkAPI(lang="chs", user_agent=config.enka_network_api_agent)
|
||||
self.enka_client.set_cache(RedisCache(redis.client, key="players_info:enka_network", ttl=60))
|
||||
self.qname = "players_info"
|
||||
|
||||
async def get_form_cache(self, player: Player):
|
||||
qname = f"{self.qname}:{player.user_id}:{player.player_id}"
|
||||
data = await self.cache.get(qname)
|
||||
if data is None:
|
||||
return None
|
||||
json_data = str(data, encoding="utf-8")
|
||||
return PlayerInfo.parse_raw(json_data)
|
||||
|
||||
async def set_form_cache(self, player: PlayerInfo):
|
||||
qname = f"{self.qname}:{player.user_id}:{player.player_id}"
|
||||
await self.cache.set(qname, player.json(), ex=60)
|
||||
|
||||
async def get_player_info_from_enka(self, player_id: int) -> Optional[EnkaPlayerInfo]:
|
||||
try:
|
||||
response = await self.enka_client.fetch_user(player_id, info=True)
|
||||
return response.player
|
||||
except (VaildateUIDError, EnkaPlayerNotFound, HTTPException) as exc:
|
||||
logger.warning("EnkaNetwork 请求失败: %s", str(exc))
|
||||
except AioHttpTimeoutException as exc:
|
||||
logger.warning("EnkaNetwork 请求超时: %s", str(exc))
|
||||
except ClientConnectorError as exc:
|
||||
logger.warning("EnkaNetwork 请求错误: %s", str(exc))
|
||||
except Exception as exc:
|
||||
logger.error("EnkaNetwork 请求失败: %s", exc_info=exc)
|
||||
return None
|
||||
|
||||
async def get(self, player: Player) -> Optional[PlayerInfo]:
|
||||
player_info = await self.get_form_cache(player)
|
||||
if player_info is not None:
|
||||
return player_info
|
||||
player_info = await self._players_info_repository.get(player.user_id, player.player_id)
|
||||
if player_info is None:
|
||||
player_info_enka = await self.get_player_info_from_enka(player.player_id)
|
||||
if player_info_enka is None:
|
||||
return None
|
||||
player_info = PlayerInfo(
|
||||
user_id=player.user_id,
|
||||
player_id=player.player_id,
|
||||
nickname=player_info_enka.nickname,
|
||||
signature=player_info_enka.signature,
|
||||
name_card=player_info_enka.namecard.id,
|
||||
hand_image=player_info_enka.avatar.id,
|
||||
create_time=datetime.now(),
|
||||
last_save_time=datetime.now(),
|
||||
is_update=True,
|
||||
)
|
||||
await self._players_info_repository.add(PlayerInfoSQLModel.from_orm(player_info))
|
||||
await self.set_form_cache(player_info)
|
||||
return player_info
|
||||
if player_info.is_update:
|
||||
expiration_time = datetime.now() - timedelta(days=7)
|
||||
if player_info.last_save_time is None or player_info.last_save_time <= expiration_time:
|
||||
player_info_enka = await self.get_player_info_from_enka(player.player_id)
|
||||
if player_info_enka is None:
|
||||
player_info.last_save_time = datetime.now()
|
||||
await self._players_info_repository.update(PlayerInfoSQLModel.from_orm(player_info))
|
||||
await self.set_form_cache(player_info)
|
||||
return player_info
|
||||
player_info.nickname = player_info_enka.nickname
|
||||
player_info.name_card = player_info_enka.namecard.id
|
||||
player_info.signature = player_info_enka.signature
|
||||
player_info.hand_image = player_info_enka.avatar.id
|
||||
player_info.nickname = player_info_enka.nickname
|
||||
player_info.last_save_time = datetime.now()
|
||||
await self._players_info_repository.update(PlayerInfoSQLModel.from_orm(player_info))
|
||||
await self.set_form_cache(player_info)
|
||||
return player_info
|
||||
|
||||
async def update_from_enka(self, player: Player) -> bool:
|
||||
player_info = await self._players_info_repository.get(player.user_id, player.player_id)
|
||||
if player_info is not None:
|
||||
player_info_enka = await self.get_player_info_from_enka(player.player_id)
|
||||
if player_info_enka is None:
|
||||
return False
|
||||
player_info.nickname = player_info_enka.nickname
|
||||
player_info.name_card = player_info_enka.namecard.id
|
||||
player_info.signature = player_info_enka.signature
|
||||
player_info.hand_image = player_info_enka.avatar.id
|
||||
player_info.nickname = player_info_enka.nickname
|
||||
player_info.last_save_time = datetime.now()
|
||||
await self._players_info_repository.update(player_info)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def add_from_enka(self, player: Player) -> bool:
|
||||
player_info = await self._players_info_repository.get(player.user_id, player.player_id)
|
||||
if player_info is None:
|
||||
player_info_enka = await self.get_player_info_from_enka(player.player_id)
|
||||
if player_info_enka is None:
|
||||
return False
|
||||
player_info = PlayerInfoSQLModel(
|
||||
user_id=player.user_id,
|
||||
player_id=player.player_id,
|
||||
nickname=player_info_enka.nickname,
|
||||
signature=player_info_enka.signature,
|
||||
name_card=player_info_enka.namecard.id,
|
||||
hand_image=player_info_enka.avatar.id,
|
||||
create_time=datetime.now(),
|
||||
last_save_time=datetime.now(),
|
||||
is_update=True,
|
||||
)
|
||||
await self._players_info_repository.add(player_info)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_form_sql(self, player: Player):
|
||||
return await self._players_info_repository.get(player.user_id, player.player_id)
|
||||
|
||||
async def delete_form_player(self, player: Player):
|
||||
await self._players_info_repository.delete_by_id(user_id=player.user_id, player_id=player.player_id)
|
||||
|
||||
async def add(self, player_info: PlayerInfo):
|
||||
await self._players_info_repository.add(PlayerInfoSQLModel.from_orm(player_info))
|
||||
|
||||
async def delete(self, player_info: PlayerInfo):
|
||||
await self._players_info_repository.delete(PlayerInfoSQLModel.from_orm(player_info))
|
1
core/services/quiz/__init__.py
Normal file
1
core/services/quiz/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""QuizService"""
|
@ -1,12 +1,13 @@
|
||||
from typing import List
|
||||
|
||||
import ujson
|
||||
from core.base_service import BaseService
|
||||
from core.dependence.redisdb import RedisDB
|
||||
from core.services.quiz.models import Answer, Question
|
||||
|
||||
from core.base.redisdb import RedisDB
|
||||
from .models import Answer, Question
|
||||
__all__ = ("QuizCache",)
|
||||
|
||||
|
||||
class QuizCache:
|
||||
class QuizCache(BaseService.Component):
|
||||
def __init__(self, redis: RedisDB):
|
||||
self.client = redis.client
|
||||
self.question_qname = "quiz:question"
|
||||
@ -18,7 +19,7 @@ class QuizCache:
|
||||
data_list = [self.question_qname + f":{question_id}" for question_id in await self.client.lrange(qname, 0, -1)]
|
||||
data = await self.client.mget(data_list)
|
||||
for i in data:
|
||||
temp_list.append(Question.de_json(ujson.loads(i)))
|
||||
temp_list.append(Question.parse_raw(i))
|
||||
return temp_list
|
||||
|
||||
async def get_all_question_id_list(self) -> List[str]:
|
||||
@ -29,19 +30,19 @@ class QuizCache:
|
||||
qname = f"{self.question_qname}:{question_id}"
|
||||
data = await self.client.get(qname)
|
||||
json_data = str(data, encoding="utf-8")
|
||||
return Question.de_json(ujson.loads(json_data))
|
||||
return Question.parse_raw(json_data)
|
||||
|
||||
async def get_one_answer(self, answer_id: int) -> Answer:
|
||||
qname = f"{self.answer_qname}:{answer_id}"
|
||||
data = await self.client.get(qname)
|
||||
json_data = str(data, encoding="utf-8")
|
||||
return Answer.de_json(ujson.loads(json_data))
|
||||
return Answer.parse_raw(json_data)
|
||||
|
||||
async def add_question(self, question_list: List[Question] = None) -> int:
|
||||
if not question_list:
|
||||
return 0
|
||||
for question in question_list:
|
||||
await self.client.set(f"{self.question_qname}:{question.question_id}", ujson.dumps(question.to_dict()))
|
||||
await self.client.set(f"{self.question_qname}:{question.question_id}", question.json())
|
||||
question_id_list = [question.question_id for question in question_list]
|
||||
await self.client.lpush(f"{self.question_qname}:id_list", *question_id_list)
|
||||
return await self.client.llen(f"{self.question_qname}:id_list")
|
||||
@ -62,7 +63,7 @@ class QuizCache:
|
||||
if not answer_list:
|
||||
return 0
|
||||
for answer in answer_list:
|
||||
await self.client.set(f"{self.answer_qname}:{answer.answer_id}", ujson.dumps(answer.to_dict()))
|
||||
await self.client.set(f"{self.answer_qname}:{answer.answer_id}", answer.json())
|
||||
answer_id_list = [answer.answer_id for answer in answer_list]
|
||||
await self.client.lpush(f"{self.answer_qname}:id_list", *answer_id_list)
|
||||
return await self.client.llen(f"{self.answer_qname}:id_list")
|
57
core/services/quiz/models.py
Normal file
57
core/services/quiz/models.py
Normal file
@ -0,0 +1,57 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import Column, Field, ForeignKey, Integer, SQLModel
|
||||
|
||||
__all__ = ("Answer", "AnswerDB", "Question", "QuestionDB")
|
||||
|
||||
|
||||
class AnswerDB(SQLModel, table=True):
|
||||
__tablename__ = "answer"
|
||||
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
||||
|
||||
id: Optional[int] = Field(
|
||||
default=None, primary_key=True, sa_column=Column(Integer, primary_key=True, autoincrement=True)
|
||||
)
|
||||
question_id: Optional[int] = Field(
|
||||
sa_column=Column(Integer, ForeignKey("question.id", ondelete="RESTRICT", onupdate="RESTRICT"))
|
||||
)
|
||||
is_correct: Optional[bool] = Field()
|
||||
text: Optional[str] = Field()
|
||||
|
||||
|
||||
class QuestionDB(SQLModel, table=True):
|
||||
__tablename__ = "question"
|
||||
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
||||
|
||||
id: Optional[int] = Field(
|
||||
default=None, primary_key=True, sa_column=Column(Integer, primary_key=True, autoincrement=True)
|
||||
)
|
||||
text: Optional[str] = Field()
|
||||
|
||||
|
||||
class Answer(BaseModel):
|
||||
answer_id: int = 0
|
||||
question_id: int = 0
|
||||
is_correct: bool = True
|
||||
text: str = ""
|
||||
|
||||
def to_database_data(self) -> AnswerDB:
|
||||
return AnswerDB(id=self.answer_id, question_id=self.question_id, text=self.text, is_correct=self.is_correct)
|
||||
|
||||
@classmethod
|
||||
def de_database_data(cls, data: AnswerDB) -> Optional["Answer"]:
|
||||
return cls(answer_id=data.id, question_id=data.question_id, text=data.text, is_correct=data.is_correct)
|
||||
|
||||
|
||||
class Question(BaseModel):
|
||||
question_id: int = 0
|
||||
text: str = ""
|
||||
answers: List[Answer] = []
|
||||
|
||||
def to_database_data(self) -> QuestionDB:
|
||||
return QuestionDB(text=self.text, id=self.question_id)
|
||||
|
||||
@classmethod
|
||||
def de_database_data(cls, data: QuestionDB) -> Optional["Question"]:
|
||||
return cls(question_id=data.id, text=data.text)
|
@ -2,54 +2,55 @@ from typing import List
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from core.base.mysql import MySQL
|
||||
from .models import AnswerDB, QuestionDB
|
||||
from core.base_service import BaseService
|
||||
from core.dependence.mysql import MySQL
|
||||
from core.services.quiz.models import AnswerDB, QuestionDB
|
||||
from core.sqlmodel.session import AsyncSession
|
||||
|
||||
__all__ = ("QuizRepository",)
|
||||
|
||||
|
||||
class QuizRepository:
|
||||
class QuizRepository(BaseService.Component):
|
||||
def __init__(self, mysql: MySQL):
|
||||
self.mysql = mysql
|
||||
self.engine = mysql.engine
|
||||
|
||||
async def get_question_list(self) -> List[QuestionDB]:
|
||||
async with self.mysql.Session() as session:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
query = select(QuestionDB)
|
||||
results = await session.exec(query)
|
||||
questions = results.all()
|
||||
return questions
|
||||
return results.all()
|
||||
|
||||
async def get_answers_from_question_id(self, question_id: int) -> List[AnswerDB]:
|
||||
async with self.mysql.Session() as session:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
query = select(AnswerDB).where(AnswerDB.question_id == question_id)
|
||||
results = await session.exec(query)
|
||||
answers = results.all()
|
||||
return answers
|
||||
return results.all()
|
||||
|
||||
async def add_question(self, question: QuestionDB):
|
||||
async with self.mysql.Session() as session:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(question)
|
||||
await session.commit()
|
||||
|
||||
async def get_question_by_text(self, text: str) -> QuestionDB:
|
||||
async with self.mysql.Session() as session:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
query = select(QuestionDB).where(QuestionDB.text == text)
|
||||
results = await session.exec(query)
|
||||
question = results.first()
|
||||
return question[0]
|
||||
return results.first()
|
||||
|
||||
async def add_answer(self, answer: AnswerDB):
|
||||
async with self.mysql.Session() as session:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(answer)
|
||||
await session.commit()
|
||||
|
||||
async def delete_question_by_id(self, question_id: int):
|
||||
async with self.mysql.Session() as session:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = select(QuestionDB).where(QuestionDB.id == question_id)
|
||||
results = await session.exec(statement)
|
||||
question = results.one()
|
||||
await session.delete(question)
|
||||
|
||||
async def delete_answer_by_id(self, answer_id: int):
|
||||
async with self.mysql.Session() as session:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = select(AnswerDB).where(AnswerDB.id == answer_id)
|
||||
results = await session.exec(statement)
|
||||
answer = results.one()
|
@ -1,12 +1,15 @@
|
||||
import asyncio
|
||||
from typing import List
|
||||
|
||||
from .cache import QuizCache
|
||||
from .models import Answer, Question
|
||||
from .repositories import QuizRepository
|
||||
from core.base_service import BaseService
|
||||
from core.services.quiz.cache import QuizCache
|
||||
from core.services.quiz.models import Answer, Question
|
||||
from core.services.quiz.repositories import QuizRepository
|
||||
|
||||
__all__ = ("QuizService",)
|
||||
|
||||
|
||||
class QuizService:
|
||||
class QuizService(BaseService):
|
||||
def __init__(self, repository: QuizRepository, cache: QuizCache):
|
||||
self._repository = repository
|
||||
self._cache = cache
|
1
core/services/search/__init__.py
Normal file
1
core/services/search/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""SearchService"""
|
@ -1,12 +1,11 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Optional, List
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
__all__ = ["BaseEntry", "WeaponEntry", "WeaponsEntry", "StrategyEntry", "StrategyEntryList"]
|
||||
|
||||
from thefuzz import fuzz
|
||||
|
||||
__all__ = ("BaseEntry", "WeaponEntry", "WeaponsEntry", "StrategyEntry", "StrategyEntryList")
|
||||
|
||||
|
||||
class BaseEntry(BaseModel):
|
||||
"""所有可搜索条目的基类。
|
@ -5,19 +5,22 @@ import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Tuple, List, Optional, Dict
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import aiofiles
|
||||
from async_lru import alru_cache
|
||||
|
||||
from core.search.models import WeaponEntry, BaseEntry, WeaponsEntry, StrategyEntry, StrategyEntryList
|
||||
from core.base_service import BaseService
|
||||
from core.services.search.models import BaseEntry, StrategyEntry, StrategyEntryList, WeaponEntry, WeaponsEntry
|
||||
from utils.const import PROJECT_ROOT
|
||||
|
||||
__all__ = ("SearchServices",)
|
||||
|
||||
ENTRY_DAYA_PATH = PROJECT_ROOT.joinpath("data", "entry")
|
||||
ENTRY_DAYA_PATH.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
class SearchServices:
|
||||
class SearchServices(BaseService):
|
||||
def __init__(self):
|
||||
self._lock = asyncio.Lock() # 访问和修改操作成员变量必须加锁操作
|
||||
self.weapons: List[WeaponEntry] = []
|
1
core/services/sign/__init__.py
Normal file
1
core/services/sign/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""SignService"""
|
@ -2,8 +2,10 @@ import enum
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import SQLModel, Field, Enum, Column, DateTime
|
||||
from sqlalchemy import func, BigInteger
|
||||
from sqlmodel import Column, DateTime, Enum, Field, SQLModel, Integer
|
||||
|
||||
__all__ = ("SignStatusEnum", "Sign")
|
||||
|
||||
|
||||
class SignStatusEnum(int, enum.Enum):
|
||||
@ -19,10 +21,13 @@ class SignStatusEnum(int, enum.Enum):
|
||||
|
||||
class Sign(SQLModel, table=True):
|
||||
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
||||
|
||||
id: int = Field(primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.user_id")
|
||||
id: Optional[int] = Field(
|
||||
default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True)
|
||||
)
|
||||
user_id: int = Field(primary_key=True, sa_column=Column(BigInteger(), index=True))
|
||||
chat_id: Optional[int] = Field(default=None)
|
||||
time_created: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True), server_default=func.now()))
|
||||
time_updated: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True), onupdate=func.now()))
|
||||
time_created: Optional[datetime] = Field(
|
||||
sa_column=Column(DateTime, server_default=func.now()) # pylint: disable=E1102
|
||||
)
|
||||
time_updated: Optional[datetime] = Field(sa_column=Column(DateTime, onupdate=func.now())) # pylint: disable=E1102
|
||||
status: Optional[SignStatusEnum] = Field(sa_column=Column(Enum(SignStatusEnum)))
|
50
core/services/sign/repositories.py
Normal file
50
core/services/sign/repositories.py
Normal file
@ -0,0 +1,50 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.dependence.mysql import MySQL
|
||||
from core.services.sign.models import Sign
|
||||
from core.sqlmodel.session import AsyncSession
|
||||
|
||||
__all__ = ("SignRepository",)
|
||||
|
||||
|
||||
class SignRepository(BaseService.Component):
|
||||
def __init__(self, mysql: MySQL):
|
||||
self.engine = mysql.engine
|
||||
|
||||
async def add(self, sign: Sign):
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(sign)
|
||||
await session.commit()
|
||||
|
||||
async def remove(self, sign: Sign):
|
||||
async with AsyncSession(self.engine) as session:
|
||||
await session.delete(sign)
|
||||
await session.commit()
|
||||
|
||||
async def update(self, sign: Sign) -> Sign:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(sign)
|
||||
await session.commit()
|
||||
await session.refresh(sign)
|
||||
return sign
|
||||
|
||||
async def get_by_user_id(self, user_id: int) -> Optional[Sign]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = select(Sign).where(Sign.user_id == user_id)
|
||||
results = await session.exec(statement)
|
||||
return results.first()
|
||||
|
||||
async def get_by_chat_id(self, chat_id: int) -> Optional[List[Sign]]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = select(Sign).where(Sign.chat_id == chat_id)
|
||||
results = await session.exec(statement)
|
||||
return results.all()
|
||||
|
||||
async def get_all(self) -> List[Sign]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
query = select(Sign)
|
||||
results = await session.exec(query)
|
||||
return results.all()
|
@ -1,8 +1,11 @@
|
||||
from .models import Sign
|
||||
from .repositories import SignRepository
|
||||
from core.base_service import BaseService
|
||||
from core.services.sign.models import Sign
|
||||
from core.services.sign.repositories import SignRepository
|
||||
|
||||
__all__ = ["SignServices"]
|
||||
|
||||
|
||||
class SignServices:
|
||||
class SignServices(BaseService):
|
||||
def __init__(self, sign_repository: SignRepository) -> None:
|
||||
self._repository: SignRepository = sign_repository
|
||||
|
1
core/services/template/__init__.py
Normal file
1
core/services/template/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""TemplateService"""
|
@ -3,10 +3,14 @@ import pickle # nosec B403
|
||||
from hashlib import sha256
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.base.redisdb import RedisDB
|
||||
from core.base_service import BaseService
|
||||
|
||||
from core.dependence.redisdb import RedisDB
|
||||
|
||||
__all__ = ["TemplatePreviewCache", "HtmlToFileIdCache"]
|
||||
|
||||
|
||||
class TemplatePreviewCache:
|
||||
class TemplatePreviewCache(BaseService.Component):
|
||||
"""暂存渲染模板的数据用于预览"""
|
||||
|
||||
def __init__(self, redis: RedisDB):
|
||||
@ -29,7 +33,7 @@ class TemplatePreviewCache:
|
||||
return f"{self.qname}:{key}"
|
||||
|
||||
|
||||
class HtmlToFileIdCache:
|
||||
class HtmlToFileIdCache(BaseService.Component):
|
||||
"""html to file_id 的缓存"""
|
||||
|
||||
def __init__(self, redis: RedisDB):
|
@ -1,10 +1,12 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Union, List
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from telegram import Message, InputMediaPhoto, InputMediaDocument
|
||||
from telegram import InputMediaDocument, InputMediaPhoto, Message
|
||||
|
||||
from core.template.cache import HtmlToFileIdCache
|
||||
from core.template.error import ErrorFileType, FileIdNotFound
|
||||
from core.services.template.cache import HtmlToFileIdCache
|
||||
from core.services.template.error import ErrorFileType, FileIdNotFound
|
||||
|
||||
__all__ = ["FileType", "RenderResult", "RenderGroupResult"]
|
||||
|
||||
|
||||
class FileType(Enum):
|
||||
@ -16,10 +18,9 @@ class FileType(Enum):
|
||||
"""对应的 Telegram media 类型"""
|
||||
if file_type == FileType.PHOTO:
|
||||
return InputMediaPhoto
|
||||
elif file_type == FileType.DOCUMENT:
|
||||
if file_type == FileType.DOCUMENT:
|
||||
return InputMediaDocument
|
||||
else:
|
||||
raise ErrorFileType
|
||||
raise ErrorFileType
|
||||
|
||||
|
||||
class RenderResult:
|
@ -1,44 +1,31 @@
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from urllib.parse import (
|
||||
urlencode,
|
||||
urljoin,
|
||||
urlsplit,
|
||||
)
|
||||
from urllib.parse import urlencode, urljoin, urlsplit
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import (
|
||||
FileResponse,
|
||||
HTMLResponse,
|
||||
)
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import FileResponse, HTMLResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from jinja2 import (
|
||||
Environment,
|
||||
FileSystemLoader,
|
||||
Template,
|
||||
)
|
||||
from jinja2 import Environment, FileSystemLoader, Template
|
||||
from playwright.async_api import ViewportSize
|
||||
|
||||
from core.base.aiobrowser import AioBrowser
|
||||
from core.base.webserver import webapp
|
||||
from core.bot import bot
|
||||
from core.template.cache import (
|
||||
HtmlToFileIdCache,
|
||||
TemplatePreviewCache,
|
||||
)
|
||||
from core.template.error import QuerySelectorNotFound
|
||||
from core.template.models import (
|
||||
FileType,
|
||||
RenderResult,
|
||||
)
|
||||
from core.application import Application
|
||||
from core.base_service import BaseService
|
||||
from core.config import config as application_config
|
||||
from core.dependence.aiobrowser import AioBrowser
|
||||
from core.services.template.cache import HtmlToFileIdCache, TemplatePreviewCache
|
||||
from core.services.template.error import QuerySelectorNotFound
|
||||
from core.services.template.models import FileType, RenderResult
|
||||
from utils.const import PROJECT_ROOT
|
||||
from utils.log import logger
|
||||
|
||||
__all__ = ("TemplateService", "TemplatePreviewer")
|
||||
|
||||
class TemplateService:
|
||||
|
||||
class TemplateService(BaseService):
|
||||
def __init__(
|
||||
self,
|
||||
app: Application,
|
||||
browser: AioBrowser,
|
||||
html_to_file_id_cache: HtmlToFileIdCache,
|
||||
preview_cache: TemplatePreviewCache,
|
||||
@ -51,10 +38,12 @@ class TemplateService:
|
||||
loader=FileSystemLoader(template_dir),
|
||||
enable_async=True,
|
||||
autoescape=True,
|
||||
auto_reload=bot.config.debug,
|
||||
auto_reload=application_config.debug,
|
||||
)
|
||||
self.using_preview = application_config.debug and application_config.webserver.enable
|
||||
|
||||
self.previewer = TemplatePreviewer(self, preview_cache)
|
||||
if self.using_preview:
|
||||
self.previewer = TemplatePreviewer(self, preview_cache, app.web_app)
|
||||
|
||||
self.html_to_file_id_cache = html_to_file_id_cache
|
||||
|
||||
@ -66,10 +55,11 @@ class TemplateService:
|
||||
:param template_name: 模板文件名
|
||||
:param template_data: 模板数据
|
||||
"""
|
||||
start_time = time.time()
|
||||
loop = asyncio.get_event_loop()
|
||||
start_time = loop.time()
|
||||
template = self.get_template(template_name)
|
||||
html = await template.render_async(**template_data)
|
||||
logger.debug(f"{template_name} 模板渲染使用了 {str(time.time() - start_time)}")
|
||||
logger.debug("%s 模板渲染使用了 %s", template_name, str(loop.time() - start_time))
|
||||
return html
|
||||
|
||||
async def render(
|
||||
@ -100,19 +90,20 @@ class TemplateService:
|
||||
:param filename: 文件名字
|
||||
:return:
|
||||
"""
|
||||
start_time = time.time()
|
||||
loop = asyncio.get_event_loop()
|
||||
start_time = loop.time()
|
||||
template = self.get_template(template_name)
|
||||
|
||||
if bot.config.debug:
|
||||
if self.using_preview:
|
||||
preview_url = await self.previewer.get_preview_url(template_name, template_data)
|
||||
logger.debug(f"调试模板 URL: {preview_url}")
|
||||
logger.debug("调试模板 URL: \n%s", preview_url)
|
||||
|
||||
html = await template.render_async(**template_data)
|
||||
logger.debug(f"{template_name} 模板渲染使用了 {str(time.time() - start_time)}")
|
||||
logger.debug("%s 模板渲染使用了 %s", template_name, str(loop.time() - start_time))
|
||||
|
||||
file_id = await self.html_to_file_id_cache.get_data(html, file_type.name)
|
||||
if file_id and not bot.config.debug:
|
||||
logger.debug(f"{template_name} 命中缓存,返回 file_id {file_id}")
|
||||
if file_id and not application_config.debug:
|
||||
logger.debug("%s 命中缓存,返回 file_id[%s]", template_name, file_id)
|
||||
return RenderResult(
|
||||
html=html,
|
||||
photo=file_id,
|
||||
@ -125,7 +116,7 @@ class TemplateService:
|
||||
)
|
||||
|
||||
browser = await self._browser.get_browser()
|
||||
start_time = time.time()
|
||||
start_time = loop.time()
|
||||
page = await browser.new_page(viewport=viewport)
|
||||
uri = (PROJECT_ROOT / template.filename).as_uri()
|
||||
await page.goto(uri)
|
||||
@ -142,10 +133,10 @@ class TemplateService:
|
||||
if not clip:
|
||||
raise QuerySelectorNotFound
|
||||
except QuerySelectorNotFound:
|
||||
logger.warning(f"未找到 {query_selector} 元素")
|
||||
logger.warning("未找到 %s 元素", query_selector)
|
||||
png_data = await page.screenshot(clip=clip, full_page=full_page)
|
||||
await page.close()
|
||||
logger.debug(f"{template_name} 图片渲染使用了 {str(time.time() - start_time)}")
|
||||
logger.debug("%s 图片渲染使用了 %s", template_name, str(loop.time() - start_time))
|
||||
return RenderResult(
|
||||
html=html,
|
||||
photo=png_data,
|
||||
@ -158,15 +149,21 @@ class TemplateService:
|
||||
)
|
||||
|
||||
|
||||
class TemplatePreviewer:
|
||||
def __init__(self, template_service: TemplateService, cache: TemplatePreviewCache):
|
||||
class TemplatePreviewer(BaseService, load=application_config.webserver.enable and application_config.debug):
|
||||
def __init__(
|
||||
self,
|
||||
template_service: TemplateService,
|
||||
cache: TemplatePreviewCache,
|
||||
web_app: FastAPI,
|
||||
):
|
||||
self.web_app = web_app
|
||||
self.template_service = template_service
|
||||
self.cache = cache
|
||||
self.register_routes()
|
||||
|
||||
async def get_preview_url(self, template: str, data: dict):
|
||||
"""获取预览 URL"""
|
||||
components = urlsplit(bot.config.webserver.url)
|
||||
components = urlsplit(application_config.webserver.url)
|
||||
path = urljoin("/preview/", template)
|
||||
query = {}
|
||||
|
||||
@ -176,12 +173,13 @@ class TemplatePreviewer:
|
||||
await self.cache.set_data(key, data)
|
||||
query["key"] = key
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
return components._replace(path=path, query=urlencode(query)).geturl()
|
||||
|
||||
def register_routes(self):
|
||||
"""注册预览用到的路由"""
|
||||
|
||||
@webapp.get("/preview/{path:path}")
|
||||
@self.web_app.get("/preview/{path:path}")
|
||||
async def preview_template(path: str, key: Optional[str] = None): # pylint: disable=W0612
|
||||
# 如果是 /preview/ 开头的静态文件,直接返回内容。比如使用相对链接 ../ 引入的静态资源
|
||||
if not path.endswith(".html"):
|
||||
@ -206,4 +204,4 @@ class TemplatePreviewer:
|
||||
for name in ["cache", "resources"]:
|
||||
directory = PROJECT_ROOT / name
|
||||
directory.mkdir(exist_ok=True)
|
||||
webapp.mount(f"/{name}", StaticFiles(directory=PROJECT_ROOT / name), name=name)
|
||||
self.web_app.mount(f"/{name}", StaticFiles(directory=PROJECT_ROOT / name), name=name)
|
0
core/services/users/__init__.py
Normal file
0
core/services/users/__init__.py
Normal file
24
core/services/users/cache.py
Normal file
24
core/services/users/cache.py
Normal file
@ -0,0 +1,24 @@
|
||||
from typing import List
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.dependence.redisdb import RedisDB
|
||||
|
||||
__all__ = ("UserAdminCache",)
|
||||
|
||||
|
||||
class UserAdminCache(BaseService.Component):
|
||||
def __init__(self, redis: RedisDB):
|
||||
self.client = redis.client
|
||||
self.qname = "users:admin"
|
||||
|
||||
async def ismember(self, user_id: int) -> bool:
|
||||
return self.client.sismember(self.qname, user_id)
|
||||
|
||||
async def get_all(self) -> List[int]:
|
||||
return [int(str_data) for str_data in await self.client.smembers(self.qname)]
|
||||
|
||||
async def set(self, user_id: int) -> bool:
|
||||
return await self.client.sadd(self.qname, user_id)
|
||||
|
||||
async def remove(self, user_id: int) -> bool:
|
||||
return await self.client.srem(self.qname, user_id)
|
34
core/services/users/models.py
Normal file
34
core/services/users/models.py
Normal file
@ -0,0 +1,34 @@
|
||||
import enum
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import SQLModel, Field, DateTime, Column, Enum, BigInteger, Integer
|
||||
|
||||
__all__ = (
|
||||
"User",
|
||||
"UserDataBase",
|
||||
"PermissionsEnum",
|
||||
)
|
||||
|
||||
|
||||
class PermissionsEnum(int, enum.Enum):
|
||||
OWNER = 1
|
||||
ADMIN = 2
|
||||
PUBLIC = 3
|
||||
|
||||
|
||||
class User(SQLModel):
|
||||
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
||||
id: Optional[int] = Field(
|
||||
default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True)
|
||||
)
|
||||
user_id: int = Field(unique=True, sa_column=Column(BigInteger()))
|
||||
permissions: Optional[PermissionsEnum] = Field(sa_column=Column(Enum(PermissionsEnum)))
|
||||
locale: Optional[str] = Field()
|
||||
ban_end_time: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True)))
|
||||
ban_start_time: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True)))
|
||||
is_banned: Optional[int] = Field()
|
||||
|
||||
|
||||
class UserDataBase(User, table=True):
|
||||
__tablename__ = "users"
|
44
core/services/users/repositories.py
Normal file
44
core/services/users/repositories.py
Normal file
@ -0,0 +1,44 @@
|
||||
from typing import Optional, List
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.dependence.mysql import MySQL
|
||||
from core.services.users.models import UserDataBase as User
|
||||
from core.sqlmodel.session import AsyncSession
|
||||
|
||||
__all__ = ("UserRepository",)
|
||||
|
||||
|
||||
class UserRepository(BaseService.Component):
|
||||
def __init__(self, mysql: MySQL):
|
||||
self.engine = mysql.engine
|
||||
|
||||
async def get_by_user_id(self, user_id: int) -> Optional[User]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = select(User).where(User.user_id == user_id)
|
||||
results = await session.exec(statement)
|
||||
return results.first()
|
||||
|
||||
async def add(self, user: User):
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
|
||||
async def update(self, user: User) -> User:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return user
|
||||
|
||||
async def remove(self, user: User):
|
||||
async with AsyncSession(self.engine) as session:
|
||||
await session.delete(user)
|
||||
await session.commit()
|
||||
|
||||
async def get_all(self) -> List[User]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = select(User)
|
||||
results = await session.exec(statement)
|
||||
return results.all()
|
79
core/services/users/services.py
Normal file
79
core/services/users/services.py
Normal file
@ -0,0 +1,79 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from core.base_service import BaseService
|
||||
from core.config import config
|
||||
from core.services.users.cache import UserAdminCache
|
||||
from core.services.users.models import PermissionsEnum, UserDataBase as User
|
||||
from core.services.users.repositories import UserRepository
|
||||
|
||||
__all__ = ("UserService", "UserAdminService")
|
||||
|
||||
from utils.log import logger
|
||||
|
||||
|
||||
class UserService(BaseService):
|
||||
def __init__(self, user_repository: UserRepository) -> None:
|
||||
self._repository: UserRepository = user_repository
|
||||
|
||||
async def get_user_by_id(self, user_id: int) -> Optional[User]:
|
||||
"""从数据库获取用户信息
|
||||
:param user_id:用户ID
|
||||
:return: User
|
||||
"""
|
||||
return await self._repository.get_by_user_id(user_id)
|
||||
|
||||
async def remove(self, user: User):
|
||||
return await self._repository.remove(user)
|
||||
|
||||
async def update_user(self, user: User):
|
||||
return await self._repository.add(user)
|
||||
|
||||
|
||||
class UserAdminService(BaseService):
|
||||
def __init__(self, user_repository: UserRepository, cache: UserAdminCache):
|
||||
self.user_repository = user_repository
|
||||
self._cache = cache
|
||||
|
||||
async def initialize(self):
|
||||
owner = config.owner
|
||||
if owner:
|
||||
user = await self.user_repository.get_by_user_id(owner)
|
||||
await self._cache.set(user.user_id)
|
||||
if user:
|
||||
if user.permissions != PermissionsEnum.OWNER:
|
||||
user.permissions = PermissionsEnum.OWNER
|
||||
await self.user_repository.update(user)
|
||||
else:
|
||||
user = User(user_id=owner, permissions=PermissionsEnum.OWNER)
|
||||
await self.user_repository.add(user)
|
||||
else:
|
||||
logger.warning("检测到未配置Bot所有者 会导无法正常使用管理员权限")
|
||||
|
||||
async def is_admin(self, user_id: int) -> bool:
|
||||
return await self._cache.ismember(user_id)
|
||||
|
||||
async def get_admin_list(self) -> List[int]:
|
||||
return await self._cache.get_all()
|
||||
|
||||
async def add_admin(self, user_id: int) -> bool:
|
||||
user = await self.user_repository.get_by_user_id(user_id)
|
||||
if user:
|
||||
if user.permissions == PermissionsEnum.OWNER:
|
||||
return False
|
||||
if user.permissions != PermissionsEnum.ADMIN:
|
||||
user.permissions = PermissionsEnum.ADMIN
|
||||
await self.user_repository.update(user)
|
||||
else:
|
||||
user = User(user_id=user_id, permissions=PermissionsEnum.ADMIN)
|
||||
await self.user_repository.add(user)
|
||||
return await self._cache.set(user.user_id)
|
||||
|
||||
async def delete_admin(self, user_id: int) -> bool:
|
||||
user = await self.user_repository.get_by_user_id(user_id)
|
||||
if user:
|
||||
if user.permissions == PermissionsEnum.OWNER:
|
||||
return True # 假装移除成功
|
||||
user.permissions = PermissionsEnum.PUBLIC
|
||||
await self.user_repository.update(user)
|
||||
return await self._cache.remove(user.user_id)
|
||||
return False
|
1
core/services/wiki/__init__.py
Normal file
1
core/services/wiki/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""WikiService"""
|
@ -1,10 +1,13 @@
|
||||
import ujson as json
|
||||
|
||||
from core.base.redisdb import RedisDB
|
||||
from core.base_service import BaseService
|
||||
from core.dependence.redisdb import RedisDB
|
||||
from modules.wiki.base import Model
|
||||
|
||||
__all__ = ["WikiCache"]
|
||||
|
||||
class WikiCache:
|
||||
|
||||
class WikiCache(BaseService.Component):
|
||||
def __init__(self, redis: RedisDB):
|
||||
self.client = redis.client
|
||||
self.qname = "wiki"
|
@ -1,12 +1,15 @@
|
||||
from typing import List, NoReturn, Optional
|
||||
|
||||
from core.wiki.cache import WikiCache
|
||||
from core.base_service import BaseService
|
||||
from core.services.wiki.cache import WikiCache
|
||||
from modules.wiki.character import Character
|
||||
from modules.wiki.weapon import Weapon
|
||||
from utils.log import logger
|
||||
|
||||
__all__ = ["WikiService"]
|
||||
|
||||
class WikiService:
|
||||
|
||||
class WikiService(BaseService):
|
||||
def __init__(self, cache: WikiCache):
|
||||
self._cache = cache
|
||||
"""Redis 在这里的作用是作为持久化"""
|
||||
@ -18,7 +21,7 @@ class WikiService:
|
||||
|
||||
async def refresh_weapon(self) -> NoReturn:
|
||||
weapon_name_list = await Weapon.get_name_list()
|
||||
logger.info(f"一共找到 {len(weapon_name_list)} 把武器信息")
|
||||
logger.info("一共找到 %s 把武器信息", len(weapon_name_list))
|
||||
|
||||
weapon_list = []
|
||||
num = 0
|
||||
@ -26,7 +29,7 @@ class WikiService:
|
||||
weapon_list.append(weapon)
|
||||
num += 1
|
||||
if num % 10 == 0:
|
||||
logger.info(f"现在已经获取到 {num} 把武器信息")
|
||||
logger.info("现在已经获取到 %s 把武器信息", num)
|
||||
|
||||
logger.info("写入武器信息到Redis")
|
||||
self._weapon_list = weapon_list
|
||||
@ -35,7 +38,7 @@ class WikiService:
|
||||
|
||||
async def refresh_characters(self) -> NoReturn:
|
||||
character_name_list = await Character.get_name_list()
|
||||
logger.info(f"一共找到 {len(character_name_list)} 个角色信息")
|
||||
logger.info("一共找到 %s 个角色信息", len(character_name_list))
|
||||
|
||||
character_list = []
|
||||
num = 0
|
||||
@ -43,7 +46,7 @@ class WikiService:
|
||||
character_list.append(character)
|
||||
num += 1
|
||||
if num % 10 == 0:
|
||||
logger.info(f"现在已经获取到 {num} 个角色信息")
|
||||
logger.info("现在已经获取到 %s 个角色信息", num)
|
||||
|
||||
logger.info("写入角色信息到Redis")
|
||||
self._character_list = character_list
|
@ -1,11 +0,0 @@
|
||||
from core.base.mysql import MySQL
|
||||
from core.service import init_service
|
||||
from .repositories import SignRepository
|
||||
from .services import SignServices
|
||||
|
||||
|
||||
@init_service
|
||||
def create_game_strategy_service(mysql: MySQL):
|
||||
_repository = SignRepository(mysql)
|
||||
_service = SignServices(_repository)
|
||||
return _service
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user