♻️ PaiGram V4

Co-authored-by: luoshuijs <luoshuijs@outlook.com>
Co-authored-by: Karako <karakohear@gmail.com>
Co-authored-by: xtaodada <xtao@xtaolink.cn>
This commit is contained in:
洛水居室 2023-03-14 09:27:22 +08:00 committed by GitHub
parent baceace292
commit 233e7ab58d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
246 changed files with 8964 additions and 5895 deletions

View File

@ -1,6 +1,12 @@
# debug 开关 # debug 开关
DEBUG=false DEBUG=false
AUTO_RELOAD=false
RELOAD_DELAY=0.25
RELOAD_DIRS=[]
RELOAD_INCLUDE=[]
RELOAD_EXCLUDE=[]
# MySQL # MySQL
DB_HOST=127.0.0.1 DB_HOST=127.0.0.1
DB_PORT=3306 DB_PORT=3306
@ -17,14 +23,14 @@ REDIS_PASSWORD=""
# 联系 https://t.me/BotFather 使用 /newbot 命令创建机器人并获取 token # 联系 https://t.me/BotFather 使用 /newbot 命令创建机器人并获取 token
BOT_TOKEN="xxxxxxx" BOT_TOKEN="xxxxxxx"
# bot 管理员 # bot 所有者
ADMINS=[{ "username": "", "user_id": -1 }] OWNER=0
# 记录错误并发送消息通知开发人员 可选配置项 # 记录错误并发送消息通知开发人员 可选配置项
# ERROR_NOTIFICATION_CHAT_ID=chat_id # ERROR_NOTIFICATION_CHAT_ID=chat_id
# 文章推送群组 可选配置项 # 文章推送群组 可选配置项
# CHANNELS=[{ "name": "", "chat_id": 1}] # CHANNELS=[]
# 是否允许机器人邀请到其他群 默认不允许 如果允许 可以允许全部人或有认证选项 可选配置项 # 是否允许机器人邀请到其他群 默认不允许 如果允许 可以允许全部人或有认证选项 可选配置项
# JOIN_GROUPS = "NO_ALLOW" # JOIN_GROUPS = "NO_ALLOW"
@ -33,20 +39,20 @@ ADMINS=[{ "username": "", "user_id": -1 }]
# VERIFY_GROUPS=[] # VERIFY_GROUPS=[]
# logger 配置 可选配置项 # logger 配置 可选配置项
LOGGER_NAME="TGPaimon" # LOGGER_NAME="TGPaimon"
# 打印时的宽度 # 打印时的宽度
LOGGER_WIDTH=180 # LOGGER_WIDTH=180
# log 文件存放目录 # log 文件存放目录
LOGGER_LOG_PATH="logs" # LOGGER_LOG_PATH="logs"
# log 时间格式,参考 datetime.strftime # log 时间格式,参考 datetime.strftime
LOGGER_TIME_FORMAT="[%Y-%m-%d %X]" # LOGGER_TIME_FORMAT="[%Y-%m-%d %X]"
# log 高亮关键词 # log 高亮关键词
LOGGER_RENDER_KEYWORDS=["BOT"] # LOGGER_RENDER_KEYWORDS=["BOT"]
# traceback 相关配置 # traceback 相关配置
LOGGER_TRACEBACK_MAX_FRAMES=20 # LOGGER_TRACEBACK_MAX_FRAMES=20
LOGGER_LOCALS_MAX_DEPTH=0 # LOGGER_LOCALS_MAX_DEPTH=0
LOGGER_LOCALS_MAX_LENGTH=10 # LOGGER_LOCALS_MAX_LENGTH=10
LOGGER_LOCALS_MAX_STRING=80 # LOGGER_LOCALS_MAX_STRING=80
# 可被 logger 打印的 record 的名称(默认包含了 LOGGER_NAME # 可被 logger 打印的 record 的名称(默认包含了 LOGGER_NAME
LOGGER_FILTERED_NAMES=["uvicorn","ErrorPush","ApiHelper"] LOGGER_FILTERED_NAMES=["uvicorn","ErrorPush","ApiHelper"]
@ -77,7 +83,7 @@ LOGGER_FILTERED_NAMES=["uvicorn","ErrorPush","ApiHelper"]
# ENKA_NETWORK_API_AGENT="" # ENKA_NETWORK_API_AGENT=""
# Web Server # Web Server
# 目前只用于预览模板,仅开发环境启动 # WEB_SWITCH=False # 是否开启
# WEB_URL=http://localhost:8080/ # WEB_URL=http://localhost:8080/
# WEB_HOST=localhost # WEB_HOST=localhost
# WEB_PORT=8080 # WEB_PORT=8080

54
.github/workflows/integration-test.yml vendored Normal file
View File

@ -0,0 +1,54 @@
name: Integration Test
on:
push:
branches:
- main
paths:
- 'tests/integration/**'
pull_request:
types: [ opened, synchronize ]
paths:
- 'core/services/**'
- 'core/dependence/**'
- 'tests/integration/**'
jobs:
pytest:
name: pytest
runs-on: ubuntu-latest
services:
mysql:
image: mysql:5.7
env:
MYSQL_DATABASE: integration_test
MYSQL_ROOT_PASSWORD: 123456test
ports:
- 3306:3306
options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3
redis:
image: redis
ports:
- 6379:6379
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Set up Python 3.11
uses: actions/setup-python@v2
with:
python-version: 3.11
- name: Setup integration test environment
run: cp tests/integration/.env.example .env && cp tests/integration/.env.example tests/integration/.env
- name: Create venv
run: |
pip install --upgrade pip
python3 -m venv venv
- name: Install requirements
run: |
source venv/bin/activate
python3 -m pip install --upgrade poetry
python3 -m poetry install --extras all
- name: Run test
run: |
source venv/bin/activate
python3 -m pytest tests/integration

View File

@ -1,19 +1,17 @@
name: test name: Test modules
on: on:
push: push:
branches: branches:
- main - main
paths: paths:
- 'tests/**' - 'tests/unit/**'
pull_request: pull_request:
types: [ opened, synchronize ] types: [ opened, synchronize ]
paths: paths:
- 'modules/apihelper/**' - 'modules/apihelper/**'
- 'modules/wiki/**' - 'modules/wiki/**'
- 'tests/**' - 'tests/unit/**'
schedule:
- cron: '0 4 * * 3'
jobs: jobs:
pytest: pytest:
@ -22,16 +20,15 @@ jobs:
continue-on-error: ${{ matrix.experimental }} continue-on-error: ${{ matrix.experimental }}
strategy: strategy:
matrix: matrix:
python-version: [ '3.10' ]
os: [ ubuntu-latest, windows-latest ] os: [ ubuntu-latest, windows-latest ]
experimental: [ false ]
fail-fast: False fail-fast: False
steps: steps:
- uses: actions/checkout@v3 - name: Checkout code
- name: Set up Python ${{ matrix.python-version }} uses: actions/checkout@v3
uses: actions/setup-python@v4 - name: Set up Python 3.11
uses: actions/setup-python@v2
with: with:
python-version: ${{ matrix.python-version }} python-version: 3.11
- name: restore or create a python virtualenv - name: restore or create a python virtualenv
id: cache id: cache
uses: syphar/restore-virtualenv@v1.2 uses: syphar/restore-virtualenv@v1.2
@ -45,4 +42,4 @@ jobs:
poetry install --extras test poetry install --extras test
- name: Test with pytest - name: Test with pytest
run: | run: |
python -m pytest python -m pytest tests/unit

5
.gitignore vendored
View File

@ -58,6 +58,5 @@ plugins/private
.pytest_cache .pytest_cache
### mtp ### ### mtp ###
paimon.session paigram.session
PaimonBot.session paigram.session-journal
PaimonBot.session-journal

View File

@ -1,7 +1,6 @@
<h1 align="center">PaiGram</h1> <h1 align="center">PaiGram</h1>
<div align="center"> <div align="center">·<img src="https://img.shields.io/badge/python-3.11%2B-blue" alt="">
<img src="https://img.shields.io/badge/python-3.8%2B-blue" alt="">
<img src="https://img.shields.io/badge/works%20on-my%20machine-brightgreen" alt=""> <img src="https://img.shields.io/badge/works%20on-my%20machine-brightgreen" alt="">
<img src="https://img.shields.io/badge/status-%E5%92%95%E5%92%95%E5%92%95-blue" alt=""> <img src="https://img.shields.io/badge/status-%E5%92%95%E5%92%95%E5%92%95-blue" alt="">
<a href="https://black.readthedocs.io/en/stable/index.html"><img src="https://img.shields.io/badge/code%20style-black-000000.svg" alt="code_style" /></a> <a href="https://black.readthedocs.io/en/stable/index.html"><img src="https://img.shields.io/badge/code%20style-black-000000.svg" alt="code_style" /></a>
@ -19,7 +18,7 @@
## 环境需求 ## 环境需求
- Python 3.8+ - Python 3.11+
- MySQL - MySQL
- Redis - Redis

View File

@ -6,19 +6,13 @@ from logging.config import fileConfig
from typing import Iterator from typing import Iterator
from alembic import context from alembic import context
from sqlalchemy import ( from sqlalchemy import engine_from_config, pool
engine_from_config,
pool,
)
from sqlalchemy.engine import Connection from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.ext.asyncio import AsyncEngine
from sqlmodel import SQLModel from sqlmodel import SQLModel
from utils.const import ( from core.config import config as BotConfig
CORE_DIR, from utils.const import CORE_DIR, PLUGIN_DIR, PROJECT_ROOT
PLUGIN_DIR,
PROJECT_ROOT,
)
from utils.log import logger from utils.log import logger
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
@ -28,7 +22,7 @@ config = context.config
# Interpret the config file for Python logging. # Interpret the config file for Python logging.
# This line sets up loggers basically. # This line sets up loggers basically.
if config.config_file_name is not None: if config.config_file_name is not None:
fileConfig(config.config_file_name) fileConfig(config.config_file_name) # skipcq: PY-A6006
def scan_models() -> Iterator[str]: def scan_models() -> Iterator[str]:
@ -46,7 +40,7 @@ def import_models():
try: try:
import_module(pkg) # 导入 models import_module(pkg) # 导入 models
except Exception as e: # pylint: disable=W0703 except Exception as e: # pylint: disable=W0703
logger.error(f'在导入文件 "{pkg}" 的过程中遇到了错误: \n[red bold]{type(e).__name__}: {e}[/]') logger.error("在导入文件 %s 的过程中遇到了错误: \n[red bold]%s: %s[/]", pkg, type(e).__name__, e, extra={"markup": True})
# register our models for alembic to auto-generate migrations # register our models for alembic to auto-generate migrations
@ -61,14 +55,13 @@ target_metadata = SQLModel.metadata
# here we allow ourselves to pass interpolation vars to alembic.ini # here we allow ourselves to pass interpolation vars to alembic.ini
# from the application config module # from the application config module
from core.config import config as botConfig
section = config.config_ini_section section = config.config_ini_section
config.set_section_option(section, "DB_HOST", botConfig.mysql.host) config.set_section_option(section, "DB_HOST", BotConfig.mysql.host)
config.set_section_option(section, "DB_PORT", str(botConfig.mysql.port)) config.set_section_option(section, "DB_PORT", str(BotConfig.mysql.port))
config.set_section_option(section, "DB_USERNAME", botConfig.mysql.username) config.set_section_option(section, "DB_USERNAME", BotConfig.mysql.username)
config.set_section_option(section, "DB_PASSWORD", botConfig.mysql.password) config.set_section_option(section, "DB_PASSWORD", BotConfig.mysql.password)
config.set_section_option(section, "DB_DATABASE", botConfig.mysql.database) config.set_section_option(section, "DB_DATABASE", BotConfig.mysql.database)
def run_migrations_offline() -> None: def run_migrations_offline() -> None:

View File

@ -5,16 +5,19 @@ Revises:
Create Date: 2022-09-01 16:55:20.372560 Create Date: 2022-09-01 16:55:20.372560
""" """
from alembic import op from base64 import b64decode
import sqlalchemy as sa import sqlalchemy as sa
import sqlmodel import sqlmodel
from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = "9e9a36470cd5" revision = "9e9a36470cd5"
down_revision = None down_revision = None
branch_labels = None branch_labels = None
depends_on = None depends_on = None
old_cookies_database_name1 = b64decode("bWlob3lvX2Nvb2tpZXM=").decode()
old_cookies_database_name2 = b64decode("aG95b3ZlcnNlX2Nvb2tpZXM=").decode()
def upgrade() -> None: def upgrade() -> None:
@ -22,7 +25,7 @@ def upgrade() -> None:
op.create_table( op.create_table(
"question", "question",
sa.Column("id", sa.Integer(), nullable=False), sa.Column("id", sa.Integer(), nullable=False),
sa.Column("text", sqlmodel.sql.sqltypes.AutoString(), nullable=True), sa.Column("text", sqlmodel.AutoString(), nullable=True),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
mysql_charset="utf8mb4", mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci", mysql_collate="utf8mb4_general_ci",
@ -35,7 +38,7 @@ def upgrade() -> None:
nullable=True, nullable=True,
), ),
sa.Column("id", sa.Integer(), nullable=False), sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False), sa.Column("user_id", sa.BigInteger(), nullable=False),
sa.Column("yuanshen_uid", sa.Integer(), nullable=True), sa.Column("yuanshen_uid", sa.Integer(), nullable=True),
sa.Column("genshin_uid", sa.Integer(), nullable=True), sa.Column("genshin_uid", sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
@ -46,7 +49,7 @@ def upgrade() -> None:
op.create_table( op.create_table(
"admin", "admin",
sa.Column("id", sa.Integer(), nullable=False), sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False), sa.Column("user_id", sa.BigInteger(), nullable=False),
sa.ForeignKeyConstraint( sa.ForeignKeyConstraint(
["user_id"], ["user_id"],
["user.user_id"], ["user.user_id"],
@ -60,7 +63,7 @@ def upgrade() -> None:
sa.Column("question_id", sa.Integer(), nullable=True), sa.Column("question_id", sa.Integer(), nullable=True),
sa.Column("id", sa.Integer(), nullable=False), sa.Column("id", sa.Integer(), nullable=False),
sa.Column("is_correct", sa.Boolean(), nullable=True), sa.Column("is_correct", sa.Boolean(), nullable=True),
sa.Column("text", sqlmodel.sql.sqltypes.AutoString(), nullable=True), sa.Column("text", sqlmodel.AutoString(), nullable=True),
sa.ForeignKeyConstraint( sa.ForeignKeyConstraint(
["question_id"], ["question_id"],
["question.id"], ["question.id"],
@ -72,7 +75,7 @@ def upgrade() -> None:
mysql_collate="utf8mb4_general_ci", mysql_collate="utf8mb4_general_ci",
) )
op.create_table( op.create_table(
"hoyoverse_cookies", old_cookies_database_name2,
sa.Column("cookies", sa.JSON(), nullable=True), sa.Column("cookies", sa.JSON(), nullable=True),
sa.Column( sa.Column(
"status", "status",
@ -85,7 +88,7 @@ def upgrade() -> None:
nullable=True, nullable=True,
), ),
sa.Column("id", sa.Integer(), nullable=False), sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=True), sa.Column("user_id", sa.BigInteger(), nullable=True),
sa.ForeignKeyConstraint( sa.ForeignKeyConstraint(
["user_id"], ["user_id"],
["user.user_id"], ["user.user_id"],
@ -95,7 +98,7 @@ def upgrade() -> None:
mysql_collate="utf8mb4_general_ci", mysql_collate="utf8mb4_general_ci",
) )
op.create_table( op.create_table(
"mihoyo_cookies", old_cookies_database_name1,
sa.Column("cookies", sa.JSON(), nullable=True), sa.Column("cookies", sa.JSON(), nullable=True),
sa.Column( sa.Column(
"status", "status",
@ -108,7 +111,7 @@ def upgrade() -> None:
nullable=True, nullable=True,
), ),
sa.Column("id", sa.Integer(), nullable=False), sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=True), sa.Column("user_id", sa.BigInteger(), nullable=True),
sa.ForeignKeyConstraint( sa.ForeignKeyConstraint(
["user_id"], ["user_id"],
["user.user_id"], ["user.user_id"],
@ -119,6 +122,9 @@ def upgrade() -> None:
) )
op.create_table( op.create_table(
"sign", "sign",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=False),
sa.Column("chat_id", sa.BigInteger(), nullable=True),
sa.Column( sa.Column(
"time_created", "time_created",
sa.DateTime(timezone=True), sa.DateTime(timezone=True),
@ -140,14 +146,11 @@ def upgrade() -> None:
), ),
nullable=True, nullable=True,
), ),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("chat_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint( sa.ForeignKeyConstraint(
["user_id"], ["user_id"],
["user.user_id"], ["user.user_id"],
), ),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id", "user_id"),
mysql_charset="utf8mb4", mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci", mysql_collate="utf8mb4_general_ci",
) )
@ -157,8 +160,8 @@ def upgrade() -> None:
def downgrade() -> None: def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.drop_table("sign") op.drop_table("sign")
op.drop_table("mihoyo_cookies") op.drop_table(old_cookies_database_name1)
op.drop_table("hoyoverse_cookies") op.drop_table(old_cookies_database_name2)
op.drop_table("answer") op.drop_table("answer")
op.drop_table("admin") op.drop_table("admin")
op.drop_table("user") op.drop_table("user")

View File

@ -0,0 +1,301 @@
"""v4
Revision ID: ddcfba3c7d5c
Revises: 9e9a36470cd5
Create Date: 2023-02-11 17:07:18.170175
"""
import json
import logging
from base64 import b64decode
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy import text
from sqlalchemy.exc import NoSuchTableError
# revision identifiers, used by Alembic.
revision = "ddcfba3c7d5c"
down_revision = "9e9a36470cd5"
branch_labels = None
depends_on = None
old_cookies_database_name1 = b64decode("bWlob3lvX2Nvb2tpZXM=").decode()
old_cookies_database_name2 = b64decode("aG95b3ZlcnNlX2Nvb2tpZXM=").decode()
logger = logging.getLogger(__name__)
def upgrade() -> None:
connection = op.get_bind()
# ### commands auto generated by Alembic - please adjust! ###
cookies_table = op.create_table(
"cookies",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=False),
sa.Column("account_id", sa.BigInteger(), nullable=False),
sa.Column("data", sa.JSON(), nullable=True),
sa.Column(
"status",
sa.Enum(
"STATUS_SUCCESS",
"INVALID_COOKIES",
"TOO_MANY_REQUESTS",
name="cookiesstatusenum",
),
nullable=True,
),
sa.Column(
"region",
sa.Enum("NULL", "HYPERION", "HOYOLAB", name="regionenum"),
nullable=True,
),
sa.Column("is_share", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.Index("index_user_account", "user_id", "account_id", unique=True),
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
)
for old_cookies_database_name in (old_cookies_database_name1, old_cookies_database_name2):
try:
statement = f"SELECT * FROM {old_cookies_database_name};" # skipcq: BAN-B608
old_cookies_table_data = connection.execute(text(statement))
except NoSuchTableError:
logger.warning("Table '%s' doesn't exist", old_cookies_database_name)
continue
if old_cookies_table_data is None:
logger.warning("Old Cookies Database is None")
continue
for row in old_cookies_table_data:
try:
user_id = row["user_id"]
status = row["status"]
cookies_row = row["cookies"]
cookies_data = json.loads(cookies_row)
account_id = cookies_data.get("account_id")
if account_id is None: # Cleaning Data 清洗数据
account_id = cookies_data.get("ltuid")
else:
account_mid_v2 = cookies_data.get("account_mid_v2")
if account_mid_v2 is not None:
cookies_data.pop("account_id")
cookies_data.setdefault("account_uid_v2", account_id)
if old_cookies_database_name == old_cookies_database_name1:
region = "HYPERION"
else:
region = "HOYOLAB"
if account_id is None:
logger.warning("Can not get user account_id, user_id :%s", user_id)
continue
insert = cookies_table.insert().values(
user_id=int(user_id),
account_id=int(account_id),
status=status,
data=cookies_data,
region=region,
is_share=True,
)
with op.get_context().autocommit_block():
connection.execute(insert)
except Exception as exc: # pylint: disable=W0703
logger.error(
"Process %s->cookies Exception", old_cookies_database_name, exc_info=exc
) # pylint: disable=W0703
players_table = op.create_table(
"players",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=False),
sa.Column("account_id", sa.BigInteger(), nullable=True),
sa.Column("player_id", sa.BigInteger(), nullable=False),
sa.Column(
"region",
sa.Enum("NULL", "HYPERION", "HOYOLAB", name="regionenum"),
nullable=True,
),
sa.Column("is_chosen", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.Index("index_user_account_player", "user_id", "account_id", "player_id", unique=True),
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
)
try:
statement = "SELECT * FROM user;"
old_user_table_data = connection.execute(text(statement))
except NoSuchTableError:
logger.warning("Table 'user' doesn't exist")
return # should not happen
if old_user_table_data is not None:
for row in old_user_table_data:
try:
user_id = row["user_id"]
y_uid = row["yuanshen_uid"]
g_uid = row["genshin_uid"]
region = row["region"]
account_id = None
cookies_row = connection.execute(
cookies_table.select().where(cookies_table.c.user_id == user_id)
).first()
if cookies_row is not None:
account_id = cookies_row["account_id"]
if y_uid:
insert = players_table.insert().values(
user_id=int(user_id),
player_id=int(y_uid),
is_chosen=(region == "HYPERION"),
region="HYPERION",
account_id=account_id,
)
with op.get_context().autocommit_block():
connection.execute(insert)
if g_uid:
insert = players_table.insert().values(
user_id=int(user_id),
player_id=int(g_uid),
is_chosen=(region == "HOYOLAB"),
region="HOYOLAB",
account_id=account_id,
)
with op.get_context().autocommit_block():
connection.execute(insert)
except Exception as exc: # pylint: disable=W0703
logger.error("Process user->player Exception", exc_info=exc)
else:
logger.warning("Old User Database is None")
users_table = op.create_table(
"users",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=False, primary_key=True),
sa.Column(
"permissions",
sa.Enum("OWNER", "ADMIN", "PUBLIC", name="permissionsenum"),
nullable=True,
),
sa.Column("locale", sqlmodel.AutoString(), nullable=True),
sa.Column("is_banned", sa.BigInteger(), nullable=True),
sa.Column("ban_end_time", sa.DateTime(timezone=True), nullable=True),
sa.Column("ban_start_time", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("user_id"),
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
)
try:
statement = "SELECT * FROM admin;"
old_user_table_data = connection.execute(text(statement))
except NoSuchTableError:
logger.warning("Table 'admin' doesn't exist")
return # should not happen
if old_user_table_data is not None:
for row in old_user_table_data:
try:
user_id = row["user_id"]
insert = users_table.insert().values(
user_id=int(user_id),
permissions="ADMIN",
)
with op.get_context().autocommit_block():
connection.execute(insert)
except Exception as exc: # pylint: disable=W0703
logger.error("Process admin->users Exception", exc_info=exc)
else:
logger.warning("Old User Database is None")
op.create_table(
"players_info",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=False),
sa.Column("player_id", sa.BigInteger(), nullable=False),
sa.Column("nickname", sqlmodel.AutoString(length=128), nullable=True),
sa.Column("signature", sqlmodel.AutoString(length=255), nullable=True),
sa.Column("hand_image", sa.Integer(), nullable=True),
sa.Column("name_card", sa.Integer(), nullable=True),
sa.Column("extra_data", sa.VARCHAR(length=512), nullable=True),
sa.Column(
"create_time",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("last_save_time", sa.DateTime(timezone=True), nullable=True),
sa.Column("is_update", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.Index("index_user_player", "user_id", "player_id", unique=True),
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
)
op.drop_table(old_cookies_database_name1)
op.drop_table(old_cookies_database_name2)
op.drop_table("admin")
op.drop_constraint("sign_ibfk_1", "sign", type_="foreignkey")
op.drop_index("user_id", table_name="sign")
op.drop_table("user")
# ### end Alembic commands ###
def downgrade() -> None:
op.create_table(
"user",
sa.Column("region", sa.Enum("NULL", "HYPERION", "HOYOLAB", name="regionenum"), nullable=True),
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column("user_id", sa.BigInteger(), autoincrement=False, nullable=False),
sa.Column("yuanshen_uid", sa.INTEGER(), autoincrement=False, nullable=True),
sa.Column("genshin_uid", sa.INTEGER(), autoincrement=False, nullable=True),
sa.PrimaryKeyConstraint("id"),
mysql_collate="utf8mb4_general_ci",
mysql_default_charset="utf8mb4",
mysql_engine="InnoDB",
)
op.create_index("user_id", "user", ["user_id"], unique=False)
op.create_table(
"admin",
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column("user_id", sa.BigInteger(), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["user.user_id"], name="admin_ibfk_1"),
sa.PrimaryKeyConstraint("id"),
mysql_collate="utf8mb4_general_ci",
mysql_default_charset="utf8mb4",
mysql_engine="InnoDB",
)
op.create_table(
old_cookies_database_name1,
sa.Column("cookies", sa.JSON(), nullable=True),
sa.Column(
"status",
sa.Enum("STATUS_SUCCESS", "INVALID_COOKIES", "TOO_MANY_REQUESTS", name="cookiesstatusenum"),
nullable=True,
),
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column("user_id", sa.BigInteger(), autoincrement=False, nullable=True),
sa.ForeignKeyConstraint(["user_id"], ["user.user_id"], name="mihoyo_cookies_ibfk_1"),
sa.PrimaryKeyConstraint("id"),
mysql_collate="utf8mb4_general_ci",
mysql_default_charset="utf8mb4",
mysql_engine="InnoDB",
)
op.create_table(
old_cookies_database_name2,
sa.Column("cookies", sa.JSON(), nullable=True),
sa.Column(
"status",
sa.Enum("STATUS_SUCCESS", "INVALID_COOKIES", "TOO_MANY_REQUESTS", name="cookiesstatusenum"),
nullable=True,
),
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column("user_id", sa.BigInteger(), autoincrement=False, nullable=True),
sa.ForeignKeyConstraint(["user_id"], ["user.user_id"], name="hoyoverse_cookies_ibfk_1"),
sa.PrimaryKeyConstraint("id"),
mysql_collate="utf8mb4_general_ci",
mysql_default_charset="utf8mb4",
mysql_engine="InnoDB",
)
op.create_foreign_key("sign_ibfk_1", "sign", "user", ["user_id"], ["user_id"])
op.create_index("user_id", "sign", ["user_id"], unique=False)
op.drop_table("users")
op.drop_table("players")
op.drop_table("cookies")
op.drop_table("players_info")
# ### end Alembic commands ###

View File

@ -1,14 +0,0 @@
from core.service import init_service
from core.base.mysql import MySQL
from core.base.redisdb import RedisDB
from core.admin.cache import BotAdminCache
from core.admin.repositories import BotAdminRepository
from core.admin.services import BotAdminService
@init_service
def create_bot_admin_service(mysql: MySQL, redis: RedisDB):
_cache = BotAdminCache(redis)
_repository = BotAdminRepository(mysql)
_service = BotAdminService(_repository, _cache)
return _service

View File

@ -1,38 +0,0 @@
from typing import List
from core.base.redisdb import RedisDB
class BotAdminCache:
def __init__(self, redis: RedisDB):
self.client = redis.client
self.qname = "bot:admin"
async def get_list(self):
return [int(str_data) for str_data in await self.client.lrange(self.qname, 0, -1)]
async def set_list(self, str_list: List[int], ttl: int = -1):
await self.client.ltrim(self.qname, 1, 0)
await self.client.lpush(self.qname, *str_list)
if ttl != -1:
await self.client.expire(self.qname, ttl)
count = await self.client.llen(self.qname)
return count
class GroupAdminCache:
def __init__(self, redis: RedisDB):
self.client = redis.client
self.qname = "group:admin_list"
async def get_chat_admin(self, chat_id: int):
qname = f"{self.qname}:{chat_id}"
return [int(str_id) for str_id in await self.client.lrange(qname, 0, -1)]
async def set_chat_admin(self, chat_id: int, admin_list: List[int]):
qname = f"{self.qname}:{chat_id}"
await self.client.ltrim(qname, 1, 0)
await self.client.lpush(qname, *admin_list)
await self.client.expire(qname, 60)
count = await self.client.llen(qname)
return count

View File

@ -1,8 +0,0 @@
from sqlmodel import SQLModel, Field
class Admin(SQLModel, table=True):
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
id: int = Field(primary_key=True)
user_id: int = Field(foreign_key="user.user_id")

View File

@ -1,33 +0,0 @@
from typing import List, cast
from sqlalchemy import select
from sqlmodel.ext.asyncio.session import AsyncSession
from core.admin.models import Admin
from core.base.mysql import MySQL
class BotAdminRepository:
def __init__(self, mysql: MySQL):
self.mysql = mysql
async def delete_by_user_id(self, user_id: int):
async with self.mysql.Session() as session:
session = cast(AsyncSession, session)
statement = select(Admin).where(Admin.user_id == user_id)
results = await session.exec(statement)
admin = results.one()
await session.delete(admin)
async def add_by_user_id(self, user_id: int):
async with self.mysql.Session() as session:
admin = Admin(user_id=user_id)
session.add(admin)
await session.commit()
async def get_all_user_id(self) -> List[int]:
async with self.mysql.Session() as session:
query = select(Admin)
results = await session.exec(query)
admins = results.all()
return [admin[0].user_id for admin in admins]

View File

@ -1,60 +0,0 @@
from typing import List
from asyncmy.errors import IntegrityError
from telegram import Bot
from core.admin.cache import BotAdminCache, GroupAdminCache
from core.admin.repositories import BotAdminRepository
from core.config import config
from utils.log import logger
class BotAdminService:
def __init__(self, repository: BotAdminRepository, cache: BotAdminCache):
self._repository = repository
self._cache = cache
async def get_admin_list(self) -> List[int]:
admin_list = await self._cache.get_list()
if len(admin_list) == 0:
admin_list = await self._repository.get_all_user_id()
for config_admin in config.admins:
admin_list.append(config_admin.user_id)
await self._cache.set_list(admin_list)
return admin_list
async def add_admin(self, user_id: int) -> bool:
try:
await self._repository.add_by_user_id(user_id)
except IntegrityError:
logger.warning("用户 %s 已经存在 Admin 数据库", user_id)
admin_list = await self._repository.get_all_user_id()
for config_admin in config.admins:
admin_list.append(config_admin.user_id)
await self._cache.set_list(admin_list)
return True
async def delete_admin(self, user_id: int) -> bool:
try:
await self._repository.delete_by_user_id(user_id)
except ValueError:
return False
admin_list = await self._repository.get_all_user_id()
for config_admin in config.admins:
admin_list.append(config_admin.user_id)
await self._cache.set_list(admin_list)
return True
class GroupAdminService:
def __init__(self, cache: GroupAdminCache):
self._cache = cache
async def get_admins(self, bot: Bot, chat_id: int, extra_user: List[int]) -> List[int]:
admin_id_list = await self._cache.get_chat_admin(chat_id)
if len(admin_id_list) == 0:
admin_list = await bot.get_chat_administrators(chat_id)
admin_id_list = [admin.user.id for admin in admin_list]
await self._cache.set_chat_admin(chat_id, admin_id_list)
admin_id_list += extra_user
return admin_id_list

287
core/application.py Normal file
View File

@ -0,0 +1,287 @@
"""BOT"""
import asyncio
import signal
from functools import wraps
from signal import SIGABRT, SIGINT, SIGTERM, signal as signal_func
from ssl import SSLZeroReturnError
from typing import Callable, List, Optional, TYPE_CHECKING, TypeVar
import pytz
import uvicorn
from fastapi import FastAPI
from telegram import Bot, Update
from telegram.error import NetworkError, TelegramError, TimedOut
from telegram.ext import (
Application as TelegramApplication,
ApplicationBuilder as TelegramApplicationBuilder,
Defaults,
JobQueue,
)
from typing_extensions import ParamSpec
from uvicorn import Server
from core.config import config as application_config
from core.handler.limiterhandler import LimiterHandler
from core.manager import Managers
from core.override.telegram import HTTPXRequest
from utils.const import WRAPPER_ASSIGNMENTS
from utils.log import logger
from utils.models.signal import Singleton
if TYPE_CHECKING:
from asyncio import Task
from types import FrameType
__all__ = ("Application",)
R = TypeVar("R")
T = TypeVar("T")
P = ParamSpec("P")
class Application(Singleton):
"""Application"""
_web_server_task: Optional["Task"] = None
_startup_funcs: List[Callable] = []
_shutdown_funcs: List[Callable] = []
def __init__(self, managers: "Managers", telegram: "TelegramApplication", web_server: "Server") -> None:
self._running = False
self.managers = managers
self.telegram = telegram
self.web_server = web_server
self.managers.set_application(application=self) # 给 managers 设置 application
self.managers.build_executor("Application")
@classmethod
def build(cls):
managers = Managers()
telegram = (
TelegramApplicationBuilder()
.get_updates_read_timeout(application_config.update_read_timeout)
.get_updates_write_timeout(application_config.update_write_timeout)
.get_updates_connect_timeout(application_config.update_connect_timeout)
.get_updates_pool_timeout(application_config.update_pool_timeout)
.defaults(Defaults(tzinfo=pytz.timezone("Asia/Shanghai")))
.token(application_config.bot_token)
.request(
HTTPXRequest(
connection_pool_size=application_config.connection_pool_size,
proxy_url=application_config.proxy_url,
read_timeout=application_config.read_timeout,
write_timeout=application_config.write_timeout,
connect_timeout=application_config.connect_timeout,
pool_timeout=application_config.pool_timeout,
)
)
.build()
)
web_server = Server(
uvicorn.Config(
app=FastAPI(debug=application_config.debug),
port=application_config.webserver.port,
host=application_config.webserver.host,
log_config=None,
)
)
return cls(managers, telegram, web_server)
@property
def running(self) -> bool:
"""bot 是否正在运行"""
with self._lock:
return self._running
@property
def web_app(self) -> FastAPI:
"""fastapi app"""
return self.web_server.config.app
@property
def bot(self) -> Optional[Bot]:
return self.telegram.bot
@property
def job_queue(self) -> Optional[JobQueue]:
return self.telegram.job_queue
async def _on_startup(self) -> None:
for func in self._startup_funcs:
await self.managers.executor(func, block=getattr(func, "block", False))
async def _on_shutdown(self) -> None:
for func in self._shutdown_funcs:
await self.managers.executor(func, block=getattr(func, "block", False))
async def initialize(self):
"""BOT 初始化"""
self.telegram.add_handler(LimiterHandler(limit_time=10), group=-1) # 启用入口洪水限制
await self.managers.start_dependency() # 启动基础服务
await self.managers.init_components() # 实例化组件
await self.managers.start_services() # 启动其他服务
await self.managers.install_plugins() # 安装插件
async def shutdown(self):
"""BOT 关闭"""
await self.managers.uninstall_plugins() # 卸载插件
await self.managers.stop_services() # 终止其他服务
await self.managers.stop_dependency() # 终止基础服务
async def start(self) -> None:
"""启动 BOT"""
logger.info("正在启动 BOT 中...")
def error_callback(exc: TelegramError) -> None:
"""错误信息回调"""
self.telegram.create_task(self.telegram.process_error(error=exc, update=None))
await self.telegram.initialize()
logger.info("[blue]Telegram[/] 初始化成功", extra={"markup": True})
if application_config.webserver.enable: # 如果使用 web app
server_config = self.web_server.config
server_config.setup_event_loop()
if not server_config.loaded:
server_config.load()
self.web_server.lifespan = server_config.lifespan_class(server_config)
try:
await self.web_server.startup()
except OSError as e:
if e.errno == 10048:
logger.error("Web Server 端口被占用:%s", e)
logger.error("Web Server 启动失败,正在退出")
raise SystemExit from None
if self.web_server.should_exit:
logger.error("Web Server 启动失败,正在退出")
raise SystemExit from None
logger.success("Web Server 启动成功")
self._web_server_task = asyncio.create_task(self.web_server.main_loop())
for _ in range(5): # 连接至 telegram 服务器
try:
await self.telegram.updater.start_polling(
error_callback=error_callback, allowed_updates=Update.ALL_TYPES
)
break
except TimedOut:
logger.warning("连接至 [blue]telegram[/] 服务器失败,正在重试", extra={"markup": True})
continue
except NetworkError as e:
logger.exception()
if isinstance(e, SSLZeroReturnError):
logger.error("代理服务出现异常, 请检查您的代理服务是否配置成功.")
else:
logger.error("网络连接出现问题, 请检查您的网络状况.")
raise SystemExit from e
await self.initialize()
logger.success("BOT 初始化成功")
logger.debug("BOT 开始启动")
await self._on_startup()
await self.telegram.start()
self._running = True
logger.success("BOT 启动成功")
def stop_signal_handler(self, signum: int):
"""终止信号处理"""
signals = {k: v for v, k in signal.__dict__.items() if v.startswith("SIG") and not v.startswith("SIG_")}
logger.debug("接收到了终止信号 %s 正在退出...", signals[signum])
if self._web_server_task:
self._web_server_task.cancel()
async def idle(self) -> None:
"""在接收到中止信号之前堵塞loop"""
task = None
def stop_handler(signum: int, _: "FrameType") -> None:
self.stop_signal_handler(signum)
task.cancel()
for s in (SIGINT, SIGTERM, SIGABRT):
signal_func(s, stop_handler)
while True:
task = asyncio.create_task(asyncio.sleep(600))
try:
await task
except asyncio.CancelledError:
break
async def stop(self) -> None:
"""关闭"""
logger.info("BOT 正在关闭")
self._running = False
await self._on_shutdown()
if self.telegram.updater.running:
await self.telegram.updater.stop()
await self.shutdown()
if self.telegram.running:
await self.telegram.stop()
await self.telegram.shutdown()
if self.web_server is not None:
try:
await self.web_server.shutdown()
logger.info("Web Server 已经关闭")
except AttributeError:
pass
logger.success("BOT 关闭成功")
def launch(self) -> None:
"""启动"""
loop = asyncio.get_event_loop()
try:
loop.run_until_complete(self.start())
loop.run_until_complete(self.idle())
except (SystemExit, KeyboardInterrupt) as exc:
logger.debug("接收到了终止信号BOT 即将关闭", exc_info=exc) # 接收到了终止信号
except NetworkError as e:
if isinstance(e, SSLZeroReturnError):
logger.critical("代理服务出现异常, 请检查您的代理服务是否配置成功.")
else:
logger.critical("网络连接出现问题, 请检查您的网络状况.")
except Exception as e:
logger.critical("遇到了未知错误: %s", {type(e)}, exc_info=e)
finally:
loop.run_until_complete(self.stop())
if application_config.reload:
raise SystemExit from None
def on_startup(self, func: Callable[P, R]) -> Callable[P, R]:
"""注册一个在 BOT 启动时执行的函数"""
if func not in self._startup_funcs:
self._startup_funcs.append(func)
# noinspection PyTypeChecker
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return func(*args, **kwargs)
return wrapper
def on_shutdown(self, func: Callable[P, R]) -> Callable[P, R]:
"""注册一个在 BOT 停止时执行的函数"""
if func not in self._shutdown_funcs:
self._shutdown_funcs.append(func)
# noinspection PyTypeChecker
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return func(*args, **kwargs)
return wrapper

View File

@ -1,31 +0,0 @@
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlmodel.ext.asyncio.session import AsyncSession
from typing_extensions import Self
from core.config import BotConfig
from core.service import Service
class MySQL(Service):
@classmethod
def from_config(cls, config: BotConfig) -> Self:
return cls(**config.mysql.dict())
def __init__(self, host: str, port: int, username: str, password: str, database: str):
self.database = database
self.password = password
self.user = username
self.port = port
self.host = host
self.url = f"mysql+asyncmy://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
self.engine = create_async_engine(self.url)
self.Session = sessionmaker(bind=self.engine, class_=AsyncSession)
async def get_session(self):
"""获取会话"""
async with self.Session() as session:
yield session
async def stop(self):
self.Session.close_all()

View File

@ -1,64 +0,0 @@
import asyncio
import uvicorn
from fastapi import FastAPI
from core.config import (
BotConfig,
config as botConfig,
)
from core.service import Service
__all__ = ["webapp", "WebServer"]
webapp = FastAPI(debug=botConfig.debug)
@webapp.get("/")
def index():
return {"Hello": "Paimon"}
class WebServer(Service):
debug: bool
host: str
port: int
server: uvicorn.Server
_server_task: asyncio.Task
@classmethod
def from_config(cls, config: BotConfig) -> Service:
return cls(debug=config.debug, host=config.webserver.host, port=config.webserver.port)
def __init__(self, debug: bool, host: str, port: int):
self.debug = debug
self.host = host
self.port = port
self.server = uvicorn.Server(
uvicorn.Config(app=webapp, port=port, use_colors=False, host=host, log_config=None)
)
async def start(self):
"""启动 service"""
# 暂时只在开发环境启动 webserver 用于开发调试
if not self.debug:
return
# 防止 uvicorn server 拦截 signals
self.server.install_signal_handlers = lambda: None
self._server_task = asyncio.create_task(self.server.serve())
async def stop(self):
"""关闭 service"""
if not self.debug:
return
self.server.should_exit = True
# 等待 task 结束
await self._server_task

60
core/base_service.py Normal file
View File

@ -0,0 +1,60 @@
from abc import ABC
from itertools import chain
from typing import ClassVar, Iterable, Type, TypeVar
from typing_extensions import Self
from utils.helpers import isabstract
__all__ = ("BaseService", "BaseServiceType", "DependenceType", "ComponentType", "get_all_services")
class _BaseService:
"""服务基类"""
_is_component: ClassVar[bool] = False
_is_dependence: ClassVar[bool] = False
def __init_subclass__(cls, load: bool = True, **kwargs):
cls.is_dependence = cls._is_dependence
cls.is_component = cls._is_component
cls.load = load
async def __aenter__(self) -> Self:
await self.initialize()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.shutdown()
async def initialize(self) -> None:
"""Initialize resources used by this service"""
async def shutdown(self) -> None:
"""Stop & clear resources used by this service"""
class _Dependence(_BaseService, ABC):
_is_dependence: ClassVar[bool] = True
class _Component(_BaseService, ABC):
_is_component: ClassVar[bool] = True
class BaseService(_BaseService, ABC):
Dependence: Type[_BaseService] = _Dependence
Component: Type[_BaseService] = _Component
BaseServiceType = TypeVar("BaseServiceType", bound=_BaseService)
DependenceType = TypeVar("DependenceType", bound=_Dependence)
ComponentType = TypeVar("ComponentType", bound=_Component)
# noinspection PyProtectedMember
def get_all_services() -> Iterable[Type[_BaseService]]:
return filter(
lambda x: x.__name__[0] != "_" and x.load and not isabstract(x),
chain(BaseService.__subclasses__(), _Dependence.__subclasses__(), _Component.__subclasses__()),
)

29
core/basemodel.py Normal file
View File

@ -0,0 +1,29 @@
import enum
try:
import ujson as jsonlib
except ImportError:
import json as jsonlib
from pydantic import BaseSettings
__all__ = ("RegionEnum", "Settings")
class RegionEnum(int, enum.Enum):
"""账号数据所在服务器"""
NULL = 0
HYPERION = 1 # 米忽悠国服 hyperion
HOYOLAB = 2 # 米忽悠国际服 hoyolab
class Settings(BaseSettings):
def __new__(cls, *args, **kwargs):
cls.update_forward_refs()
return super(Settings, cls).__new__(cls) # pylint: disable=E1120
class Config(BaseSettings.Config):
case_sensitive = False
json_loads = jsonlib.loads
json_dumps = jsonlib.dumps

View File

@ -1,69 +0,0 @@
from telegram import Update, ReplyKeyboardRemove
from telegram.error import BadRequest, Forbidden
from telegram.ext import CallbackContext, ConversationHandler
from core.plugin import handler, conversation
from utils.bot import get_chat
from utils.log import logger
async def clean_message(context: CallbackContext):
job = context.job
message_id = job.data
chat_info = f"chat_id[{job.chat_id}]"
try:
chat = await get_chat(job.chat_id)
full_name = chat.full_name
if full_name:
chat_info = f"{full_name}[{chat.id}]"
else:
chat_info = f"{chat.title}[{chat.id}]"
except (BadRequest, Forbidden) as exc:
logger.warning("获取 chat info 失败 %s", exc.message)
except Exception as exc:
logger.warning("获取 chat info 消息失败 %s", str(exc))
logger.debug("删除消息 %s message_id[%s]", chat_info, message_id)
try:
# noinspection PyTypeChecker
await context.bot.delete_message(chat_id=job.chat_id, message_id=message_id)
except BadRequest as exc:
if "not found" in exc.message:
logger.warning("删除消息 %s message_id[%s] 失败 消息不存在", chat_info, message_id)
elif "Message can't be deleted" in exc.message:
logger.warning("删除消息 %s message_id[%s] 失败 消息无法删除 可能是没有授权", chat_info, message_id)
else:
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
except Forbidden as exc:
if "bot was kicked" in exc.message:
logger.warning("删除消息 %s message_id[%s] 失败 已经被踢出群", chat_info, message_id)
else:
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
def add_delete_message_job(context: CallbackContext, chat_id: int, message_id: int, delete_seconds: int):
context.job_queue.run_once(
callback=clean_message,
when=delete_seconds,
data=message_id,
name=f"{chat_id}|{message_id}|clean_message",
chat_id=chat_id,
job_kwargs={"replace_existing": True, "id": f"{chat_id}|{message_id}|clean_message"},
)
class _BasePlugin:
@staticmethod
def _add_delete_message_job(context: CallbackContext, chat_id: int, message_id: int, delete_seconds: int = 60):
return add_delete_message_job(context, chat_id, message_id, delete_seconds)
class _Conversation(_BasePlugin):
@conversation.fallback
@handler.command(command="cancel", block=True)
async def cancel(self, update: Update, _: CallbackContext) -> int:
await update.effective_message.reply_text("退出命令", reply_markup=ReplyKeyboardRemove())
return ConversationHandler.END
class BasePlugin(_BasePlugin):
Conversation = _Conversation

View File

@ -1,345 +0,0 @@
import asyncio
import inspect
import os
from asyncio import CancelledError
from importlib import import_module
from multiprocessing import RLock as Lock
from pathlib import Path
from typing import Any, Callable, ClassVar, Dict, Iterator, List, NoReturn, Optional, TYPE_CHECKING, Type, TypeVar
import genshin
import pytz
from async_timeout import timeout
from telegram import Update
from telegram import __version__ as tg_version
from telegram.error import NetworkError, TimedOut
from telegram.ext import (
AIORateLimiter,
Application as TgApplication,
CallbackContext,
Defaults,
JobQueue,
MessageHandler,
filters,
TypeHandler,
)
from telegram.ext.filters import StatusUpdate
from core.config import BotConfig, config # pylint: disable=W0611
from core.error import ServiceNotFoundError
# noinspection PyProtectedMember
from core.plugin import Plugin, _Plugin
from core.service import Service
from metadata.scripts.metadatas import make_github_fast
from utils.const import PLUGIN_DIR, PROJECT_ROOT
from utils.log import logger
__all__ = ["bot"]
T = TypeVar("T")
PluginType = TypeVar("PluginType", bound=_Plugin)
try:
from telegram import __version_info__ as tg_version_info
except ImportError:
tg_version_info = (0, 0, 0, 0, 0) # type: ignore[assignment]
if tg_version_info < (20, 0, 0, "alpha", 6):
logger.warning(
"Bot与当前PTB版本 [cyan bold]%s[/] [red bold]不兼容[/],请更新到最新版本后使用 [blue bold]poetry install[/] 重新安装依赖",
tg_version,
extra={"markup": True},
)
class Bot:
_lock: ClassVar[Lock] = Lock()
_instance: ClassVar[Optional["Bot"]] = None
def __new__(cls, *args, **kwargs) -> "Bot":
"""实现单例"""
with cls._lock: # 使线程、进程安全
if cls._instance is None:
cls._instance = object.__new__(cls)
return cls._instance
app: Optional[TgApplication] = None
_config: BotConfig = config
_services: Dict[Type[T], T] = {}
_running: bool = False
def _inject(self, signature: inspect.Signature, target: Callable[..., T]) -> T:
kwargs = {}
for name, parameter in signature.parameters.items():
if name != "self" and parameter.annotation != inspect.Parameter.empty:
if value := self._services.get(parameter.annotation):
kwargs[name] = value
return target(**kwargs)
def init_inject(self, target: Callable[..., T]) -> T:
"""用于实例化Plugin的方法。用于给插件传入一些必要组件如 MySQL、Redis等"""
if isinstance(target, type):
signature = inspect.signature(target.__init__)
else:
signature = inspect.signature(target)
return self._inject(signature, target)
async def async_inject(self, target: Callable[..., T]) -> T:
return await self._inject(inspect.signature(target), target)
def _gen_pkg(self, root: Path) -> Iterator[str]:
"""生成可以用于 import_module 导入的字符串"""
for path in root.iterdir():
if not path.name.startswith("_"):
if path.is_dir():
yield from self._gen_pkg(path)
elif path.suffix == ".py":
yield str(path.relative_to(PROJECT_ROOT).with_suffix("")).replace(os.sep, ".")
async def install_plugins(self):
"""安装插件"""
for pkg in self._gen_pkg(PLUGIN_DIR):
try:
import_module(pkg) # 导入插件
except Exception as e: # pylint: disable=W0703
logger.exception(
'在导入文件 "%s" 的过程中遇到了错误 [red bold]%s[/]', pkg, type(e).__name__, exc_info=e, extra={"markup": True}
)
continue # 如有错误则继续
callback_dict: Dict[int, List[Callable]] = {}
for plugin_cls in {*Plugin.__subclasses__(), *Plugin.Conversation.__subclasses__()}:
path = f"{plugin_cls.__module__}.{plugin_cls.__name__}"
try:
plugin: PluginType = self.init_inject(plugin_cls)
if hasattr(plugin, "__async_init__"):
await self.async_inject(plugin.__async_init__)
handlers = plugin.handlers
for index, handler in enumerate(handlers):
if isinstance(handler, TypeHandler): # 对 TypeHandler 进行特殊处理,优先级必须设置 -1否则无用
handlers.pop(index)
self.app.add_handler(handler, group=-1)
self.app.add_handlers(handlers)
if handlers:
logger.debug('插件 "%s" 添加了 %s 个 handler ', path, len(handlers))
# noinspection PyProtectedMember
for priority, callback in plugin._new_chat_members_handler_funcs(): # pylint: disable=W0212
if not callback_dict.get(priority):
callback_dict[priority] = []
callback_dict[priority].append(callback)
error_handlers = plugin.error_handlers
for callback, block in error_handlers.items():
self.app.add_error_handler(callback, block)
if error_handlers:
logger.debug('插件 "%s" 添加了 %s 个 error handler ', path, len(error_handlers))
if jobs := plugin.jobs:
logger.debug('插件 "%s" 添加了 %s 个 jobs ', path, len(jobs))
logger.success('插件 "%s" 载入成功', path)
except Exception as e: # pylint: disable=W0703
logger.exception(
'在安装插件 "%s" 的过程中遇到了错误 [red bold]%s[/]', path, type(e).__name__, exc_info=e, extra={"markup": True}
)
if callback_dict:
num = sum(len(callback_dict[i]) for i in callback_dict)
async def _new_chat_member_callback(update: "Update", context: "CallbackContext"):
nonlocal callback
for _, value in callback_dict.items():
for callback in value:
await callback(update, context)
self.app.add_handler(
MessageHandler(callback=_new_chat_member_callback, filters=StatusUpdate.NEW_CHAT_MEMBERS, block=False)
)
logger.success(
"成功添加了 %s 个针对 [blue]%s[/] 的 [blue]MessageHandler[/]",
num,
StatusUpdate.NEW_CHAT_MEMBERS,
extra={"markup": True},
)
# special handler
from plugins.system.start import StartPlugin
self.app.add_handler(
MessageHandler(
callback=StartPlugin.unknown_command, filters=filters.COMMAND & filters.ChatType.PRIVATE, block=False
)
)
async def _start_base_services(self):
for pkg in self._gen_pkg(PROJECT_ROOT / "core/base"):
try:
import_module(pkg)
except Exception as e: # pylint: disable=W0703
logger.exception(
'在导入文件 "%s" 的过程中遇到了错误 [red bold]%s[/]', pkg, type(e).__name__, exc_info=e, extra={"markup": True}
)
raise SystemExit from e
for base_service_cls in Service.__subclasses__():
try:
if hasattr(base_service_cls, "from_config"):
instance = base_service_cls.from_config(self._config)
else:
instance = self.init_inject(base_service_cls)
await instance.start()
logger.success('服务 "%s" 初始化成功', base_service_cls.__name__)
self._services.update({base_service_cls: instance})
except Exception as e:
logger.error('服务 "%s" 初始化失败', base_service_cls.__name__)
raise SystemExit from e
async def start_services(self):
"""启动服务"""
await self._start_base_services()
for path in (PROJECT_ROOT / "core").iterdir():
if not path.name.startswith("_") and path.is_dir() and path.name != "base":
pkg = str(path.relative_to(PROJECT_ROOT).with_suffix("")).replace(os.sep, ".")
try:
import_module(pkg)
except Exception as e: # pylint: disable=W0703
logger.exception(
'在导入文件 "%s" 的过程中遇到了错误 [red bold]%s[/]',
pkg,
type(e).__name__,
exc_info=e,
extra={"markup": True},
)
continue
async def stop_services(self):
"""关闭服务"""
if not self._services:
return
logger.info("正在关闭服务")
for _, service in filter(lambda x: not isinstance(x[1], TgApplication), self._services.items()):
async with timeout(5):
try:
if hasattr(service, "stop"):
if inspect.iscoroutinefunction(service.stop):
await service.stop()
else:
service.stop()
logger.success('服务 "%s" 关闭成功', service.__class__.__name__)
except CancelledError:
logger.warning('服务 "%s" 关闭超时', service.__class__.__name__)
except Exception as e: # pylint: disable=W0703
logger.exception('服务 "%s" 关闭失败', service.__class__.__name__, exc_info=e)
async def _post_init(self, context: CallbackContext) -> NoReturn:
logger.info("开始初始化 genshin.py 相关资源")
try:
# 替换为 fastgit 镜像源
for i in dir(genshin.utility.extdb):
if "_URL" in i:
setattr(
genshin.utility.extdb,
i,
make_github_fast(getattr(genshin.utility.extdb, i)),
)
await genshin.utility.update_characters_enka()
except Exception as exc: # pylint: disable=W0703
logger.error("初始化 genshin.py 相关资源失败")
logger.exception(exc)
else:
logger.success("初始化 genshin.py 相关资源成功")
self._services.update({CallbackContext: context})
logger.info("开始初始化服务")
await self.start_services()
logger.info("开始安装插件")
await self.install_plugins()
logger.info("BOT 初始化成功")
def launch(self) -> NoReturn:
"""启动机器人"""
self._running = True
logger.info("正在初始化BOT")
self.app = (
TgApplication.builder()
.read_timeout(self.config.read_timeout)
.write_timeout(self.config.write_timeout)
.connect_timeout(self.config.connect_timeout)
.pool_timeout(self.config.pool_timeout)
.get_updates_read_timeout(self.config.update_read_timeout)
.get_updates_write_timeout(self.config.update_write_timeout)
.get_updates_connect_timeout(self.config.update_connect_timeout)
.get_updates_pool_timeout(self.config.update_pool_timeout)
.rate_limiter(AIORateLimiter())
.defaults(Defaults(tzinfo=pytz.timezone("Asia/Shanghai")))
.token(self._config.bot_token)
.post_init(self._post_init)
.build()
)
try:
for _ in range(5):
try:
self.app.run_polling(
close_loop=False,
timeout=self.config.timeout,
allowed_updates=Update.ALL_TYPES,
)
break
except TimedOut:
logger.warning("连接至 [blue]telegram[/] 服务器失败,正在重试", extra={"markup": True})
continue
except NetworkError as e:
if "SSLZeroReturnError" in str(e):
logger.error("代理服务出现异常, 请检查您的代理服务是否配置成功.")
else:
logger.error("网络连接出现问题, 请检查您的网络状况.")
break
except (SystemExit, KeyboardInterrupt):
pass
except Exception as e: # pylint: disable=W0703
logger.exception("BOT 执行过程中出现错误", exc_info=e)
finally:
loop = asyncio.get_event_loop()
loop.run_until_complete(self.stop_services())
loop.close()
logger.info("BOT 已经关闭")
self._running = False
def find_service(self, target: Type[T]) -> T:
"""查找服务。若没找到则抛出 ServiceNotFoundError"""
if (result := self._services.get(target)) is None:
raise ServiceNotFoundError(target)
return result
def add_service(self, service: T) -> NoReturn:
"""添加服务。若已经有同类型的服务,则会抛出异常"""
if type(service) in self._services:
raise ValueError(f'Service "{type(service)}" is already existed.')
self.update_service(service)
def update_service(self, service: T):
"""更新服务。若服务不存在,则添加;若存在,则更新"""
self._services.update({type(service): service})
def contain_service(self, service: Any) -> bool:
"""判断服务是否存在"""
if isinstance(service, type):
return service in self._services
else:
return service in self._services.values()
@property
def job_queue(self) -> JobQueue:
return self.app.job_queue
@property
def services(self) -> Dict[Type[T], T]:
return self._services
@property
def config(self) -> BotConfig:
return self._config
@property
def is_running(self) -> bool:
return self._running
bot = Bot()

View File

@ -0,0 +1 @@
"""bot builtins"""

38
core/builtins/contexts.py Normal file
View File

@ -0,0 +1,38 @@
"""上下文管理"""
from contextlib import contextmanager
from contextvars import ContextVar
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from telegram.ext import CallbackContext
from telegram import Update
__all__ = [
"CallbackContextCV",
"UpdateCV",
"handler_contexts",
"job_contexts",
]
CallbackContextCV: ContextVar["CallbackContext"] = ContextVar("TelegramContextCallback")
UpdateCV: ContextVar["Update"] = ContextVar("TelegramUpdate")
@contextmanager
def handler_contexts(update: "Update", context: "CallbackContext") -> None:
context_token = CallbackContextCV.set(context)
update_token = UpdateCV.set(update)
try:
yield
finally:
CallbackContextCV.reset(context_token)
UpdateCV.reset(update_token)
@contextmanager
def job_contexts(context: "CallbackContext") -> None:
token = CallbackContextCV.set(context)
try:
yield
finally:
CallbackContextCV.reset(token)

309
core/builtins/dispatcher.py Normal file
View File

@ -0,0 +1,309 @@
"""参数分发器"""
import asyncio
import inspect
from abc import ABC, abstractmethod
from asyncio import AbstractEventLoop
from functools import cached_property, lru_cache, partial, wraps
from inspect import Parameter, Signature
from itertools import chain
from types import GenericAlias, MethodType
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Type,
Union,
)
from arkowrapper import ArkoWrapper
from fastapi import FastAPI
from telegram import Bot as TelegramBot, Chat, Message, Update, User
from telegram.ext import Application as TelegramApplication, CallbackContext, Job
from typing_extensions import ParamSpec
from uvicorn import Server
from core.application import Application
from utils.const import WRAPPER_ASSIGNMENTS
from utils.typedefs import R, T
__all__ = (
"catch",
"AbstractDispatcher",
"BaseDispatcher",
"HandlerDispatcher",
"JobDispatcher",
"dispatched",
)
P = ParamSpec("P")
TargetType = Union[Type, str, Callable[[Any], bool]]
_CATCH_TARGET_ATTR = "_catch_targets"
def catch(*targets: Union[str, Type]) -> Callable[[Callable[P, R]], Callable[P, R]]:
def decorate(func: Callable[P, R]) -> Callable[P, R]:
setattr(func, _CATCH_TARGET_ATTR, targets)
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return func(*args, **kwargs)
return wrapper
return decorate
@lru_cache(64)
def get_signature(func: Union[type, Callable]) -> Signature:
if isinstance(func, type):
return inspect.signature(func.__init__)
return inspect.signature(func)
class AbstractDispatcher(ABC):
"""参数分发器"""
IGNORED_ATTRS = []
_args: List[Any] = []
_kwargs: Dict[Union[str, Type], Any] = {}
_application: "Optional[Application]" = None
def set_application(self, application: "Application") -> None:
self._application = application
@property
def application(self) -> "Application":
if self._application is None:
raise RuntimeError(f"No application was set for this {self.__class__.__name__}.")
return self._application
def __init__(self, *args, **kwargs) -> None:
self._args = list(args)
self._kwargs = dict(kwargs)
for _, value in kwargs.items():
type_arg = type(value)
if type_arg != str:
self._kwargs[type_arg] = value
for arg in args:
type_arg = type(arg)
if type_arg != str:
self._kwargs[type_arg] = arg
@cached_property
def catch_funcs(self) -> List[MethodType]:
# noinspection PyTypeChecker
return list(
ArkoWrapper(dir(self))
.filter(lambda x: not x.startswith("_"))
.filter(
lambda x: x not in self.IGNORED_ATTRS + ["dispatch", "catch_funcs", "catch_func_map", "dispatch_funcs"]
)
.map(lambda x: getattr(self, x))
.filter(lambda x: isinstance(x, MethodType))
.filter(lambda x: hasattr(x, "_catch_targets"))
)
@cached_property
def catch_func_map(self) -> Dict[Union[str, Type[T]], Callable[..., T]]:
result = {}
for catch_func in self.catch_funcs:
catch_targets = getattr(catch_func, _CATCH_TARGET_ATTR)
for catch_target in catch_targets:
result[catch_target] = catch_func
return result
@cached_property
def dispatch_funcs(self) -> List[MethodType]:
return list(
ArkoWrapper(dir(self))
.filter(lambda x: x.startswith("dispatch_by_"))
.map(lambda x: getattr(self, x))
.filter(lambda x: isinstance(x, MethodType))
)
@abstractmethod
def dispatch_by_default(self, parameter: Parameter) -> Parameter:
"""默认的 dispatch 方法"""
@abstractmethod
def dispatch_by_catch_funcs(self, parameter: Parameter) -> Parameter:
"""使用 catch_func 获取并分配参数"""
def dispatch(self, func: Callable[P, R]) -> Callable[..., R]:
"""将参数分配给函数,从而合成一个无需参数即可执行的函数"""
params = {}
signature = get_signature(func)
parameters: Dict[str, Parameter] = dict(signature.parameters)
for name, parameter in list(parameters.items()):
parameter: Parameter
if any(
[
name == "self" and isinstance(func, (type, MethodType)),
parameter.kind in [Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL],
]
):
del parameters[name]
continue
for dispatch_func in self.dispatch_funcs:
parameters[name] = dispatch_func(parameter)
for name, parameter in parameters.items():
if parameter.default != Parameter.empty:
params[name] = parameter.default
else:
params[name] = None
return partial(func, **params)
@catch(Application)
def catch_application(self) -> Application:
return self.application
class BaseDispatcher(AbstractDispatcher):
"""默认参数分发器"""
_instances: Sequence[Any]
def _get_kwargs(self) -> Dict[Type[T], T]:
result = self._get_default_kwargs()
result[AbstractDispatcher] = self
result.update(self._kwargs)
return result
def _get_default_kwargs(self) -> Dict[Type[T], T]:
application = self.application
_default_kwargs = {
FastAPI: application.web_app,
Server: application.web_server,
TelegramApplication: application.telegram,
TelegramBot: application.telegram.bot,
}
if not application.running:
for obj in chain(
application.managers.dependency,
application.managers.components,
application.managers.services,
application.managers.plugins,
):
_default_kwargs[type(obj)] = obj
return {k: v for k, v in _default_kwargs.items() if v is not None}
def dispatch_by_default(self, parameter: Parameter) -> Parameter:
annotation = parameter.annotation
# noinspection PyTypeChecker
if isinstance(annotation, type) and (value := self._get_kwargs().get(annotation, None)) is not None:
parameter._default = value # pylint: disable=W0212
return parameter
def dispatch_by_catch_funcs(self, parameter: Parameter) -> Parameter:
annotation = parameter.annotation
if annotation != Any and isinstance(annotation, GenericAlias):
return parameter
catch_func = self.catch_func_map.get(annotation) or self.catch_func_map.get(parameter.name)
if catch_func is not None:
# noinspection PyUnresolvedReferences,PyProtectedMember
parameter._default = catch_func() # pylint: disable=W0212
return parameter
@catch(AbstractEventLoop)
def catch_loop(self) -> AbstractEventLoop:
return asyncio.get_event_loop()
class HandlerDispatcher(BaseDispatcher):
"""Handler 参数分发器"""
def __init__(self, update: Optional[Update] = None, context: Optional[CallbackContext] = None, **kwargs) -> None:
super().__init__(update=update, context=context, **kwargs)
self._update = update
self._context = context
def dispatch(
self, func: Callable[P, R], *, update: Optional[Update] = None, context: Optional[CallbackContext] = None
) -> Callable[..., R]:
self._update = update or self._update
self._context = context or self._context
if self._update is None:
from core.builtins.contexts import UpdateCV
self._update = UpdateCV.get()
if self._context is None:
from core.builtins.contexts import CallbackContextCV
self._context = CallbackContextCV.get()
return super().dispatch(func)
def dispatch_by_default(self, parameter: Parameter) -> Parameter:
"""HandlerDispatcher 默认不使用 dispatch_by_default"""
return parameter
@catch(Update)
def catch_update(self) -> Update:
return self._update
@catch(CallbackContext)
def catch_context(self) -> CallbackContext:
return self._context
@catch(Message)
def catch_message(self) -> Message:
return self._update.effective_message
@catch(User)
def catch_user(self) -> User:
return self._update.effective_user
@catch(Chat)
def catch_chat(self) -> Chat:
return self._update.effective_chat
class JobDispatcher(BaseDispatcher):
"""Job 参数分发器"""
def __init__(self, context: Optional[CallbackContext] = None, **kwargs) -> None:
super().__init__(context=context, **kwargs)
self._context = context
def dispatch(self, func: Callable[P, R], *, context: Optional[CallbackContext] = None) -> Callable[..., R]:
self._context = context or self._context
if self._context is None:
from core.builtins.contexts import CallbackContextCV
self._context = CallbackContextCV.get()
return super().dispatch(func)
@catch("data")
def catch_data(self) -> Any:
return self._context.job.data
@catch(Job)
def catch_job(self) -> Job:
return self._context.job
@catch(CallbackContext)
def catch_context(self) -> CallbackContext:
return self._context
def dispatched(dispatcher: Type[AbstractDispatcher] = BaseDispatcher):
def decorate(func: Callable[P, R]) -> Callable[P, R]:
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return dispatcher().dispatch(func)(*args, **kwargs)
return wrapper
return decorate

131
core/builtins/executor.py Normal file
View File

@ -0,0 +1,131 @@
"""执行器"""
import inspect
from functools import cached_property
from multiprocessing import RLock as Lock
from typing import Callable, ClassVar, Dict, Generic, Optional, TYPE_CHECKING, Type, TypeVar
from telegram import Update
from telegram.ext import CallbackContext
from typing_extensions import ParamSpec, Self
from core.builtins.contexts import handler_contexts, job_contexts
if TYPE_CHECKING:
from core.application import Application
from core.builtins.dispatcher import AbstractDispatcher, HandlerDispatcher
from multiprocessing.synchronize import RLock as LockType
__all__ = ("BaseExecutor", "Executor", "HandlerExecutor", "JobExecutor")
T = TypeVar("T")
R = TypeVar("R")
P = ParamSpec("P")
class BaseExecutor:
"""执行器
Args:
name(str): 该执行器的名称执行器的名称是唯一的
只支持执行只拥有 POSITIONAL_OR_KEYWORD KEYWORD_ONLY 两种参数类型的函数
"""
_lock: ClassVar["LockType"] = Lock()
_instances: ClassVar[Dict[str, Self]] = {}
_application: "Optional[Application]" = None
def set_application(self, application: "Application") -> None:
self._application = application
@property
def application(self) -> "Application":
if self._application is None:
raise RuntimeError(f"No application was set for this {self.__class__.__name__}.")
return self._application
def __new__(cls: Type[T], name: str, *args, **kwargs) -> T:
with cls._lock:
if (instance := cls._instances.get(name)) is None:
instance = object.__new__(cls)
instance.__init__(name, *args, **kwargs)
cls._instances.update({name: instance})
return instance
@cached_property
def name(self) -> str:
"""当前执行器的名称"""
return self._name
def __init__(self, name: str, dispatcher: Optional[Type["AbstractDispatcher"]] = None) -> None:
self._name = name
self._dispatcher = dispatcher
class Executor(BaseExecutor, Generic[P, R]):
async def __call__(
self,
target: Callable[P, R],
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
**kwargs,
) -> R:
dispatcher = self._dispatcher or dispatcher
dispatcher_instance = dispatcher(**kwargs)
dispatcher_instance.set_application(application=self.application)
dispatched_func = dispatcher_instance.dispatch(target) # 分发参数,组成新函数
# 执行
if inspect.iscoroutinefunction(target):
result = await dispatched_func()
else:
result = dispatched_func()
return result
class HandlerExecutor(BaseExecutor, Generic[P, R]):
"""Handler专用执行器"""
_callback: Callable[P, R]
_dispatcher: "HandlerDispatcher"
def __init__(self, func: Callable[P, R], dispatcher: Optional[Type["HandlerDispatcher"]] = None) -> None:
if dispatcher is None:
from core.builtins.dispatcher import HandlerDispatcher
dispatcher = HandlerDispatcher
super().__init__("handler", dispatcher)
self._callback = func
self._dispatcher = dispatcher()
def set_application(self, application: "Application") -> None:
self._application = application
if self._dispatcher is not None:
self._dispatcher.set_application(application)
async def __call__(self, update: Update, context: CallbackContext) -> R:
with handler_contexts(update, context):
dispatched_func = self._dispatcher.dispatch(self._callback, update=update, context=context)
return await dispatched_func()
class JobExecutor(BaseExecutor):
"""Job 专用执行器"""
def __init__(self, func: Callable[P, R], dispatcher: Optional[Type["AbstractDispatcher"]] = None) -> None:
if dispatcher is None:
from core.builtins.dispatcher import JobDispatcher
dispatcher = JobDispatcher
super().__init__("job", dispatcher)
self._callback = func
self._dispatcher = dispatcher()
def set_application(self, application: "Application") -> None:
self._application = application
if self._dispatcher is not None:
self._dispatcher.set_application(application)
async def __call__(self, context: CallbackContext) -> R:
with job_contexts(context):
dispatched_func = self._dispatcher.dispatch(self._callback, context=context)
return await dispatched_func()

185
core/builtins/reloader.py Normal file
View File

@ -0,0 +1,185 @@
import inspect
import multiprocessing
import os
import signal
import threading
from pathlib import Path
from typing import Callable, Iterator, List, Optional, TYPE_CHECKING
from watchfiles import watch
from utils.const import HANDLED_SIGNALS, PROJECT_ROOT
from utils.log import logger
from utils.typedefs import StrOrPath
if TYPE_CHECKING:
from multiprocessing.process import BaseProcess
__all__ = ("Reloader",)
multiprocessing.allow_connection_pickling()
spawn = multiprocessing.get_context("spawn")
class FileFilter:
"""监控文件过滤"""
def __init__(self, includes: List[str], excludes: List[str]) -> None:
default_includes = ["*.py"]
self.includes = [default for default in default_includes if default not in excludes]
self.includes.extend(includes)
self.includes = list(set(self.includes))
default_excludes = [".*", ".py[cod]", ".sw.*", "~*", __file__]
self.excludes = [default for default in default_excludes if default not in includes]
self.exclude_dirs = []
for e in excludes:
p = Path(e)
try:
is_dir = p.is_dir()
except OSError:
is_dir = False
if is_dir:
self.exclude_dirs.append(p)
else:
self.excludes.append(e)
self.excludes = list(set(self.excludes))
def __call__(self, path: Path) -> bool:
for include_pattern in self.includes:
if path.match(include_pattern):
for exclude_dir in self.exclude_dirs:
if exclude_dir in path.parents:
return False
for exclude_pattern in self.excludes:
if path.match(exclude_pattern):
return False
return True
return False
class Reloader:
_target: Callable[..., None]
_process: "BaseProcess"
@property
def process(self) -> "BaseProcess":
return self._process
@property
def target(self) -> Callable[..., None]:
return self._target
def __init__(
self,
target: Callable[..., None],
*,
reload_delay: float = 0.25,
reload_dirs: List[StrOrPath] = None,
reload_includes: List[str] = None,
reload_excludes: List[str] = None,
):
if inspect.iscoroutinefunction(target):
raise ValueError("不支持异步函数")
self._target = target
self.reload_delay = reload_delay
_reload_dirs = []
for reload_dir in reload_dirs or []:
_reload_dirs.append(PROJECT_ROOT.joinpath(Path(reload_dir)))
self.reload_dirs = []
for reload_dir in _reload_dirs:
append = True
for parent in reload_dir.parents:
if parent in _reload_dirs:
append = False
break
if append:
self.reload_dirs.append(reload_dir)
if not self.reload_dirs:
logger.warning("需要检测的目标文件夹列表为空", extra={"tag": "Reloader"})
self._should_exit = threading.Event()
frame = inspect.currentframe().f_back
self.watch_filter = FileFilter(reload_includes or [], (reload_excludes or []) + [frame.f_globals["__file__"]])
self.watcher = watch(
*self.reload_dirs,
watch_filter=None,
stop_event=self._should_exit,
yield_on_timeout=True,
)
def get_changes(self) -> Optional[List[Path]]:
if not self._process.is_alive():
logger.info("目标进程已经关闭", extra={"tag": "Reloader"})
self._should_exit.set()
try:
changes = next(self.watcher)
except StopIteration:
return None
if changes:
unique_paths = {Path(c[1]) for c in changes}
return [p for p in unique_paths if self.watch_filter(p)]
return None
def __iter__(self) -> Iterator[Optional[List[Path]]]:
return self
def __next__(self) -> Optional[List[Path]]:
return self.get_changes()
def run(self) -> None:
self.startup()
for changes in self:
if changes:
logger.warning(
"检测到文件 %s 发生改变, 正在重载...",
[str(c.relative_to(PROJECT_ROOT)).replace(os.sep, "/") for c in changes],
extra={"tag": "Reloader"},
)
self.restart()
self.shutdown()
def signal_handler(self, *_) -> None:
"""当接收到结束信号量时"""
self._process.join(3)
if self._process.is_alive():
self._process.terminate()
self._process.join()
self._should_exit.set()
def startup(self) -> None:
"""启动进程"""
logger.info("目标进程正在启动", extra={"tag": "Reloader"})
for sig in HANDLED_SIGNALS:
signal.signal(sig, self.signal_handler)
self._process = spawn.Process(target=self._target)
self._process.start()
logger.success("目标进程启动成功", extra={"tag": "Reloader"})
def restart(self) -> None:
"""重启进程"""
self._process.terminate()
self._process.join(10)
self._process = spawn.Process(target=self._target)
self._process.start()
logger.info("目标进程已经重载", extra={"tag": "Reloader"})
def shutdown(self) -> None:
"""关闭进程"""
self._process.terminate()
self._process.join(10)
logger.info("重载器已经关闭", extra={"tag": "Reloader"})

View File

@ -1,19 +1,15 @@
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import ( from typing import List, Optional, Union
List,
Optional,
Union,
)
import dotenv import dotenv
from pydantic import AnyUrl, BaseModel, Field from pydantic import AnyUrl, Field
from core.basemodel import Settings
from utils.const import PROJECT_ROOT from utils.const import PROJECT_ROOT
from utils.models.base import Settings
from utils.typedefs import NaturalNumber from utils.typedefs import NaturalNumber
__all__ = ["BotConfig", "config", "JoinGroups"] __all__ = ("ApplicationConfig", "config", "JoinGroups")
dotenv.load_dotenv() dotenv.load_dotenv()
@ -25,22 +21,12 @@ class JoinGroups(str, Enum):
ALLOW_ALL = "ALLOW_ALL" ALLOW_ALL = "ALLOW_ALL"
class ConfigChannel(BaseModel):
name: str
chat_id: int
class ConfigUser(BaseModel):
username: Optional[str]
user_id: int
class MySqlConfig(Settings): class MySqlConfig(Settings):
host: str = "127.0.0.1" host: str = "127.0.0.1"
port: int = 3306 port: int = 3306
username: str = None username: Optional[str] = None
password: str = None password: Optional[str] = None
database: str = None database: Optional[str] = None
class Config(Settings.Config): class Config(Settings.Config):
env_prefix = "db_" env_prefix = "db_"
@ -58,7 +44,7 @@ class RedisConfig(Settings):
class LoggerConfig(Settings): class LoggerConfig(Settings):
name: str = "TGPaimon" name: str = "TGPaimon"
width: int = 180 width: Optional[int] = None
time_format: str = "[%Y-%m-%d %X]" time_format: str = "[%Y-%m-%d %X]"
traceback_max_frames: int = 20 traceback_max_frames: int = 20
path: Path = PROJECT_ROOT / "logs" path: Path = PROJECT_ROOT / "logs"
@ -78,6 +64,9 @@ class MTProtoConfig(Settings):
class WebServerConfig(Settings): class WebServerConfig(Settings):
enable: bool = False
"""是否启用WebServer"""
url: AnyUrl = "http://localhost:8080" url: AnyUrl = "http://localhost:8080"
host: str = "localhost" host: str = "localhost"
port: int = 8080 port: int = 8080
@ -97,6 +86,16 @@ class ErrorConfig(Settings):
env_prefix = "error_" env_prefix = "error_"
class ReloadConfig(Settings):
delay: float = 0.25
dirs: List[str] = []
include: List[str] = []
exclude: List[str] = []
class Config(Settings.Config):
env_prefix = "reload_"
class NoticeConfig(Settings): class NoticeConfig(Settings):
user_mismatch: str = "再乱点我叫西风骑士团、千岩军、天领奉行、三十人团和风纪官了!" user_mismatch: str = "再乱点我叫西风骑士团、千岩军、天领奉行、三十人团和风纪官了!"
@ -104,24 +103,32 @@ class NoticeConfig(Settings):
env_prefix = "notice_" env_prefix = "notice_"
class PluginConfig(Settings): class ApplicationConfig(Settings):
download_file_max_size: int = 5
class Config(Settings.Config):
env_prefix = "plugin_"
class BotConfig(Settings):
debug: bool = False debug: bool = False
"""debug 开关"""
retry: int = 5
"""重试次数"""
auto_reload: bool = False
"""自动重载"""
proxy_url: Optional[AnyUrl] = None
"""代理链接"""
bot_token: str = "" bot_token: str = ""
"""BOT的token"""
owner: Optional[int] = None
channels: List[int] = []
"""文章推送群组"""
channels: List["ConfigChannel"] = []
admins: List["ConfigUser"] = []
verify_groups: List[Union[int, str]] = [] verify_groups: List[Union[int, str]] = []
"""启用群验证功能的群组"""
join_groups: Optional[JoinGroups] = JoinGroups.NO_ALLOW join_groups: Optional[JoinGroups] = JoinGroups.NO_ALLOW
"""是否允许机器人被邀请到其它群组"""
timeout: int = 10 timeout: int = 10
connection_pool_size: int = 256
read_timeout: Optional[float] = None read_timeout: Optional[float] = None
write_timeout: Optional[float] = None write_timeout: Optional[float] = None
connect_timeout: Optional[float] = None connect_timeout: Optional[float] = None
@ -138,6 +145,7 @@ class BotConfig(Settings):
pass_challenge_app_key: str = "" pass_challenge_app_key: str = ""
pass_challenge_user_web: str = "" pass_challenge_user_web: str = ""
reload: ReloadConfig = ReloadConfig()
mysql: MySqlConfig = MySqlConfig() mysql: MySqlConfig = MySqlConfig()
logger: LoggerConfig = LoggerConfig() logger: LoggerConfig = LoggerConfig()
webserver: WebServerConfig = WebServerConfig() webserver: WebServerConfig = WebServerConfig()
@ -145,8 +153,7 @@ class BotConfig(Settings):
mtproto: MTProtoConfig = MTProtoConfig() mtproto: MTProtoConfig = MTProtoConfig()
error: ErrorConfig = ErrorConfig() error: ErrorConfig = ErrorConfig()
notice: NoticeConfig = NoticeConfig() notice: NoticeConfig = NoticeConfig()
plugin: PluginConfig = PluginConfig()
BotConfig.update_forward_refs() ApplicationConfig.update_forward_refs()
config = BotConfig() config = ApplicationConfig()

View File

@ -1,21 +0,0 @@
from core.base.mysql import MySQL
from core.base.redisdb import RedisDB
from core.cookies.cache import PublicCookiesCache
from core.cookies.repositories import CookiesRepository
from core.cookies.services import CookiesService, PublicCookiesService
from core.service import init_service
@init_service
def create_cookie_service(mysql: MySQL):
_repository = CookiesRepository(mysql)
_service = CookiesService(_repository)
return _service
@init_service
def create_public_cookie_service(mysql: MySQL, redis: RedisDB):
_repository = CookiesRepository(mysql)
_cache = PublicCookiesCache(redis)
_service = PublicCookiesService(_repository, _cache)
return _service

View File

@ -1,27 +0,0 @@
import enum
from typing import Optional, Dict
from sqlmodel import SQLModel, Field, JSON, Enum, Column
class CookiesStatusEnum(int, enum.Enum):
STATUS_SUCCESS = 0
INVALID_COOKIES = 1
TOO_MANY_REQUESTS = 2
class Cookies(SQLModel):
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
id: int = Field(primary_key=True)
user_id: Optional[int] = Field(foreign_key="user.user_id")
cookies: Optional[Dict[str, str]] = Field(sa_column=Column(JSON))
status: Optional[CookiesStatusEnum] = Field(sa_column=Column(Enum(CookiesStatusEnum)))
class HyperionCookie(Cookies, table=True):
__tablename__ = "mihoyo_cookies"
class HoyolabCookie(Cookies, table=True):
__tablename__ = "hoyoverse_cookies"

View File

@ -1,109 +0,0 @@
from typing import cast, List
from sqlalchemy import select
from sqlalchemy.exc import NoResultFound
from sqlmodel.ext.asyncio.session import AsyncSession
from core.base.mysql import MySQL
from utils.error import RegionNotFoundError
from utils.models.base import RegionEnum
from .error import CookiesNotFoundError
from .models import HyperionCookie, HoyolabCookie, Cookies
class CookiesRepository:
def __init__(self, mysql: MySQL):
self.mysql = mysql
async def add_cookies(self, user_id: int, cookies: dict, region: RegionEnum):
async with self.mysql.Session() as session:
session = cast(AsyncSession, session)
if region == RegionEnum.HYPERION:
db_data = HyperionCookie(user_id=user_id, cookies=cookies)
elif region == RegionEnum.HOYOLAB:
db_data = HoyolabCookie(user_id=user_id, cookies=cookies)
else:
raise RegionNotFoundError(region.name)
session.add(db_data)
await session.commit()
async def update_cookies(self, user_id: int, cookies: dict, region: RegionEnum):
async with self.mysql.Session() as session:
session = cast(AsyncSession, session)
if region == RegionEnum.HYPERION:
statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id)
elif region == RegionEnum.HOYOLAB:
statement = select(HoyolabCookie).where(HoyolabCookie.user_id == user_id)
else:
raise RegionNotFoundError(region.name)
results = await session.exec(statement)
db_cookies = results.first()
if db_cookies is None:
raise CookiesNotFoundError(user_id)
db_cookies = db_cookies[0]
db_cookies.cookies = cookies
session.add(db_cookies)
await session.commit()
await session.refresh(db_cookies)
async def update_cookies_ex(self, cookies: Cookies, region: RegionEnum):
async with self.mysql.Session() as session:
session = cast(AsyncSession, session)
if region not in [RegionEnum.HYPERION, RegionEnum.HOYOLAB]:
raise RegionNotFoundError(region.name)
session.add(cookies)
await session.commit()
await session.refresh(cookies)
async def get_cookies(self, user_id, region: RegionEnum) -> Cookies:
async with self.mysql.Session() as session:
session = cast(AsyncSession, session)
if region == RegionEnum.HYPERION:
statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id)
results = await session.exec(statement)
db_cookies = results.first()
if db_cookies is None:
raise CookiesNotFoundError(user_id)
return db_cookies[0]
elif region == RegionEnum.HOYOLAB:
statement = select(HoyolabCookie).where(HoyolabCookie.user_id == user_id)
results = await session.exec(statement)
db_cookies = results.first()
if db_cookies is None:
raise CookiesNotFoundError(user_id)
return db_cookies[0]
else:
raise RegionNotFoundError(region.name)
async def get_all_cookies(self, region: RegionEnum) -> List[Cookies]:
async with self.mysql.Session() as session:
session = cast(AsyncSession, session)
if region == RegionEnum.HYPERION:
statement = select(HyperionCookie)
results = await session.exec(statement)
db_cookies = results.all()
return [cookies[0] for cookies in db_cookies]
elif region == RegionEnum.HOYOLAB:
statement = select(HoyolabCookie)
results = await session.exec(statement)
db_cookies = results.all()
return [cookies[0] for cookies in db_cookies]
else:
raise RegionNotFoundError(region.name)
async def del_cookies(self, user_id, region: RegionEnum):
async with self.mysql.Session() as session:
session = cast(AsyncSession, session)
if region == RegionEnum.HYPERION:
statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id)
elif region == RegionEnum.HOYOLAB:
statement = select(HoyolabCookie).where(HoyolabCookie.user_id == user_id)
else:
raise RegionNotFoundError(region.name)
results = await session.execute(statement)
try:
db_cookies = results.unique().scalar_one()
except NoResultFound as exc:
raise CookiesNotFoundError(user_id) from exc
await session.delete(db_cookies)
await session.commit()

View File

@ -0,0 +1 @@
"""基础服务"""

View File

@ -1,26 +1,40 @@
from typing import Optional from typing import Optional, TYPE_CHECKING
from playwright.async_api import Browser, Playwright, async_playwright, Error from playwright.async_api import Error, async_playwright
from core.service import Service from core.base_service import BaseService
from utils.log import logger from utils.log import logger
if TYPE_CHECKING:
from playwright.async_api import Playwright as AsyncPlaywright, Browser
__all__ = ("AioBrowser",)
class AioBrowser(BaseService.Dependence):
@property
def browser(self):
return self._browser
class AioBrowser(Service):
def __init__(self, loop=None): def __init__(self, loop=None):
self.browser: Optional[Browser] = None self._browser: Optional["Browser"] = None
self._playwright: Optional[Playwright] = None self._playwright: Optional["AsyncPlaywright"] = None
self._loop = loop self._loop = loop
async def start(self): async def get_browser(self):
if self._browser is None:
await self.initialize()
return self._browser
async def initialize(self):
if self._playwright is None: if self._playwright is None:
logger.info("正在尝试启动 [blue]Playwright[/]", extra={"markup": True}) logger.info("正在尝试启动 [blue]Playwright[/]", extra={"markup": True})
self._playwright = await async_playwright().start() self._playwright = await async_playwright().start()
logger.success("[blue]Playwright[/] 启动成功", extra={"markup": True}) logger.success("[blue]Playwright[/] 启动成功", extra={"markup": True})
if self.browser is None: if self._browser is None:
logger.info("正在尝试启动 [blue]Browser[/]", extra={"markup": True}) logger.info("正在尝试启动 [blue]Browser[/]", extra={"markup": True})
try: try:
self.browser = await self._playwright.chromium.launch(timeout=5000) self._browser = await self._playwright.chromium.launch(timeout=5000)
logger.success("[blue]Browser[/] 启动成功", extra={"markup": True}) logger.success("[blue]Browser[/] 启动成功", extra={"markup": True})
except Error as err: except Error as err:
if "playwright install" in str(err): if "playwright install" in str(err):
@ -33,15 +47,10 @@ class AioBrowser(Service):
raise RuntimeError("检查到 playwright 刚刚安装或者未升级\n请运行以下命令下载新浏览器\nplaywright install chromium") raise RuntimeError("检查到 playwright 刚刚安装或者未升级\n请运行以下命令下载新浏览器\nplaywright install chromium")
raise err raise err
return self.browser return self._browser
async def stop(self): async def shutdown(self):
if self.browser is not None: if self._browser is not None:
await self.browser.close() await self._browser.close()
if self._playwright is not None: if self._playwright is not None:
await self._playwright.stop() self._playwright.stop()
async def get_browser(self) -> Browser:
if self.browser is None:
await self.start()
return self.browser

View File

@ -0,0 +1,16 @@
from asyncio import AbstractEventLoop
from playwright.async_api import Browser, Playwright as AsyncPlaywright
from core.base_service import BaseService
__all__ = ("AioBrowser",)
class AioBrowser(BaseService.Dependence):
_browser: Browser | None
_playwright: AsyncPlaywright | None
_loop: AbstractEventLoop
@property
def browser(self) -> Browser | None: ...
async def get_browser(self) -> Browser: ...

View File

@ -17,7 +17,7 @@ from enkanetwork.model.assets import CharacterAsset as EnkaCharacterAsset
from httpx import AsyncClient, HTTPError, HTTPStatusError, TransportError, URL from httpx import AsyncClient, HTTPError, HTTPStatusError, TransportError, URL
from typing_extensions import Self from typing_extensions import Self
from core.service import Service from core.base_service import BaseService
from metadata.genshin import AVATAR_DATA, HONEY_DATA, MATERIAL_DATA, NAMECARD_DATA, WEAPON_DATA from metadata.genshin import AVATAR_DATA, HONEY_DATA, MATERIAL_DATA, NAMECARD_DATA, WEAPON_DATA
from metadata.scripts.honey import update_honey_metadata from metadata.scripts.honey import update_honey_metadata
from metadata.scripts.metadatas import update_metadata_from_ambr, update_metadata_from_github from metadata.scripts.metadatas import update_metadata_from_ambr, update_metadata_from_github
@ -31,6 +31,8 @@ if TYPE_CHECKING:
from httpx import Response from httpx import Response
from multiprocessing.synchronize import RLock from multiprocessing.synchronize import RLock
__all__ = ("AssetsServiceType", "AssetsService", "AssetsServiceError", "AssetsCouldNotFound", "DEFAULT_EnkaAssets")
ICON_TYPE = Union[Callable[[bool], Awaitable[Optional[Path]]], Callable[..., Awaitable[Optional[Path]]]] ICON_TYPE = Union[Callable[[bool], Awaitable[Optional[Path]]], Callable[..., Awaitable[Optional[Path]]]]
NAME_MAP_TYPE = Dict[str, StrOrURL] NAME_MAP_TYPE = Dict[str, StrOrURL]
@ -127,7 +129,7 @@ class _AssetsService(ABC):
async def _download(self, url: StrOrURL, path: Path, retry: int = 5) -> Path | None: async def _download(self, url: StrOrURL, path: Path, retry: int = 5) -> Path | None:
"""从 url 下载图标至 path""" """从 url 下载图标至 path"""
logger.debug(f"正在从 {url} 下载图标至 {path}") logger.debug("正在从 %s 下载图标至 %s", url, path)
headers = {"user-agent": "TGPaimonBot/3.0"} if URL(url).host == "enka.network" else None headers = {"user-agent": "TGPaimonBot/3.0"} if URL(url).host == "enka.network" else None
for time in range(retry): for time in range(retry):
try: try:
@ -204,8 +206,8 @@ class _AssetsService(ABC):
"""魔法""" """魔法"""
if item in self.icon_types: if item in self.icon_types:
return partial(self._get_img, item=item) return partial(self._get_img, item=item)
else:
object.__getattribute__(self, item) object.__getattribute__(self, item)
return None
@abstractmethod @abstractmethod
@cached_property @cached_property
@ -498,7 +500,7 @@ class _NamecardAssets(_AssetsService):
} }
class AssetsService(Service): class AssetsService(BaseService.Dependence):
"""asset服务 """asset服务
用于储存和管理 asset : 用于储存和管理 asset :
@ -527,8 +529,10 @@ class AssetsService(Service):
): ):
setattr(self, attr, globals()[assets_type_name]()) setattr(self, attr, globals()[assets_type_name]())
async def start(self): # pylint: disable=R0201 async def initialize(self) -> None: # pylint: disable=R0201
"""启动 AssetsService 服务,刷新元数据"""
logger.info("正在刷新元数据") logger.info("正在刷新元数据")
# todo 这3个任务同时异步下载
await update_metadata_from_github(False) await update_metadata_from_github(False)
await update_metadata_from_ambr(False) await update_metadata_from_ambr(False)
await update_honey_metadata(False) await update_honey_metadata(False)

167
core/dependence/assets.pyi Normal file
View File

@ -0,0 +1,167 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from functools import partial
from pathlib import Path
from typing import Awaitable, Callable, ClassVar, TypeVar
from enkanetwork import Assets as EnkaAssets
from enkanetwork.model.assets import CharacterAsset as EnkaCharacterAsset
from httpx import AsyncClient
from typing_extensions import Self
from core.base_service import BaseService
from utils.typedefs import StrOrInt
__all__ = ("AssetsServiceType", "AssetsService", "AssetsServiceError", "AssetsCouldNotFound", "DEFAULT_EnkaAssets")
ICON_TYPE = Callable[[bool], Awaitable[Path | None]] | Callable[..., Awaitable[Path | None]]
DEFAULT_EnkaAssets: EnkaAssets
_GET_TYPE = partial | list[str] | int | str | ICON_TYPE | Path | AsyncClient | None | Self | dict[str, str]
class AssetsServiceError(Exception): ...
class AssetsCouldNotFound(AssetsServiceError):
message: str
target: str
def __init__(self, message: str, target: str): ...
class _AssetsService(ABC):
icon_types: ClassVar[list[str]]
id: int
type: str
icon: ICON_TYPE
"""图标"""
@abstractmethod
@property
def game_name(self) -> str:
"""游戏数据中的名称"""
@property
def honey_id(self) -> str:
"""当前资源在 Honey Impact 所对应的 ID"""
@property
def path(self) -> Path:
"""当前资源的文件夹"""
@property
def client(self) -> AsyncClient:
"""当前的 http client"""
def __init__(self, client: AsyncClient | None = None) -> None: ...
def __call__(self, target: int) -> Self:
"""用于生成与 target 对应的 assets"""
def __getattr__(self, item: str) -> _GET_TYPE:
"""魔法"""
async def get_link(self, item: str) -> str | None:
"""获取相应图标链接"""
@abstractmethod
@property
def game_name_map(self) -> dict[str, str]:
"""游戏中的图标名"""
@abstractmethod
@property
def honey_name_map(self) -> dict[str, str]:
"""来自honey的图标名"""
class _AvatarAssets(_AssetsService):
enka: EnkaCharacterAsset | None
side: ICON_TYPE
"""侧视图图标"""
card: ICON_TYPE
"""卡片图标"""
gacha: ICON_TYPE
"""抽卡立绘"""
gacha_card: ICON_TYPE
"""抽卡卡片"""
@property
def honey_name_map(self) -> dict[str, str]: ...
@property
def game_name_map(self) -> dict[str, str]: ...
@property
def enka(self) -> EnkaCharacterAsset | None: ...
def __init__(self, client: AsyncClient | None = None, enka: EnkaAssets | None = None) -> None: ...
def __call__(self, target: StrOrInt) -> Self: ...
def __getitem__(self, item: str) -> _GET_TYPE | EnkaCharacterAsset: ...
def game_name(self) -> str: ...
class _WeaponAssets(_AssetsService):
awaken: ICON_TYPE
"""突破后图标"""
gacha: ICON_TYPE
"""抽卡立绘"""
@property
def honey_name_map(self) -> dict[str, str]: ...
@property
def game_name_map(self) -> dict[str, str]: ...
def __call__(self, target: StrOrInt) -> Self: ...
def game_name(self) -> str: ...
class _MaterialAssets(_AssetsService):
@property
def honey_name_map(self) -> dict[str, str]: ...
@property
def game_name_map(self) -> dict[str, str]: ...
def __call__(self, target: StrOrInt) -> Self: ...
def game_name(self) -> str: ...
class _ArtifactAssets(_AssetsService):
flower: ICON_TYPE
"""生之花"""
plume: ICON_TYPE
"""死之羽"""
sands: ICON_TYPE
"""时之沙"""
goblet: ICON_TYPE
"""空之杯"""
circlet: ICON_TYPE
"""理之冠"""
@property
def honey_name_map(self) -> dict[str, str]: ...
@property
def game_name_map(self) -> dict[str, str]: ...
def game_name(self) -> str: ...
class _NamecardAssets(_AssetsService):
enka: EnkaCharacterAsset | None
navbar: ICON_TYPE
"""好友名片背景"""
profile: ICON_TYPE
"""个人资料名片背景"""
@property
def honey_name_map(self) -> dict[str, str]: ...
@property
def game_name_map(self) -> dict[str, str]: ...
def game_name(self) -> str: ...
class AssetsService(BaseService.Dependence):
avatar: _AvatarAssets
"""角色"""
weapon: _WeaponAssets
"""武器"""
material: _MaterialAssets
"""素材"""
artifact: _ArtifactAssets
"""圣遗物"""
namecard: _NamecardAssets
"""名片"""
AssetsServiceType = TypeVar("AssetsServiceType", bound=_AssetsService)

View File

@ -4,6 +4,8 @@ from urllib.parse import urlparse
import aiofiles import aiofiles
from core.base_service import BaseService
from core.config import config as bot_config
from utils.log import logger from utils.log import logger
try: try:
@ -13,13 +15,12 @@ try:
session.log.debug = lambda *args, **kwargs: None # 关闭日记 session.log.debug = lambda *args, **kwargs: None # 关闭日记
PYROGRAM_AVAILABLE = True PYROGRAM_AVAILABLE = True
except ImportError: except ImportError:
Client = None
session = None
PYROGRAM_AVAILABLE = False PYROGRAM_AVAILABLE = False
from core.bot import bot
from core.service import Service
class MTProto(BaseService.Dependence):
class MTProto(Service):
async def get_session(self): async def get_session(self):
async with aiofiles.open(self.session_path, mode="r") as f: async with aiofiles.open(self.session_path, mode="r") as f:
return await f.read() return await f.read()
@ -32,9 +33,9 @@ class MTProto(Service):
return os.path.exists(self.session_path) return os.path.exists(self.session_path)
def __init__(self): def __init__(self):
self.name = "PaimonBot" self.name = "paigram"
current_dir = os.getcwd() current_dir = os.getcwd()
self.session_path = os.path.join(current_dir, "paimon.session") self.session_path = os.path.join(current_dir, "paigram.session")
self.client: Optional[Client] = None self.client: Optional[Client] = None
self.proxy: Optional[dict] = None self.proxy: Optional[dict] = None
http_proxy = os.environ.get("HTTP_PROXY") http_proxy = os.environ.get("HTTP_PROXY")
@ -42,25 +43,25 @@ class MTProto(Service):
http_proxy_url = urlparse(http_proxy) http_proxy_url = urlparse(http_proxy)
self.proxy = {"scheme": "http", "hostname": http_proxy_url.hostname, "port": http_proxy_url.port} self.proxy = {"scheme": "http", "hostname": http_proxy_url.hostname, "port": http_proxy_url.port}
async def start(self): # pylint: disable=W0221 async def initialize(self): # pylint: disable=W0221
if not PYROGRAM_AVAILABLE: if not PYROGRAM_AVAILABLE:
logger.info("MTProto 服务需要的 pyrogram 模块未导入 本次服务 client 为 None") logger.info("MTProto 服务需要的 pyrogram 模块未导入 本次服务 client 为 None")
return return
if bot.config.mtproto.api_id is None: if bot_config.mtproto.api_id is None:
logger.info("MTProto 服务需要的 api_id 未配置 本次服务 client 为 None") logger.info("MTProto 服务需要的 api_id 未配置 本次服务 client 为 None")
return return
if bot.config.mtproto.api_hash is None: if bot_config.mtproto.api_hash is None:
logger.info("MTProto 服务需要的 api_hash 未配置 本次服务 client 为 None") logger.info("MTProto 服务需要的 api_hash 未配置 本次服务 client 为 None")
return return
self.client = Client( self.client = Client(
api_id=bot.config.mtproto.api_id, api_id=bot_config.mtproto.api_id,
api_hash=bot.config.mtproto.api_hash, api_hash=bot_config.mtproto.api_hash,
name=self.name, name=self.name,
bot_token=bot.config.bot_token, bot_token=bot_config.bot_token,
proxy=self.proxy, proxy=self.proxy,
) )
await self.client.start() await self.client.start()
async def stop(self): # pylint: disable=W0221 async def shutdown(self): # pylint: disable=W0221
if self.client is not None: if self.client is not None:
await self.client.stop(block=False) await self.client.stop(block=False)

View File

@ -0,0 +1,31 @@
from __future__ import annotations
from typing import TypedDict
from core.base_service import BaseService
try:
from pyrogram import Client
from pyrogram.session import session
PYROGRAM_AVAILABLE = True
except ImportError:
Client = None
session = None
PYROGRAM_AVAILABLE = False
__all__ = ("MTProto",)
class _ProxyType(TypedDict):
scheme: str
hostname: str | None
port: int | None
class MTProto(BaseService.Dependence):
name: str
session_path: str
client: Client | None
proxy: _ProxyType | None
async def get_session(self) -> str: ...
async def set_session(self, b: str) -> None: ...
def session_exists(self) -> bool: ...

50
core/dependence/mysql.py Normal file
View File

@ -0,0 +1,50 @@
import contextlib
from typing import Optional
from sqlalchemy.engine import URL
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import sessionmaker
from typing_extensions import Self
from core.base_service import BaseService
from core.config import ApplicationConfig
from core.sqlmodel.session import AsyncSession
__all__ = ("MySQL",)
class MySQL(BaseService.Dependence):
@classmethod
def from_config(cls, config: ApplicationConfig) -> Self:
return cls(**config.mysql.dict())
def __init__(
self,
host: Optional[str] = None,
port: Optional[int] = None,
username: Optional[str] = None,
password: Optional[str] = None,
database: Optional[str] = None,
):
self.database = database
self.password = password
self.username = username
self.port = port
self.host = host
self.url = URL.create(
"mysql+asyncmy",
username=self.username,
password=self.password,
host=self.host,
port=self.port,
database=self.database,
)
self.engine = create_async_engine(self.url)
self.Session = sessionmaker(bind=self.engine, class_=AsyncSession)
@contextlib.asynccontextmanager
async def session(self) -> AsyncSession:
yield self.Session()
async def shutdown(self):
self.Session.close_all()

View File

@ -1,4 +1,3 @@
import asyncio
from typing import Optional, Union from typing import Optional, Union
import fakeredis.aioredis import fakeredis.aioredis
@ -6,14 +5,16 @@ from redis import asyncio as aioredis
from redis.exceptions import ConnectionError as RedisConnectionError, TimeoutError as RedisTimeoutError from redis.exceptions import ConnectionError as RedisConnectionError, TimeoutError as RedisTimeoutError
from typing_extensions import Self from typing_extensions import Self
from core.config import BotConfig from core.base_service import BaseService
from core.service import Service from core.config import ApplicationConfig
from utils.log import logger from utils.log import logger
__all__ = ["RedisDB"]
class RedisDB(Service):
class RedisDB(BaseService.Dependence):
@classmethod @classmethod
def from_config(cls, config: BotConfig) -> Self: def from_config(cls, config: ApplicationConfig) -> Self:
return cls(**config.redis.dict()) return cls(**config.redis.dict())
def __init__( def __init__(
@ -24,6 +25,7 @@ class RedisDB(Service):
self.key_prefix = "paimon_bot" self.key_prefix = "paimon_bot"
async def ping(self): async def ping(self):
# noinspection PyUnresolvedReferences
if await self.client.ping(): if await self.client.ping():
logger.info("连接 [red]Redis[/] 成功", extra={"markup": True}) logger.info("连接 [red]Redis[/] 成功", extra={"markup": True})
else: else:
@ -34,7 +36,7 @@ class RedisDB(Service):
self.client = fakeredis.aioredis.FakeRedis() self.client = fakeredis.aioredis.FakeRedis()
await self.ping() await self.ping()
async def start(self): # pylint: disable=W0221 async def initialize(self):
logger.info("正在尝试建立与 [red]Redis[/] 连接", extra={"markup": True}) logger.info("正在尝试建立与 [red]Redis[/] 连接", extra={"markup": True})
try: try:
await self.ping() await self.ping()
@ -45,5 +47,5 @@ class RedisDB(Service):
logger.warning("连接 [red]Redis[/] 失败,使用 [red]fakeredis[/] 模拟", extra={"markup": True}) logger.warning("连接 [red]Redis[/] 失败,使用 [red]fakeredis[/] 模拟", extra={"markup": True})
await self.start_fake_redis() await self.start_fake_redis()
async def stop(self): # pylint: disable=W0221 async def shutdown(self):
await self.client.close() await self.client.close()

View File

@ -1,16 +0,0 @@
from core.base.redisdb import RedisDB
from core.service import init_service
from .cache import GameCache
from .services import GameMaterialService, GameStrategyService
@init_service
def create_game_strategy_service(redis: RedisDB):
_cache = GameCache(redis, "game:strategy")
return GameStrategyService(_cache)
@init_service
def create_game_material_service(redis: RedisDB):
_cache = GameCache(redis, "game:material")
return GameMaterialService(_cache)

View File

@ -0,0 +1,59 @@
import asyncio
from typing import TypeVar, TYPE_CHECKING, Any, Optional
from telegram import Update
from telegram.ext import ApplicationHandlerStop, BaseHandler
from core.error import ServiceNotFoundError
from core.services.users.services import UserAdminService
from utils.log import logger
if TYPE_CHECKING:
from core.application import Application
from telegram.ext import Application as TelegramApplication
RT = TypeVar("RT")
UT = TypeVar("UT")
CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]")
class AdminHandler(BaseHandler[Update, CCT]):
_lock = asyncio.Lock()
def __init__(self, handler: BaseHandler[Update, CCT], application: "Application") -> None:
self.handler = handler
self.application = application
self.user_service: Optional["UserAdminService"] = None
super().__init__(self.handler.callback)
def check_update(self, update: object) -> bool:
if not isinstance(update, Update):
return False
return self.handler.check_update(update)
async def _user_service(self) -> "UserAdminService":
async with self._lock:
if self.user_service is not None:
return self.user_service
user_service: UserAdminService = self.application.managers.services_map.get(UserAdminService, None)
if user_service is None:
raise ServiceNotFoundError("UserAdminService")
self.user_service = user_service
return self.user_service
async def handle_update(
self,
update: "UT",
application: "TelegramApplication[Any, CCT, Any, Any, Any, Any]",
check_result: Any,
context: "CCT",
) -> RT:
user_service = await self._user_service()
user = update.effective_user
if await user_service.is_admin(user.id):
return await self.handler.handle_update(update, application, check_result, context)
message = update.effective_message
logger.warning("用户 %s[%s] 触发尝试调用Admin命令但权限不足", user.full_name, user.id)
await message.reply_text("权限不足")
raise ApplicationHandlerStop

View File

@ -0,0 +1,62 @@
import asyncio
from contextlib import AbstractAsyncContextManager
from types import TracebackType
from typing import TypeVar, TYPE_CHECKING, Any, Optional, Type
from telegram.ext import CallbackQueryHandler as BaseCallbackQueryHandler, ApplicationHandlerStop
from utils.log import logger
if TYPE_CHECKING:
from telegram.ext import Application
RT = TypeVar("RT")
UT = TypeVar("UT")
CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]")
class OverlappingException(Exception):
pass
class OverlappingContext(AbstractAsyncContextManager):
_lock = asyncio.Lock()
def __init__(self, context: "CCT"):
self.context = context
async def __aenter__(self) -> None:
async with self._lock:
flag = self.context.user_data.get("overlapping", False)
if flag:
raise OverlappingException
self.context.user_data["overlapping"] = True
return None
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
async with self._lock:
del self.context.user_data["overlapping"]
return None
class CallbackQueryHandler(BaseCallbackQueryHandler):
async def handle_update(
self,
update: "UT",
application: "Application[Any, CCT, Any, Any, Any, Any]",
check_result: Any,
context: "CCT",
) -> RT:
self.collect_additional_context(context, update, application, check_result)
try:
async with OverlappingContext(context):
return await self.callback(update, context)
except OverlappingException as exc:
user = update.effective_user
logger.warning("用户 %s[%s] 触发 overlapping 该次命令已忽略", user.full_name, user.id)
raise ApplicationHandlerStop from exc

View File

@ -0,0 +1,71 @@
import asyncio
from typing import TypeVar, Optional
from telegram import Update
from telegram.ext import ContextTypes, ApplicationHandlerStop, TypeHandler
from utils.log import logger
UT = TypeVar("UT")
CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]")
class LimiterHandler(TypeHandler[UT, CCT]):
_lock = asyncio.Lock()
def __init__(
self, max_rate: float = 5, time_period: float = 10, amount: float = 1, limit_time: Optional[float] = None
):
"""Limiter Handler 通过
`Leaky bucket algorithm <https://en.wikipedia.org/wiki/Leaky_bucket>`_
实现对用户的输入的精确控制
输入超过一定速率后代码会抛出
:class:`telegram.ext.ApplicationHandlerStop`
异常并在一段时间内防止用户执行任何其他操作
:param max_rate: 在抛出异常之前最多允许 频率/ 的速度
:param time_period: 在限制速率的时间段的持续时间
:param amount: 提供的容量
:param limit_time: 限制时间 如果不提供限制时间为 max_rate / time_period * amount
"""
self.max_rate = max_rate
self.amount = amount
self._rate_per_sec = max_rate / time_period
self.limit_time = limit_time
super().__init__(Update, self.limiter_callback)
async def limiter_callback(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
if update.inline_query is not None:
return
loop = asyncio.get_running_loop()
async with self._lock:
time = loop.time()
user_data = context.user_data
if user_data is None:
return
user_limit_time = user_data.get("limit_time")
if user_limit_time is not None:
if time >= user_limit_time:
del user_data["limit_time"]
else:
raise ApplicationHandlerStop
last_task_time = user_data.get("last_task_time", 0)
if last_task_time:
task_level = user_data.get("task_level", 0)
elapsed = time - last_task_time
decrement = elapsed * self._rate_per_sec
task_level = max(task_level - decrement, 0)
user_data["task_level"] = task_level
if not task_level + self.amount <= self.max_rate:
if self.limit_time:
limit_time = self.limit_time
else:
limit_time = 1 / self._rate_per_sec * self.amount
user_data["limit_time"] = time + limit_time
user = update.effective_user
logger.warning("用户 %s[%s] 触发洪水限制 已被限制 %s", user.full_name, user.id, limit_time)
raise ApplicationHandlerStop
user_data["last_task_time"] = time
task_level = user_data.get("task_level", 0)
user_data["task_level"] = task_level + self.amount

286
core/manager.py Normal file
View File

@ -0,0 +1,286 @@
import asyncio
from importlib import import_module
from pathlib import Path
from typing import Dict, Generic, List, Optional, TYPE_CHECKING, Type, TypeVar
from arkowrapper import ArkoWrapper
from async_timeout import timeout
from typing_extensions import ParamSpec
from core.base_service import BaseServiceType, ComponentType, DependenceType, get_all_services
from core.config import config as bot_config
from utils.const import PLUGIN_DIR, PROJECT_ROOT
from utils.helpers import gen_pkg
from utils.log import logger
if TYPE_CHECKING:
from core.application import Application
from core.plugin import PluginType
from core.builtins.executor import Executor
__all__ = ("DependenceManager", "PluginManager", "ComponentManager", "ServiceManager", "Managers")
R = TypeVar("R")
T = TypeVar("T")
P = ParamSpec("P")
def _load_module(path: Path) -> None:
for pkg in gen_pkg(path):
try:
logger.debug('正在导入 "%s"', pkg)
import_module(pkg)
except Exception as e:
logger.exception(
'在导入 "%s" 的过程中遇到了错误 [red bold]%s[/]', pkg, type(e).__name__, exc_info=e, extra={"markup": True}
)
raise SystemExit from e
class Manager(Generic[T]):
"""生命周期控制基类"""
_executor: Optional["Executor"] = None
_lib: Dict[Type[T], T] = {}
_application: "Optional[Application]" = None
def set_application(self, application: "Application") -> None:
self._application = application
@property
def application(self) -> "Application":
if self._application is None:
raise RuntimeError(f"No application was set for this {self.__class__.__name__}.")
return self._application
@property
def executor(self) -> "Executor":
"""执行器"""
if self._executor is None:
raise RuntimeError(f"No executor was set for this {self.__class__.__name__}.")
return self._executor
def build_executor(self, name: str):
from core.builtins.executor import Executor
from core.builtins.dispatcher import BaseDispatcher
self._executor = Executor(name, dispatcher=BaseDispatcher)
self._executor.set_application(self.application)
class DependenceManager(Manager[DependenceType]):
"""基础依赖管理"""
_dependency: Dict[Type[DependenceType], DependenceType] = {}
@property
def dependency(self) -> List[DependenceType]:
return list(self._dependency.values())
@property
def dependency_map(self) -> Dict[Type[DependenceType], DependenceType]:
return self._dependency
async def start_dependency(self) -> None:
_load_module(PROJECT_ROOT / "core/dependence")
for dependence in filter(lambda x: x.is_dependence, get_all_services()):
dependence: Type[DependenceType]
instance: DependenceType
try:
if hasattr(dependence, "from_config"): # 如果有 from_config 方法
instance = dependence.from_config(bot_config) # 用 from_config 实例化服务
else:
instance = await self.executor(dependence)
await instance.initialize()
logger.success('基础服务 "%s" 启动成功', dependence.__name__)
self._lib[dependence] = instance
self._dependency[dependence] = instance
except Exception as e:
logger.exception('基础服务 "%s" 初始化失败BOT 将自动关闭', dependence.__name__)
raise SystemExit from e
async def stop_dependency(self) -> None:
async def task(d):
try:
async with timeout(5):
await d.shutdown()
logger.debug('基础服务 "%s" 关闭成功', d.__class__.__name__)
except asyncio.TimeoutError:
logger.warning('基础服务 "%s" 关闭超时', d.__class__.__name__)
except Exception as e:
logger.error('基础服务 "%s" 关闭错误', d.__class__.__name__, exc_info=e)
tasks = []
for dependence in self._dependency.values():
tasks.append(asyncio.create_task(task(dependence)))
await asyncio.gather(*tasks)
class ComponentManager(Manager[ComponentType]):
"""组件管理"""
_components: Dict[Type[ComponentType], ComponentType] = {}
@property
def components(self) -> List[ComponentType]:
return list(self._components.values())
@property
def components_map(self) -> Dict[Type[ComponentType], ComponentType]:
return self._components
async def init_components(self):
for path in filter(
lambda x: x.is_dir() and not x.name.startswith("_"), PROJECT_ROOT.joinpath("core/services").iterdir()
):
_load_module(path)
components = ArkoWrapper(get_all_services()).filter(lambda x: x.is_component)
retry_times = 0
max_retry_times = len(components)
while components:
start_len = len(components)
for component in list(components):
component: Type[ComponentType]
instance: ComponentType
try:
instance = await self.executor(component)
self._lib[component] = instance
self._components[component] = instance
components = components.remove(component)
except Exception as e: # pylint: disable=W0703
logger.debug('组件 "%s" 初始化失败: [red]%s[/]', component.__name__, e, extra={"markup": True})
end_len = len(list(components))
if start_len == end_len:
retry_times += 1
if retry_times == max_retry_times and components:
for component in components:
logger.error('组件 "%s" 初始化失败', component.__name__)
raise SystemExit
class ServiceManager(Manager[BaseServiceType]):
"""服务控制类"""
_services: Dict[Type[BaseServiceType], BaseServiceType] = {}
@property
def services(self) -> List[BaseServiceType]:
return list(self._services.values())
@property
def services_map(self) -> Dict[Type[BaseServiceType], BaseServiceType]:
return self._services
async def _initialize_service(self, target: Type[BaseServiceType]) -> BaseServiceType:
instance: BaseServiceType
try:
if hasattr(target, "from_config"): # 如果有 from_config 方法
instance = target.from_config(bot_config) # 用 from_config 实例化服务
else:
instance = await self.executor(target)
await instance.initialize()
logger.success('服务 "%s" 启动成功', target.__name__)
return instance
except Exception as e: # pylint: disable=W0703
logger.exception('服务 "%s" 初始化失败BOT 将自动关闭', target.__name__)
raise SystemExit from e
async def start_services(self) -> None:
for path in filter(
lambda x: x.is_dir() and not x.name.startswith("_"), PROJECT_ROOT.joinpath("core/services").iterdir()
):
_load_module(path)
for service in filter(lambda x: not x.is_component and not x.is_dependence, get_all_services()): # 遍历所有服务类
instance = await self._initialize_service(service)
self._lib[service] = instance
self._services[service] = instance
async def stop_services(self) -> None:
"""关闭服务"""
if not self._services:
return
async def task(s):
try:
async with timeout(5):
await s.shutdown()
logger.success('服务 "%s" 关闭成功', s.__class__.__name__)
except asyncio.TimeoutError:
logger.warning('服务 "%s" 关闭超时', s.__class__.__name__)
except Exception as e:
logger.warning('服务 "%s" 关闭失败', s.__class__.__name__, exc_info=e)
logger.info("正在关闭服务")
tasks = []
for service in self._services.values():
tasks.append(asyncio.create_task(task(service)))
await asyncio.gather(*tasks)
class PluginManager(Manager["PluginType"]):
"""插件管理"""
_plugins: Dict[Type["PluginType"], "PluginType"] = {}
@property
def plugins(self) -> List["PluginType"]:
"""所有已经加载的插件"""
return list(self._plugins.values())
@property
def plugins_map(self) -> Dict[Type["PluginType"], "PluginType"]:
return self._plugins
async def install_plugins(self) -> None:
"""安装所有插件"""
from core.plugin import get_all_plugins
for path in filter(lambda x: x.is_dir(), PLUGIN_DIR.iterdir()):
_load_module(path)
for plugin in get_all_plugins():
plugin: Type["PluginType"]
try:
instance: "PluginType" = await self.executor(plugin)
except Exception as e: # pylint: disable=W0703
logger.error('插件 "%s" 初始化失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e)
continue
self._plugins[plugin] = instance
if self._application is not None:
instance.set_application(self._application)
await asyncio.create_task(self.plugin_install_task(plugin, instance))
@staticmethod
async def plugin_install_task(plugin: Type["PluginType"], instance: "PluginType"):
try:
await instance.install()
logger.success('插件 "%s" 安装成功', f"{plugin.__module__}.{plugin.__name__}")
except Exception as e: # pylint: disable=W0703
logger.error('插件 "%s" 安装失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e)
async def uninstall_plugins(self) -> None:
for plugin in self._plugins.values():
try:
await plugin.uninstall()
except Exception as e: # pylint: disable=W0703
logger.error('插件 "%s" 卸载失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e)
class Managers(DependenceManager, ComponentManager, ServiceManager, PluginManager):
"""BOT 除自身外的生命周期管理类"""

106
core/override/telegram.py Normal file
View File

@ -0,0 +1,106 @@
"""重写 telegram.request.HTTPXRequest 使其使用 ujson 库进行 json 序列化"""
from typing import Any, AsyncIterable, Optional
import httpcore
from httpx import (
AsyncByteStream,
AsyncHTTPTransport as DefaultAsyncHTTPTransport,
Limits,
Response as DefaultResponse,
Timeout,
)
from telegram.request import HTTPXRequest as DefaultHTTPXRequest
try:
import ujson as jsonlib
except ImportError:
import json as jsonlib
__all__ = ("HTTPXRequest",)
class Response(DefaultResponse):
def json(self, **kwargs: Any) -> Any:
# noinspection PyProtectedMember
from httpx._utils import guess_json_utf
if self.charset_encoding is None and self.content and len(self.content) > 3:
encoding = guess_json_utf(self.content)
if encoding is not None:
return jsonlib.loads(self.content.decode(encoding), **kwargs)
return jsonlib.loads(self.text, **kwargs)
# noinspection PyProtectedMember
class AsyncHTTPTransport(DefaultAsyncHTTPTransport):
async def handle_async_request(self, request) -> Response:
from httpx._transports.default import (
map_httpcore_exceptions,
AsyncResponseStream,
)
if not isinstance(request.stream, AsyncByteStream):
raise AssertionError
req = httpcore.Request(
method=request.method,
url=httpcore.URL(
scheme=request.url.raw_scheme,
host=request.url.raw_host,
port=request.url.port,
target=request.url.raw_path,
),
headers=request.headers.raw,
content=request.stream,
extensions=request.extensions,
)
with map_httpcore_exceptions():
resp = await self._pool.handle_async_request(req)
if not isinstance(resp.stream, AsyncIterable):
raise AssertionError
return Response(
status_code=resp.status,
headers=resp.headers,
stream=AsyncResponseStream(resp.stream),
extensions=resp.extensions,
)
class HTTPXRequest(DefaultHTTPXRequest):
def __init__( # pylint: disable=W0231
self,
connection_pool_size: int = 1,
proxy_url: str = None,
read_timeout: Optional[float] = 5.0,
write_timeout: Optional[float] = 5.0,
connect_timeout: Optional[float] = 5.0,
pool_timeout: Optional[float] = 1.0,
):
timeout = Timeout(
connect=connect_timeout,
read=read_timeout,
write=write_timeout,
pool=pool_timeout,
)
limits = Limits(
max_connections=connection_pool_size,
max_keepalive_connections=connection_pool_size,
)
self._client_kwargs = dict(
timeout=timeout,
proxies=proxy_url,
limits=limits,
transport=AsyncHTTPTransport(limits=limits),
)
try:
self._client = self._build_client()
except ImportError as exc:
if "httpx[socks]" not in str(exc):
raise exc
raise RuntimeError(
"To use Socks5 proxies, PTB must be installed via `pip install python-telegram-bot[socks]`."
) from exc

View File

@ -1,483 +0,0 @@
import copy
import datetime
import re
from importlib import import_module
from re import Pattern
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
# noinspection PyProtectedMember
from telegram._utils.defaultvalue import DEFAULT_TRUE
# noinspection PyProtectedMember
from telegram._utils.types import DVInput, JSONDict
from telegram.ext import BaseHandler, ConversationHandler, Job
# noinspection PyProtectedMember
from telegram.ext._utils.types import JobCallback
from telegram.ext.filters import BaseFilter
from typing_extensions import ParamSpec
__all__ = ["Plugin", "handler", "conversation", "job", "error_handler"]
P = ParamSpec("P")
T = TypeVar("T")
HandlerType = TypeVar("HandlerType", bound=BaseHandler)
TimeType = Union[float, datetime.timedelta, datetime.datetime, datetime.time]
_Module = import_module("telegram.ext")
_NORMAL_HANDLER_ATTR_NAME = "_handler_data"
_CONVERSATION_HANDLER_ATTR_NAME = "_conversation_data"
_JOB_ATTR_NAME = "_job_data"
_EXCLUDE_ATTRS = ["handlers", "jobs", "error_handlers"]
class _Plugin:
def _make_handler(self, datas: Union[List[Dict], Dict]) -> List[HandlerType]:
result = []
if isinstance(datas, list):
for data in filter(lambda x: x, datas):
func = getattr(self, data.pop("func"))
result.append(data.pop("type")(callback=func, **data.pop("kwargs")))
else:
func = getattr(self, datas.pop("func"))
result.append(datas.pop("type")(callback=func, **datas.pop("kwargs")))
return result
@property
def handlers(self) -> List[HandlerType]:
result = []
for attr in dir(self):
# noinspection PyUnboundLocalVariable
if (
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
and isinstance(func := getattr(self, attr), MethodType)
and (datas := getattr(func, _NORMAL_HANDLER_ATTR_NAME, None))
):
for data in datas:
if data["type"] not in ["error", "new_chat_member"]:
result.extend(self._make_handler(data))
return result
def _new_chat_members_handler_funcs(self) -> List[Tuple[int, Callable]]:
result = []
for attr in dir(self):
# noinspection PyUnboundLocalVariable
if (
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
and isinstance(func := getattr(self, attr), MethodType)
and (datas := getattr(func, _NORMAL_HANDLER_ATTR_NAME, None))
):
for data in datas:
if data and data["type"] == "new_chat_member":
result.append((data["priority"], func))
return result
@property
def error_handlers(self) -> Dict[Callable, bool]:
result = {}
for attr in dir(self):
# noinspection PyUnboundLocalVariable
if (
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
and isinstance(func := getattr(self, attr), MethodType)
and (datas := getattr(func, _NORMAL_HANDLER_ATTR_NAME, None))
):
for data in datas:
if data and data["type"] == "error":
result.update({func: data["block"]})
return result
@property
def jobs(self) -> List[Job]:
from core.bot import bot
result = []
for attr in dir(self):
# noinspection PyUnboundLocalVariable
if (
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
and isinstance(func := getattr(self, attr), MethodType)
and (datas := getattr(func, _JOB_ATTR_NAME, None))
):
for data in datas:
_job = getattr(bot.job_queue, data.pop("type"))(
callback=func, **data.pop("kwargs"), **{key: data.pop(key) for key in list(data.keys())}
)
result.append(_job)
return result
class _Conversation(_Plugin):
_conversation_kwargs: Dict
def __init_subclass__(cls, **kwargs):
cls._conversation_kwargs = kwargs
super(_Conversation, cls).__init_subclass__()
return cls
@property
def handlers(self) -> List[HandlerType]:
result: List[HandlerType] = []
entry_points: List[HandlerType] = []
states: Dict[Any, List[HandlerType]] = {}
fallbacks: List[HandlerType] = []
for attr in dir(self):
# noinspection PyUnboundLocalVariable
if (
not (attr.startswith("_") or attr == "handlers")
and isinstance(func := getattr(self, attr), Callable)
and (handler_datas := getattr(func, _NORMAL_HANDLER_ATTR_NAME, None))
):
conversation_data = getattr(func, _CONVERSATION_HANDLER_ATTR_NAME, None)
if attr == "cancel":
handler_datas = copy.deepcopy(handler_datas)
conversation_data = copy.deepcopy(conversation_data)
_handlers = self._make_handler(handler_datas)
if conversation_data:
if (_type := conversation_data.pop("type")) == "entry":
entry_points.extend(_handlers)
elif _type == "state":
if (key := conversation_data.pop("state")) in states:
states[key].extend(_handlers)
else:
states[key] = _handlers
elif _type == "fallback":
fallbacks.extend(_handlers)
else:
result.extend(_handlers)
if entry_points or states or fallbacks:
result.append(
ConversationHandler(
entry_points, states, fallbacks, **self.__class__._conversation_kwargs # pylint: disable=W0212
)
)
return result
class Plugin(_Plugin):
Conversation = _Conversation
class _Handler:
def __init__(self, **kwargs):
self.kwargs = kwargs
@property
def _type(self) -> Type[BaseHandler]:
return getattr(_Module, f"{self.__class__.__name__.strip('_')}Handler")
def __call__(self, func: Callable[P, T]) -> Callable[P, T]:
data = {"type": self._type, "func": func.__name__, "kwargs": self.kwargs}
if hasattr(func, _NORMAL_HANDLER_ATTR_NAME):
handler_datas = getattr(func, _NORMAL_HANDLER_ATTR_NAME)
handler_datas.append(data)
setattr(func, _NORMAL_HANDLER_ATTR_NAME, handler_datas)
else:
setattr(func, _NORMAL_HANDLER_ATTR_NAME, [data])
return func
class _CallbackQuery(_Handler):
def __init__(
self,
pattern: Union[str, Pattern, type, Callable[[object], Optional[bool]]] = None,
block: DVInput[bool] = DEFAULT_TRUE,
):
super(_CallbackQuery, self).__init__(pattern=pattern, block=block)
class _ChatJoinRequest(_Handler):
def __init__(self, block: DVInput[bool] = DEFAULT_TRUE):
super(_ChatJoinRequest, self).__init__(block=block)
class _ChatMember(_Handler):
def __init__(self, chat_member_types: int = -1, block: DVInput[bool] = DEFAULT_TRUE):
super().__init__(chat_member_types=chat_member_types, block=block)
class _ChosenInlineResult(_Handler):
def __init__(self, block: DVInput[bool] = DEFAULT_TRUE, pattern: Union[str, Pattern] = None):
super().__init__(block=block, pattern=pattern)
class _Command(_Handler):
def __init__(self, command: str, filters: "BaseFilter" = None, block: DVInput[bool] = DEFAULT_TRUE):
super(_Command, self).__init__(command=command, filters=filters, block=block)
class _InlineQuery(_Handler):
def __init__(
self, pattern: Union[str, Pattern] = None, block: DVInput[bool] = DEFAULT_TRUE, chat_types: List[str] = None
):
super().__init__(pattern=pattern, block=block, chat_types=chat_types)
class _MessageNewChatMembers(_Handler):
def __init__(self, func: Callable[P, T] = None, *, priority: int = 5):
super().__init__()
self.func = func
self.priority = priority
def __call__(self, func: Callable[P, T] = None) -> Callable[P, T]:
self.func = self.func or func
data = {"type": "new_chat_member", "priority": self.priority}
if hasattr(func, _NORMAL_HANDLER_ATTR_NAME):
handler_datas = getattr(func, _NORMAL_HANDLER_ATTR_NAME)
handler_datas.append(data)
setattr(func, _NORMAL_HANDLER_ATTR_NAME, handler_datas)
else:
setattr(func, _NORMAL_HANDLER_ATTR_NAME, [data])
return func
class _Message(_Handler):
def __init__(
self,
filters: "BaseFilter",
block: DVInput[bool] = DEFAULT_TRUE,
):
super(_Message, self).__init__(filters=filters, block=block)
new_chat_members = _MessageNewChatMembers
class _PollAnswer(_Handler):
def __init__(self, block: DVInput[bool] = DEFAULT_TRUE):
super(_PollAnswer, self).__init__(block=block)
class _Poll(_Handler):
def __init__(self, block: DVInput[bool] = DEFAULT_TRUE):
super(_Poll, self).__init__(block=block)
class _PreCheckoutQuery(_Handler):
def __init__(self, block: DVInput[bool] = DEFAULT_TRUE):
super(_PreCheckoutQuery, self).__init__(block=block)
class _Prefix(_Handler):
def __init__(
self,
prefix: str,
command: str,
filters: BaseFilter = None,
block: DVInput[bool] = DEFAULT_TRUE,
):
super(_Prefix, self).__init__(prefix=prefix, command=command, filters=filters, block=block)
class _ShippingQuery(_Handler):
def __init__(self, block: DVInput[bool] = DEFAULT_TRUE):
super(_ShippingQuery, self).__init__(block=block)
class _StringCommand(_Handler):
def __init__(self, command: str):
super(_StringCommand, self).__init__(command=command)
class _StringRegex(_Handler):
def __init__(self, pattern: Union[str, Pattern], block: DVInput[bool] = DEFAULT_TRUE):
super(_StringRegex, self).__init__(pattern=pattern, block=block)
class _Type(_Handler):
# noinspection PyShadowingBuiltins
def __init__(
self, type: Type, strict: bool = False, block: DVInput[bool] = DEFAULT_TRUE # pylint: disable=redefined-builtin
):
super(_Type, self).__init__(type=type, strict=strict, block=block)
# noinspection PyPep8Naming
class handler(_Handler):
def __init__(self, handler_type: Callable[P, HandlerType], **kwargs: P.kwargs):
self._type_ = handler_type
super(handler, self).__init__(**kwargs)
@property
def _type(self) -> Type[BaseHandler]:
# noinspection PyTypeChecker
return self._type_
callback_query = _CallbackQuery
chat_join_request = _ChatJoinRequest
chat_member = _ChatMember
chosen_inline_result = _ChosenInlineResult
command = _Command
inline_query = _InlineQuery
message = _Message
poll_answer = _PollAnswer
pool = _Poll
pre_checkout_query = _PreCheckoutQuery
prefix = _Prefix
shipping_query = _ShippingQuery
string_command = _StringCommand
string_regex = _StringRegex
type = _Type
# noinspection PyPep8Naming
class error_handler:
def __init__(self, func: Callable[P, T] = None, *, block: bool = DEFAULT_TRUE):
self._func = func
self._block = block
def __call__(self, func: Callable[P, T] = None) -> Callable[P, T]:
self._func = func or self._func
data = {"type": "error", "block": self._block}
if hasattr(func, _NORMAL_HANDLER_ATTR_NAME):
handler_datas = getattr(func, _NORMAL_HANDLER_ATTR_NAME)
handler_datas.append(data)
setattr(func, _NORMAL_HANDLER_ATTR_NAME, handler_datas)
else:
setattr(func, _NORMAL_HANDLER_ATTR_NAME, [data])
return func
def _entry(func: Callable[P, T]) -> Callable[P, T]:
setattr(func, _CONVERSATION_HANDLER_ATTR_NAME, {"type": "entry"})
return func
class _State:
def __init__(self, state: Any):
self.state = state
def __call__(self, func: Callable[P, T] = None) -> Callable[P, T]:
setattr(func, _CONVERSATION_HANDLER_ATTR_NAME, {"type": "state", "state": self.state})
return func
def _fallback(func: Callable[P, T]) -> Callable[P, T]:
setattr(func, _CONVERSATION_HANDLER_ATTR_NAME, {"type": "fallback"})
return func
# noinspection PyPep8Naming
class conversation(_Handler):
entry_point = _entry
state = _State
fallback = _fallback
class _Job:
kwargs: Dict = {}
def __init__(
self,
name: str = None,
data: object = None,
chat_id: int = None,
user_id: int = None,
job_kwargs: JSONDict = None,
**kwargs,
):
self.name = name
self.data = data
self.chat_id = chat_id
self.user_id = user_id
self.job_kwargs = {} if job_kwargs is None else job_kwargs
self.kwargs = kwargs
def __call__(self, func: JobCallback) -> JobCallback:
data = {
"name": self.name,
"data": self.data,
"chat_id": self.chat_id,
"user_id": self.user_id,
"job_kwargs": self.job_kwargs,
"kwargs": self.kwargs,
"type": re.sub(r"([A-Z])", lambda x: "_" + x.group().lower(), self.__class__.__name__).lstrip("_"),
}
if hasattr(func, _JOB_ATTR_NAME):
job_datas = getattr(func, _JOB_ATTR_NAME)
job_datas.append(data)
setattr(func, _JOB_ATTR_NAME, job_datas)
else:
setattr(func, _JOB_ATTR_NAME, [data])
return func
class _RunOnce(_Job):
def __init__(
self,
when: TimeType,
data: object = None,
name: str = None,
chat_id: int = None,
user_id: int = None,
job_kwargs: JSONDict = None,
):
super().__init__(name, data, chat_id, user_id, job_kwargs, when=when)
class _RunRepeating(_Job):
def __init__(
self,
interval: Union[float, datetime.timedelta],
first: TimeType = None,
last: TimeType = None,
data: object = None,
name: str = None,
chat_id: int = None,
user_id: int = None,
job_kwargs: JSONDict = None,
):
super().__init__(name, data, chat_id, user_id, job_kwargs, interval=interval, first=first, last=last)
class _RunMonthly(_Job):
def __init__(
self,
when: datetime.time,
day: int,
data: object = None,
name: str = None,
chat_id: int = None,
user_id: int = None,
job_kwargs: JSONDict = None,
):
super().__init__(name, data, chat_id, user_id, job_kwargs, when=when, day=day)
class _RunDaily(_Job):
def __init__(
self,
time: datetime.time,
days: Tuple[int, ...] = tuple(range(7)),
data: object = None,
name: str = None,
chat_id: int = None,
user_id: int = None,
job_kwargs: JSONDict = None,
):
super().__init__(name, data, chat_id, user_id, job_kwargs, time=time, days=days)
class _RunCustom(_Job):
def __init__(
self,
data: object = None,
name: str = None,
chat_id: int = None,
user_id: int = None,
job_kwargs: JSONDict = None,
):
super().__init__(name, data, chat_id, user_id, job_kwargs)
# noinspection PyPep8Naming
class job:
run_once = _RunOnce
run_repeating = _RunRepeating
run_monthly = _RunMonthly
run_daily = _RunDaily
run_custom = _RunCustom

16
core/plugin/__init__.py Normal file
View File

@ -0,0 +1,16 @@
"""插件"""
from core.plugin._handler import conversation, error_handler, handler
from core.plugin._job import TimeType, job
from core.plugin._plugin import Plugin, PluginType, get_all_plugins
__all__ = (
"Plugin",
"PluginType",
"get_all_plugins",
"handler",
"error_handler",
"conversation",
"job",
"TimeType",
)

175
core/plugin/_funcs.py Normal file
View File

@ -0,0 +1,175 @@
from pathlib import Path
from typing import List, Optional, Union, TYPE_CHECKING
import aiofiles
import httpx
from httpx import UnsupportedProtocol
from telegram import Chat, Message, ReplyKeyboardRemove, Update
from telegram.error import BadRequest, Forbidden
from telegram.ext import CallbackContext, ConversationHandler, Job
from core.dependence.redisdb import RedisDB
from core.plugin._handler import conversation, handler
from utils.const import CACHE_DIR, REQUEST_HEADERS
from utils.error import UrlResourcesNotFoundError
from utils.helpers import sha1
from utils.log import logger
if TYPE_CHECKING:
from core.application import Application
try:
import ujson as json
except ImportError:
import json
__all__ = (
"PluginFuncs",
"ConversationFuncs",
)
class PluginFuncs:
_application: "Optional[Application]" = None
def set_application(self, application: "Application") -> None:
self._application = application
@property
def application(self) -> "Application":
if self._application is None:
raise RuntimeError("No application was set for this PluginManager.")
return self._application
async def _delete_message(self, context: CallbackContext) -> None:
job = context.job
message_id = job.data
chat_info = f"chat_id[{job.chat_id}]"
try:
chat = await self.get_chat(job.chat_id)
full_name = chat.full_name
if full_name:
chat_info = f"{full_name}[{chat.id}]"
else:
chat_info = f"{chat.title}[{chat.id}]"
except (BadRequest, Forbidden) as exc:
logger.warning("获取 chat info 失败 %s", exc.message)
except Exception as exc:
logger.warning("获取 chat info 消息失败 %s", str(exc))
logger.debug("删除消息 %s message_id[%s]", chat_info, message_id)
try:
# noinspection PyTypeChecker
await context.bot.delete_message(chat_id=job.chat_id, message_id=message_id)
except BadRequest as exc:
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
async def get_chat(self, chat_id: Union[str, int], redis_db: Optional[RedisDB] = None, ttl: int = 86400) -> Chat:
application = self.application
redis_db: RedisDB = redis_db or self.application.managers.services_map.get(RedisDB, None)
if not redis_db:
return await application.bot.get_chat(chat_id)
qname = f"bot:chat:{chat_id}"
data = await redis_db.client.get(qname)
if data:
json_data = json.loads(data)
return Chat.de_json(json_data, application.telegram.bot)
chat_info = await application.telegram.bot.get_chat(chat_id)
await redis_db.client.set(qname, chat_info.to_json())
await redis_db.client.expire(qname, ttl)
return chat_info
def add_delete_message_job(
self,
message: Optional[Union[int, Message]] = None,
*,
delay: int = 60,
name: Optional[str] = None,
chat: Optional[Union[int, Chat]] = None,
context: Optional[CallbackContext] = None,
) -> Job:
"""延迟删除消息"""
if isinstance(message, Message):
if chat is None:
chat = message.chat_id
message = message.id
chat = chat.id if isinstance(chat, Chat) else chat
job_queue = self.application.job_queue or context.job_queue
if job_queue is None:
raise RuntimeError
return job_queue.run_once(
callback=self._delete_message,
when=delay,
data=message,
name=f"{chat}|{message}|{name}|delete_message" if name else f"{chat}|{message}|delete_message",
chat_id=chat,
job_kwargs={"replace_existing": True, "id": f"{chat}|{message}|delete_message"},
)
@staticmethod
async def download_resource(url: str, return_path: bool = False) -> str:
url_sha1 = sha1(url) # url 的 hash 值
pathed_url = Path(url)
file_name = url_sha1 + pathed_url.suffix
file_path = CACHE_DIR.joinpath(file_name)
if not file_path.exists(): # 若文件不存在,则下载
async with httpx.AsyncClient(headers=REQUEST_HEADERS) as client:
try:
response = await client.get(url)
except UnsupportedProtocol:
logger.error("链接不支持 url[%s]", url)
return ""
if response.is_error:
logger.error("请求出现错误 url[%s] status_code[%s]", url, response.status_code)
raise UrlResourcesNotFoundError(url)
if response.status_code != 200:
logger.error("download_resource 获取url[%s] 错误 status_code[%s]", url, response.status_code)
raise UrlResourcesNotFoundError(url)
async with aiofiles.open(file_path, mode="wb") as f:
await f.write(response.content)
logger.debug("download_resource 获取url[%s] 并下载到 file_dir[%s]", url, file_path)
return file_path if return_path else Path(file_path).as_uri()
@staticmethod
def get_args(context: Optional[CallbackContext] = None) -> List[str]:
args = context.args
match = context.match
if args is None:
if match is not None and (command := match.groups()[0]):
temp = []
command_parts = command.split(" ")
for command_part in command_parts:
if command_part:
temp.append(command_part)
return temp
return []
if len(args) >= 1:
return args
return []
class ConversationFuncs:
@conversation.fallback
@handler.command(command="cancel", block=True)
async def cancel(self, update: Update, _) -> int:
await update.effective_message.reply_text("退出命令", reply_markup=ReplyKeyboardRemove())
return ConversationHandler.END

380
core/plugin/_handler.py Normal file
View File

@ -0,0 +1,380 @@
from dataclasses import dataclass
from enum import Enum
from functools import wraps
from importlib import import_module
from typing import (
Any,
Callable,
ClassVar,
Dict,
List,
Optional,
Pattern,
TYPE_CHECKING,
Type,
TypeVar,
Union,
)
from pydantic import BaseModel
# noinspection PyProtectedMember
from telegram._utils.defaultvalue import DEFAULT_TRUE
# noinspection PyProtectedMember
from telegram._utils.types import DVInput
from telegram.ext import BaseHandler
from telegram.ext.filters import BaseFilter
from typing_extensions import ParamSpec
from core.handler.callbackqueryhandler import CallbackQueryHandler
from utils.const import WRAPPER_ASSIGNMENTS as _WRAPPER_ASSIGNMENTS
if TYPE_CHECKING:
from core.builtins.dispatcher import AbstractDispatcher
__all__ = (
"handler",
"conversation",
"ConversationDataType",
"ConversationData",
"HandlerData",
"ErrorHandlerData",
"error_handler",
)
P = ParamSpec("P")
T = TypeVar("T")
R = TypeVar("R")
UT = TypeVar("UT")
HandlerType = TypeVar("HandlerType", bound=BaseHandler)
HandlerCls = Type[HandlerType]
Module = import_module("telegram.ext")
HANDLER_DATA_ATTR_NAME = "_handler_datas"
"""用于储存生成 handler 时所需要的参数(例如 block的属性名"""
ERROR_HANDLER_ATTR_NAME = "_error_handler_data"
CONVERSATION_HANDLER_ATTR_NAME = "_conversation_handler_data"
"""用于储存生成 ConversationHandler 时所需要的参数(例如 block的属性名"""
WRAPPER_ASSIGNMENTS = list(
set(
_WRAPPER_ASSIGNMENTS
+ [
HANDLER_DATA_ATTR_NAME,
ERROR_HANDLER_ATTR_NAME,
CONVERSATION_HANDLER_ATTR_NAME,
]
)
)
@dataclass(init=True)
class HandlerData:
type: Type[HandlerType]
admin: bool
kwargs: Dict[str, Any]
dispatcher: Optional[Type["AbstractDispatcher"]] = None
class _Handler:
_type: Type["HandlerType"]
kwargs: Dict[str, Any] = {}
def __init_subclass__(cls, **kwargs) -> None:
"""用于获取 python-telegram-bot 中对应的 handler class"""
handler_name = f"{cls.__name__.strip('_')}Handler"
if handler_name == "CallbackQueryHandler":
cls._type = CallbackQueryHandler
return
cls._type = getattr(Module, handler_name, None)
def __init__(self, admin: bool = False, dispatcher: Optional[Type["AbstractDispatcher"]] = None, **kwargs) -> None:
self.dispatcher = dispatcher
self.admin = admin
self.kwargs = kwargs
def __call__(self, func: Callable[P, R]) -> Callable[P, R]:
"""decorator实现从 func 生成 Handler"""
handler_datas = getattr(func, HANDLER_DATA_ATTR_NAME, [])
handler_datas.append(
HandlerData(type=self._type, admin=self.admin, kwargs=self.kwargs, dispatcher=self.dispatcher)
)
setattr(func, HANDLER_DATA_ATTR_NAME, handler_datas)
return func
class _CallbackQuery(_Handler):
def __init__(
self,
pattern: Union[str, Pattern, type, Callable[[object], Optional[bool]]] = None,
*,
block: DVInput[bool] = DEFAULT_TRUE,
admin: bool = False,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super(_CallbackQuery, self).__init__(pattern=pattern, block=block, admin=admin, dispatcher=dispatcher)
class _ChatJoinRequest(_Handler):
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
super(_ChatJoinRequest, self).__init__(block=block, dispatcher=dispatcher)
class _ChatMember(_Handler):
def __init__(
self,
chat_member_types: int = -1,
*,
block: DVInput[bool] = DEFAULT_TRUE,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super().__init__(chat_member_types=chat_member_types, block=block, dispatcher=dispatcher)
class _ChosenInlineResult(_Handler):
def __init__(
self,
block: DVInput[bool] = DEFAULT_TRUE,
*,
pattern: Union[str, Pattern] = None,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super().__init__(block=block, pattern=pattern, dispatcher=dispatcher)
class _Command(_Handler):
def __init__(
self,
command: Union[str, List[str]],
filters: "BaseFilter" = None,
*,
block: DVInput[bool] = DEFAULT_TRUE,
admin: bool = False,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super(_Command, self).__init__(
command=command, filters=filters, block=block, admin=admin, dispatcher=dispatcher
)
class _InlineQuery(_Handler):
def __init__(
self,
pattern: Union[str, Pattern] = None,
chat_types: List[str] = None,
*,
block: DVInput[bool] = DEFAULT_TRUE,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super(_InlineQuery, self).__init__(pattern=pattern, block=block, chat_types=chat_types, dispatcher=dispatcher)
class _Message(_Handler):
def __init__(
self,
filters: BaseFilter,
*,
block: DVInput[bool] = DEFAULT_TRUE,
admin: bool = False,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
) -> None:
super(_Message, self).__init__(filters=filters, block=block, admin=admin, dispatcher=dispatcher)
class _PollAnswer(_Handler):
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
super(_PollAnswer, self).__init__(block=block, dispatcher=dispatcher)
class _Poll(_Handler):
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
super(_Poll, self).__init__(block=block, dispatcher=dispatcher)
class _PreCheckoutQuery(_Handler):
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
super(_PreCheckoutQuery, self).__init__(block=block, dispatcher=dispatcher)
class _Prefix(_Handler):
def __init__(
self,
prefix: str,
command: str,
filters: BaseFilter = None,
*,
block: DVInput[bool] = DEFAULT_TRUE,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super(_Prefix, self).__init__(
prefix=prefix, command=command, filters=filters, block=block, dispatcher=dispatcher
)
class _ShippingQuery(_Handler):
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
super(_ShippingQuery, self).__init__(block=block, dispatcher=dispatcher)
class _StringCommand(_Handler):
def __init__(
self,
command: str,
*,
admin: bool = False,
block: DVInput[bool] = DEFAULT_TRUE,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super(_StringCommand, self).__init__(command=command, block=block, admin=admin, dispatcher=dispatcher)
class _StringRegex(_Handler):
def __init__(
self,
pattern: Union[str, Pattern],
*,
block: DVInput[bool] = DEFAULT_TRUE,
admin: bool = False,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super(_StringRegex, self).__init__(pattern=pattern, block=block, admin=admin, dispatcher=dispatcher)
class _Type(_Handler):
# noinspection PyShadowingBuiltins
def __init__(
self,
type: Type[UT], # pylint: disable=W0622
strict: bool = False,
*,
block: DVInput[bool] = DEFAULT_TRUE,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
): # pylint: disable=redefined-builtin
super(_Type, self).__init__(type=type, strict=strict, block=block, dispatcher=dispatcher)
# noinspection PyPep8Naming
class handler(_Handler):
callback_query = _CallbackQuery
chat_join_request = _ChatJoinRequest
chat_member = _ChatMember
chosen_inline_result = _ChosenInlineResult
command = _Command
inline_query = _InlineQuery
message = _Message
poll_answer = _PollAnswer
pool = _Poll
pre_checkout_query = _PreCheckoutQuery
prefix = _Prefix
shipping_query = _ShippingQuery
string_command = _StringCommand
string_regex = _StringRegex
type = _Type
def __init__(
self,
handler_type: Union[Callable[P, "HandlerType"], Type["HandlerType"]],
*,
admin: bool = False,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
**kwargs: P.kwargs,
) -> None:
self._type = handler_type
super().__init__(admin=admin, dispatcher=dispatcher, **kwargs)
class ConversationDataType(Enum):
"""conversation handler 的类型"""
Entry = "entry"
State = "state"
Fallback = "fallback"
class ConversationData(BaseModel):
"""用于储存 conversation handler 的数据"""
type: ConversationDataType
state: Optional[Any] = None
class _ConversationType:
_type: ClassVar[ConversationDataType]
def __init_subclass__(cls, **kwargs) -> None:
cls._type = ConversationDataType(cls.__name__.lstrip("_").lower())
def _entry(func: Callable[P, R]) -> Callable[P, R]:
setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=ConversationDataType.Entry))
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
return func(*args, **kwargs)
return wrapped
class _State(_ConversationType):
def __init__(self, state: Any) -> None:
self.state = state
def __call__(self, func: Callable[P, T] = None) -> Callable[P, T]:
setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=self._type, state=self.state))
return func
def _fallback(func: Callable[P, R]) -> Callable[P, R]:
setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=ConversationDataType.Fallback))
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
return func(*args, **kwargs)
return wrapped
# noinspection PyPep8Naming
class conversation(_Handler):
entry_point = _entry
state = _State
fallback = _fallback
@dataclass(init=True)
class ErrorHandlerData:
block: bool
func: Optional[Callable] = None
# noinspection PyPep8Naming
class error_handler:
_func: Callable[P, R]
def __init__(
self,
*,
block: bool = DEFAULT_TRUE,
):
self._block = block
def __call__(self, func: Callable[P, T]) -> Callable[P, T]:
self._func = func
wraps(func, assigned=WRAPPER_ASSIGNMENTS)(self)
handler_datas = getattr(func, ERROR_HANDLER_ATTR_NAME, [])
handler_datas.append(ErrorHandlerData(block=self._block))
setattr(self._func, ERROR_HANDLER_ATTR_NAME, handler_datas)
return self._func

173
core/plugin/_job.py Normal file
View File

@ -0,0 +1,173 @@
"""插件"""
import datetime
import re
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union
# noinspection PyProtectedMember
from telegram._utils.types import JSONDict
# noinspection PyProtectedMember
from telegram.ext._utils.types import JobCallback
from typing_extensions import ParamSpec
if TYPE_CHECKING:
from core.builtins.dispatcher import AbstractDispatcher
__all__ = ["TimeType", "job", "JobData"]
P = ParamSpec("P")
T = TypeVar("T")
R = TypeVar("R")
TimeType = Union[float, datetime.timedelta, datetime.datetime, datetime.time]
_JOB_ATTR_NAME = "_job_data"
@dataclass(init=True)
class JobData:
name: str
data: Any
chat_id: int
user_id: int
type: str
job_kwargs: JSONDict = field(default_factory=dict)
kwargs: JSONDict = field(default_factory=dict)
dispatcher: Optional[Type["AbstractDispatcher"]] = None
class _Job:
kwargs: Dict = {}
def __init__(
self,
name: str = None,
data: object = None,
chat_id: int = None,
user_id: int = None,
job_kwargs: JSONDict = None,
*,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
**kwargs,
):
self.name = name
self.data = data
self.chat_id = chat_id
self.user_id = user_id
self.job_kwargs = {} if job_kwargs is None else job_kwargs
self.kwargs = kwargs
if dispatcher is None:
from core.builtins.dispatcher import JobDispatcher
dispatcher = JobDispatcher
self.dispatcher = dispatcher
def __call__(self, func: JobCallback) -> JobCallback:
data = JobData(
name=self.name,
data=self.data,
chat_id=self.chat_id,
user_id=self.user_id,
job_kwargs=self.job_kwargs,
kwargs=self.kwargs,
type=re.sub(r"([A-Z])", lambda x: "_" + x.group().lower(), self.__class__.__name__).lstrip("_"),
dispatcher=self.dispatcher,
)
if hasattr(func, _JOB_ATTR_NAME):
job_datas = getattr(func, _JOB_ATTR_NAME)
job_datas.append(data)
setattr(func, _JOB_ATTR_NAME, job_datas)
else:
setattr(func, _JOB_ATTR_NAME, [data])
return func
class _RunOnce(_Job):
def __init__(
self,
when: TimeType,
data: object = None,
name: str = None,
chat_id: int = None,
user_id: int = None,
job_kwargs: JSONDict = None,
*,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, when=when)
class _RunRepeating(_Job):
def __init__(
self,
interval: Union[float, datetime.timedelta],
first: TimeType = None,
last: TimeType = None,
data: object = None,
name: str = None,
chat_id: int = None,
user_id: int = None,
job_kwargs: JSONDict = None,
*,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super().__init__(
name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, interval=interval, first=first, last=last
)
class _RunMonthly(_Job):
def __init__(
self,
when: datetime.time,
day: int,
data: object = None,
name: str = None,
chat_id: int = None,
user_id: int = None,
job_kwargs: JSONDict = None,
*,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, when=when, day=day)
class _RunDaily(_Job):
def __init__(
self,
time: datetime.time,
days: Tuple[int, ...] = tuple(range(7)),
data: object = None,
name: str = None,
chat_id: int = None,
user_id: int = None,
job_kwargs: JSONDict = None,
*,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, time=time, days=days)
class _RunCustom(_Job):
def __init__(
self,
data: object = None,
name: str = None,
chat_id: int = None,
user_id: int = None,
job_kwargs: JSONDict = None,
*,
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
):
super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher)
# noinspection PyPep8Naming
class job:
run_once = _RunOnce
run_repeating = _RunRepeating
run_monthly = _RunMonthly
run_daily = _RunDaily
run_custom = _RunCustom

303
core/plugin/_plugin.py Normal file
View File

@ -0,0 +1,303 @@
"""插件"""
import asyncio
from abc import ABC
from dataclasses import asdict
from datetime import timedelta
from functools import partial, wraps
from itertools import chain
from multiprocessing import RLock as Lock
from types import MethodType
from typing import (
Any,
ClassVar,
Dict,
Iterable,
List,
Optional,
TYPE_CHECKING,
Type,
TypeVar,
Union,
)
from pydantic import BaseModel
from telegram.ext import BaseHandler, ConversationHandler, Job, TypeHandler
from typing_extensions import ParamSpec
from core.handler.adminhandler import AdminHandler
from core.plugin._funcs import ConversationFuncs, PluginFuncs
from core.plugin._handler import ConversationDataType
from utils.const import WRAPPER_ASSIGNMENTS
from utils.helpers import isabstract
from utils.log import logger
if TYPE_CHECKING:
from core.application import Application
from core.plugin._handler import ConversationData, HandlerData, ErrorHandlerData
from core.plugin._job import JobData
from multiprocessing.synchronize import RLock as LockType
__all__ = ("Plugin", "PluginType", "get_all_plugins")
wraps = partial(wraps, assigned=WRAPPER_ASSIGNMENTS)
P = ParamSpec("P")
T = TypeVar("T")
R = TypeVar("R")
HandlerType = TypeVar("HandlerType", bound=BaseHandler)
_HANDLER_DATA_ATTR_NAME = "_handler_datas"
"""用于储存生成 handler 时所需要的参数(例如 block的属性名"""
_CONVERSATION_HANDLER_ATTR_NAME = "_conversation_handler_data"
"""用于储存生成 ConversationHandler 时所需要的参数(例如 block的属性名"""
_ERROR_HANDLER_ATTR_NAME = "_error_handler_data"
_JOB_ATTR_NAME = "_job_data"
_EXCLUDE_ATTRS = ["handlers", "jobs", "error_handlers"]
class _Plugin(PluginFuncs):
"""插件"""
_lock: ClassVar["LockType"] = Lock()
_asyncio_lock: ClassVar["LockType"] = asyncio.Lock()
_installed: bool = False
_handlers: Optional[List[HandlerType]] = None
_error_handlers: Optional[List["ErrorHandlerData"]] = None
_jobs: Optional[List[Job]] = None
_application: "Optional[Application]" = None
def set_application(self, application: "Application") -> None:
self._application = application
@property
def application(self) -> "Application":
if self._application is None:
raise RuntimeError("No application was set for this Plugin.")
return self._application
@property
def handlers(self) -> List[HandlerType]:
"""该插件的所有 handler"""
with self._lock:
if self._handlers is None:
self._handlers = []
for attr in dir(self):
if (
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
and isinstance(func := getattr(self, attr), MethodType)
and (datas := getattr(func, _HANDLER_DATA_ATTR_NAME, []))
):
for data in datas:
data: "HandlerData"
if data.admin:
self._handlers.append(
AdminHandler(
handler=data.type(
callback=func,
**data.kwargs,
),
application=self.application,
)
)
else:
self._handlers.append(
data.type(
callback=func,
**data.kwargs,
)
)
return self._handlers
@property
def error_handlers(self) -> List["ErrorHandlerData"]:
with self._lock:
if self._error_handlers is None:
self._error_handlers = []
for attr in dir(self):
if (
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
and isinstance(func := getattr(self, attr), MethodType)
and (datas := getattr(func, _ERROR_HANDLER_ATTR_NAME, []))
):
for data in datas:
data: "ErrorHandlerData"
data.func = func
self._error_handlers.append(data)
return self._error_handlers
def _install_jobs(self) -> None:
if self._jobs is None:
self._jobs = []
for attr in dir(self):
# noinspection PyUnboundLocalVariable
if (
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
and isinstance(func := getattr(self, attr), MethodType)
and (datas := getattr(func, _JOB_ATTR_NAME, []))
):
for data in datas:
data: "JobData"
self._jobs.append(
getattr(self.application.telegram.job_queue, data.type)(
callback=func,
**data.kwargs,
**{
key: value
for key, value in asdict(data).items()
if key not in ["type", "kwargs", "dispatcher"]
},
)
)
@property
def jobs(self) -> List[Job]:
with self._lock:
if self._jobs is None:
self._jobs = []
self._install_jobs()
return self._jobs
async def initialize(self) -> None:
"""初始化插件"""
async def shutdown(self) -> None:
"""销毁插件"""
async def install(self) -> None:
"""安装"""
group = id(self)
if not self._installed:
await self.initialize()
# initialize 必须先执行 如果出现异常不会执行 add_handler 以免出现问题
async with self._asyncio_lock:
self._install_jobs()
for h in self.handlers:
if not isinstance(h, TypeHandler):
self.application.telegram.add_handler(h, group)
else:
self.application.telegram.add_handler(h, -1)
for h in self.error_handlers:
self.application.telegram.add_error_handler(h.func, h.block)
self._installed = True
async def uninstall(self) -> None:
"""卸载"""
group = id(self)
with self._lock:
if self._installed:
if group in self.application.telegram.handlers:
del self.application.telegram.handlers[id(self)]
for h in self.handlers:
if isinstance(h, TypeHandler):
self.application.telegram.remove_handler(h, -1)
for h in self.error_handlers:
self.application.telegram.remove_error_handler(h.func)
for j in self.application.telegram.job_queue.jobs():
j.schedule_removal()
await self.shutdown()
self._installed = False
async def reload(self) -> None:
await self.uninstall()
await self.install()
class _Conversation(_Plugin, ConversationFuncs, ABC):
"""Conversation类"""
# noinspection SpellCheckingInspection
class Config(BaseModel):
allow_reentry: bool = False
per_chat: bool = True
per_user: bool = True
per_message: bool = False
conversation_timeout: Optional[Union[float, timedelta]] = None
name: Optional[str] = None
map_to_parent: Optional[Dict[object, object]] = None
block: bool = False
def __init_subclass__(cls, **kwargs):
cls._conversation_kwargs = kwargs
super(_Conversation, cls).__init_subclass__()
return cls
@property
def handlers(self) -> List[HandlerType]:
with self._lock:
if self._handlers is None:
self._handlers = []
entry_points: List[HandlerType] = []
states: Dict[Any, List[HandlerType]] = {}
fallbacks: List[HandlerType] = []
for attr in dir(self):
if (
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
and (func := getattr(self, attr, None)) is not None
and (datas := getattr(func, _HANDLER_DATA_ATTR_NAME, []))
):
conversation_data: "ConversationData"
handlers: List[HandlerType] = []
for data in datas:
handlers.append(
data.type(
callback=func,
**data.kwargs,
)
)
if conversation_data := getattr(func, _CONVERSATION_HANDLER_ATTR_NAME, None):
if (_type := conversation_data.type) is ConversationDataType.Entry:
entry_points.extend(handlers)
elif _type is ConversationDataType.State:
if conversation_data.state in states:
states[conversation_data.state].extend(handlers)
else:
states[conversation_data.state] = handlers
elif _type is ConversationDataType.Fallback:
fallbacks.extend(handlers)
else:
self._handlers.extend(handlers)
else:
self._handlers.extend(handlers)
if entry_points and states and fallbacks:
kwargs = self._conversation_kwargs
kwargs.update(self.Config().dict())
self._handlers.append(ConversationHandler(entry_points, states, fallbacks, **kwargs))
else:
temp_dict = {"entry_points": entry_points, "states": states, "fallbacks": fallbacks}
reason = map(lambda x: f"'{x[0]}'", filter(lambda x: not x[1], temp_dict.items()))
logger.warning(
"'%s' 因缺少 '%s' 而生成无法生成 ConversationHandler", self.__class__.__name__, ", ".join(reason)
)
return self._handlers
class Plugin(_Plugin, ABC):
"""插件"""
Conversation = _Conversation
PluginType = TypeVar("PluginType", bound=_Plugin)
def get_all_plugins() -> Iterable[Type[PluginType]]:
"""获取所有 Plugin 的子类"""
return filter(
lambda x: x.__name__[0] != "_" and not isabstract(x),
chain(Plugin.__subclasses__(), _Conversation.__subclasses__()),
)

View File

@ -1,14 +0,0 @@
from core.base.mysql import MySQL
from core.base.redisdb import RedisDB
from core.service import init_service
from .cache import QuizCache
from .repositories import QuizRepository
from .services import QuizService
@init_service
def create_quiz_service(mysql: MySQL, redis: RedisDB):
_repository = QuizRepository(mysql)
_cache = QuizCache(redis)
_service = QuizService(_repository, _cache)
return _service

View File

@ -1,19 +0,0 @@
from typing import List
from .models import Answer, Question
def CreatQuestionFromSQLData(data: tuple) -> List[Question]:
temp_list = []
for temp_data in data:
(question_id, text) = temp_data
temp_list.append(Question(question_id, text))
return temp_list
def CreatAnswerFromSQLData(data: tuple) -> List[Answer]:
temp_list = []
for temp_data in data:
(answer_id, question_id, is_correct, text) = temp_data
temp_list.append(Answer(answer_id, question_id, is_correct, text))
return temp_list

View File

@ -1,98 +0,0 @@
from typing import List, Optional
from sqlmodel import SQLModel, Field, Column, Integer, ForeignKey
from utils.baseobject import BaseObject
from utils.typedefs import JSONDict
class AnswerDB(SQLModel, table=True):
__tablename__ = "answer"
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
id: int = Field(primary_key=True)
question_id: Optional[int] = Field(
sa_column=Column(Integer, ForeignKey("question.id", ondelete="RESTRICT", onupdate="RESTRICT"))
)
is_correct: Optional[bool] = Field()
text: Optional[str] = Field()
class QuestionDB(SQLModel, table=True):
__tablename__ = "question"
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
id: int = Field(primary_key=True)
text: Optional[str] = Field()
class Answer(BaseObject):
def __init__(self, answer_id: int = 0, question_id: int = 0, is_correct: bool = True, text: str = ""):
"""Answer类
:param answer_id: 答案ID
:param question_id: 与之对应的问题ID
:param is_correct: 该答案是否正确
:param text: 答案文本
"""
self.answer_id = answer_id
self.question_id = question_id
self.text = text
self.is_correct = is_correct
__slots__ = ("answer_id", "question_id", "text", "is_correct")
def to_database_data(self) -> AnswerDB:
data = AnswerDB()
data.id = self.answer_id
data.question_id = self.question_id
data.text = self.text
data.is_correct = self.is_correct
return data
@classmethod
def de_database_data(cls, data: Optional[AnswerDB]) -> Optional["Answer"]:
if data is None:
return cls()
return cls(answer_id=data.id, question_id=data.question_id, text=data.text, is_correct=data.is_correct)
class Question(BaseObject):
def __init__(self, question_id: int = 0, text: str = "", answers: List[Answer] = None):
"""Question类
:param question_id: 问题ID
:param text: 问题文本
:param answers: 答案列表
"""
self.question_id = question_id
self.text = text
self.answers = [] if answers is None else answers
def to_database_data(self) -> QuestionDB:
data = QuestionDB()
data.text = self.text
data.id = self.question_id
return data
@classmethod
def de_database_data(cls, data: Optional[QuestionDB]) -> Optional["Question"]:
if data is None:
return cls()
return cls(question_id=data.id, text=data.text)
def to_dict(self) -> JSONDict:
data = super().to_dict()
if self.answers:
data["answers"] = [e.to_dict() for e in self.answers]
return data
@classmethod
def de_json(cls, data: Optional[JSONDict]) -> Optional["Question"]:
data = cls._parse_data(data)
if not data:
return None
data["answers"] = Answer.de_list(data.get("answers"))
return cls(**data)
__slots__ = ("question_id", "text", "answers")

View File

@ -1,10 +0,0 @@
from core.service import init_service
from .services import SearchServices as _SearchServices
__all__ = []
@init_service
def create_search_service():
_service = _SearchServices()
return _service

View File

@ -1,31 +0,0 @@
from abc import ABC, abstractmethod
from typing import Callable
from utils.log import logger
__all__ = ["Service", "init_service"]
class Service(ABC):
@abstractmethod
def __init__(self, *args, **kwargs):
"""初始化"""
async def start(self):
"""启动 service"""
async def stop(self):
"""关闭 service"""
def init_service(func: Callable):
from core.bot import bot
if bot.is_running:
try:
service = bot.init_inject(func)
logger.success(f'服务 "{service.__class__.__name__}" 初始化成功')
bot.add_service(service)
except Exception as e: # pylint: disable=W0703
logger.exception(f"来自{func.__module__}的服务初始化失败:{e}")
return func

View File

View File

@ -0,0 +1,5 @@
"""CookieService"""
from core.services.cookies.services import CookiesService, PublicCookiesService
__all__ = ("CookiesService", "PublicCookiesService")

View File

@ -1,12 +1,15 @@
from typing import List, Union from typing import List, Union
from core.base.redisdb import RedisDB from core.base_service import BaseService
from core.basemodel import RegionEnum
from core.dependence.redisdb import RedisDB
from core.services.cookies.error import CookiesCachePoolExhausted
from utils.error import RegionNotFoundError from utils.error import RegionNotFoundError
from utils.models.base import RegionEnum
from .error import CookiesCachePoolExhausted __all__ = ("PublicCookiesCache",)
class PublicCookiesCache: class PublicCookiesCache(BaseService.Component):
"""使用优先级(score)进行排序对使用次数最少的Cookies进行审核""" """使用优先级(score)进行排序对使用次数最少的Cookies进行审核"""
def __init__(self, redis: RedisDB): def __init__(self, redis: RedisDB):
@ -19,9 +22,8 @@ class PublicCookiesCache:
def get_public_cookies_queue_name(self, region: RegionEnum): def get_public_cookies_queue_name(self, region: RegionEnum):
if region == RegionEnum.HYPERION: if region == RegionEnum.HYPERION:
return f"{self.score_qname}:yuanshen" return f"{self.score_qname}:yuanshen"
elif region == RegionEnum.HOYOLAB: if region == RegionEnum.HOYOLAB:
return f"{self.score_qname}:genshin" return f"{self.score_qname}:genshin"
else:
raise RegionNotFoundError(region.name) raise RegionNotFoundError(region.name)
async def putback_public_cookies(self, uid: int, region: RegionEnum): async def putback_public_cookies(self, uid: int, region: RegionEnum):

View File

@ -7,11 +7,6 @@ class CookiesCachePoolExhausted(CookieServiceError):
super().__init__("Cookies cache pool is exhausted") super().__init__("Cookies cache pool is exhausted")
class CookiesNotFoundError(CookieServiceError):
def __init__(self, user_id):
super().__init__(f"{user_id} cookies not found")
class TooManyRequestPublicCookies(CookieServiceError): class TooManyRequestPublicCookies(CookieServiceError):
def __init__(self, user_id): def __init__(self, user_id):
super().__init__(f"{user_id} too many request public cookies") super().__init__(f"{user_id} too many request public cookies")

View File

@ -0,0 +1,39 @@
import enum
from typing import Optional, Dict
from sqlmodel import SQLModel, Field, Boolean, Column, Enum, JSON, Integer, BigInteger, Index
from core.basemodel import RegionEnum
__all__ = ("Cookies", "CookiesDataBase", "CookiesStatusEnum")
class CookiesStatusEnum(int, enum.Enum):
STATUS_SUCCESS = 0
INVALID_COOKIES = 1
TOO_MANY_REQUESTS = 2
class Cookies(SQLModel):
__table_args__ = (
Index("index_user_account", "user_id", "account_id", unique=True),
dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"),
)
id: Optional[int] = Field(default=None, sa_column=Column(Integer, primary_key=True, autoincrement=True))
user_id: int = Field(
sa_column=Column(BigInteger()),
)
account_id: int = Field(
default=None,
sa_column=Column(
BigInteger(),
),
)
data: Optional[Dict[str, str]] = Field(sa_column=Column(JSON))
status: Optional[CookiesStatusEnum] = Field(sa_column=Column(Enum(CookiesStatusEnum)))
region: RegionEnum = Field(sa_column=Column(Enum(RegionEnum)))
is_share: Optional[bool] = Field(sa_column=Column(Boolean))
class CookiesDataBase(Cookies, table=True):
__tablename__ = "cookies"

View File

@ -0,0 +1,55 @@
from typing import Optional, List
from sqlmodel import select
from core.base_service import BaseService
from core.basemodel import RegionEnum
from core.dependence.mysql import MySQL
from core.services.cookies.models import CookiesDataBase as Cookies
from core.sqlmodel.session import AsyncSession
__all__ = ("CookiesRepository",)
class CookiesRepository(BaseService.Component):
def __init__(self, mysql: MySQL):
self.engine = mysql.engine
async def get(
self,
user_id: int,
account_id: Optional[int] = None,
region: Optional[RegionEnum] = None,
) -> Optional[Cookies]:
async with AsyncSession(self.engine) as session:
statement = select(Cookies).where(Cookies.user_id == user_id)
if account_id is not None:
statement = statement.where(Cookies.account_id == account_id)
if region is not None:
statement = statement.where(Cookies.region == region)
results = await session.exec(statement)
return results.first()
async def add(self, cookies: Cookies) -> None:
async with AsyncSession(self.engine) as session:
session.add(cookies)
await session.commit()
async def update(self, cookies: Cookies) -> Cookies:
async with AsyncSession(self.engine) as session:
session.add(cookies)
await session.commit()
await session.refresh(cookies)
return cookies
async def delete(self, cookies: Cookies) -> None:
async with AsyncSession(self.engine) as session:
await session.delete(cookies)
await session.commit()
async def get_all_by_region(self, region: RegionEnum) -> List[Cookies]:
async with AsyncSession(self.engine) as session:
statement = select(Cookies).where(Cookies.region == region)
results = await session.exec(statement)
cookies = results.all()
return cookies

View File

@ -1,67 +1,73 @@
from typing import List from typing import List, Optional
import genshin import genshin
from genshin import GenshinException, InvalidCookies, TooManyRequests, types, Game from genshin import Game, GenshinException, InvalidCookies, TooManyRequests, types
from core.base_service import BaseService
from core.basemodel import RegionEnum
from core.services.cookies.cache import PublicCookiesCache
from core.services.cookies.error import CookieServiceError, TooManyRequestPublicCookies
from core.services.cookies.models import CookiesDataBase as Cookies, CookiesStatusEnum
from core.services.cookies.repositories import CookiesRepository
from utils.log import logger from utils.log import logger
from utils.models.base import RegionEnum
from .cache import PublicCookiesCache __all__ = ("CookiesService", "PublicCookiesService")
from .error import TooManyRequestPublicCookies, CookieServiceError
from .models import CookiesStatusEnum
from .repositories import CookiesNotFoundError, CookiesRepository
class CookiesService: class CookiesService(BaseService):
def __init__(self, cookies_repository: CookiesRepository) -> None: def __init__(self, cookies_repository: CookiesRepository) -> None:
self._repository: CookiesRepository = cookies_repository self._repository: CookiesRepository = cookies_repository
async def update_cookies(self, user_id: int, cookies: dict, region: RegionEnum): async def update(self, cookies: Cookies):
await self._repository.update_cookies(user_id, cookies, region) await self._repository.update(cookies)
async def add_cookies(self, user_id: int, cookies: dict, region: RegionEnum): async def add(self, cookies: Cookies):
await self._repository.add_cookies(user_id, cookies, region) await self._repository.add(cookies)
async def get_cookies(self, user_id: int, region: RegionEnum): async def get(
return await self._repository.get_cookies(user_id, region) self,
user_id: int,
account_id: Optional[int] = None,
region: Optional[RegionEnum] = None,
) -> Optional[Cookies]:
return await self._repository.get(user_id, account_id, region)
async def del_cookies(self, user_id: int, region: RegionEnum): async def delete(self, cookies: Cookies) -> None:
return await self._repository.del_cookies(user_id, region) return await self._repository.delete(cookies)
async def add_or_update_cookies(self, user_id: int, cookies: dict, region: RegionEnum):
try:
await self.get_cookies(user_id, region)
await self.update_cookies(user_id, cookies, region)
except CookiesNotFoundError:
await self.add_cookies(user_id, cookies, region)
class PublicCookiesService: class PublicCookiesService(BaseService):
def __init__(self, cookies_repository: CookiesRepository, public_cookies_cache: PublicCookiesCache): def __init__(self, cookies_repository: CookiesRepository, public_cookies_cache: PublicCookiesCache):
self._cache = public_cookies_cache self._cache = public_cookies_cache
self._repository: CookiesRepository = cookies_repository self._repository: CookiesRepository = cookies_repository
self.count: int = 0 self.count: int = 0
self.user_times_limiter = 3 * 3 self.user_times_limiter = 3 * 3
async def initialize(self) -> None:
logger.info("正在初始化公共Cookies池")
await self.refresh()
logger.success("刷新公共Cookies池成功")
async def refresh(self): async def refresh(self):
"""刷新公共Cookies 定时任务 """刷新公共Cookies 定时任务
:return: :return:
""" """
user_list: List[int] = [] user_list: List[int] = []
cookies_list = await self._repository.get_all_cookies(RegionEnum.HYPERION) # 从数据库获取2 cookies_list = await self._repository.get_all_by_region(RegionEnum.HYPERION) # 从数据库获取2
for cookies in cookies_list: for cookies in cookies_list:
if cookies.status is None or cookies.status == CookiesStatusEnum.STATUS_SUCCESS: if cookies.status is None or cookies.status == CookiesStatusEnum.STATUS_SUCCESS:
user_list.append(cookies.user_id) user_list.append(cookies.user_id)
if len(user_list) > 0: if len(user_list) > 0:
add, count = await self._cache.add_public_cookies(user_list, RegionEnum.HYPERION) add, count = await self._cache.add_public_cookies(user_list, RegionEnum.HYPERION)
logger.info(f"国服公共Cookies池已经添加[{add}]个 当前成员数为[{count}]") logger.info("国服公共Cookies池已经添加[%s]个 当前成员数为[%s]", add, count)
user_list.clear() user_list.clear()
cookies_list = await self._repository.get_all_cookies(RegionEnum.HOYOLAB) cookies_list = await self._repository.get_all_by_region(RegionEnum.HOYOLAB)
for cookies in cookies_list: for cookies in cookies_list:
if cookies.status is None or cookies.status == CookiesStatusEnum.STATUS_SUCCESS: if cookies.status is None or cookies.status == CookiesStatusEnum.STATUS_SUCCESS:
user_list.append(cookies.user_id) user_list.append(cookies.user_id)
if len(user_list) > 0: if len(user_list) > 0:
add, count = await self._cache.add_public_cookies(user_list, RegionEnum.HOYOLAB) add, count = await self._cache.add_public_cookies(user_list, RegionEnum.HOYOLAB)
logger.info(f"国际服公共Cookies池已经添加[{add}]个 当前成员数为[{count}]") logger.info("国际服公共Cookies池已经添加[%s]个 当前成员数为[%s]", add, count)
async def get_cookies(self, user_id: int, region: RegionEnum = RegionEnum.NULL): async def get_cookies(self, user_id: int, region: RegionEnum = RegionEnum.NULL):
"""获取公共Cookies """获取公共Cookies
@ -71,20 +77,19 @@ class PublicCookiesService:
""" """
user_times = await self._cache.incr_by_user_times(user_id) user_times = await self._cache.incr_by_user_times(user_id)
if int(user_times) > self.user_times_limiter: if int(user_times) > self.user_times_limiter:
logger.warning(f"用户 [{user_id}] 使用公共Cookie次数已经到达上限") logger.warning("用户 %s 使用公共Cookie次数已经到达上限", user_id)
raise TooManyRequestPublicCookies(user_id) raise TooManyRequestPublicCookies(user_id)
while True: while True:
public_id, count = await self._cache.get_public_cookies(region) public_id, count = await self._cache.get_public_cookies(region)
try: cookies = await self._repository.get(public_id, region=region)
cookies = await self._repository.get_cookies(public_id, region) if cookies is None:
except CookiesNotFoundError:
await self._cache.delete_public_cookies(public_id, region) await self._cache.delete_public_cookies(public_id, region)
continue continue
if region == RegionEnum.HYPERION: if region == RegionEnum.HYPERION:
client = genshin.Client(cookies=cookies.cookies, game=types.Game.GENSHIN, region=types.Region.CHINESE) client = genshin.Client(cookies=cookies.data, game=types.Game.GENSHIN, region=types.Region.CHINESE)
elif region == RegionEnum.HOYOLAB: elif region == RegionEnum.HOYOLAB:
client = genshin.Client( client = genshin.Client(
cookies=cookies.cookies, game=types.Game.GENSHIN, region=types.Region.OVERSEAS, lang="zh-cn" cookies=cookies.data, game=types.Game.GENSHIN, region=types.Region.OVERSEAS, lang="zh-cn"
) )
else: else:
raise CookieServiceError raise CookieServiceError
@ -101,13 +106,13 @@ class PublicCookiesService:
logger.warning("Cookies无效 ") logger.warning("Cookies无效 ")
logger.exception(exc) logger.exception(exc)
cookies.status = CookiesStatusEnum.INVALID_COOKIES cookies.status = CookiesStatusEnum.INVALID_COOKIES
await self._repository.update_cookies_ex(cookies, region) await self._repository.update(cookies)
await self._cache.delete_public_cookies(cookies.user_id, region) await self._cache.delete_public_cookies(cookies.user_id, region)
continue continue
except TooManyRequests: except TooManyRequests:
logger.warning("用户 [%s] 查询次数太多或操作频繁", public_id) logger.warning("用户 [%s] 查询次数太多或操作频繁", public_id)
cookies.status = CookiesStatusEnum.TOO_MANY_REQUESTS cookies.status = CookiesStatusEnum.TOO_MANY_REQUESTS
await self._repository.update_cookies_ex(cookies, region) await self._repository.update(cookies)
await self._cache.delete_public_cookies(cookies.user_id, region) await self._cache.delete_public_cookies(cookies.user_id, region)
continue continue
except GenshinException as exc: except GenshinException as exc:

View File

@ -0,0 +1 @@
"""GameService"""

View File

@ -1,12 +1,16 @@
from typing import List from typing import List
from core.base.redisdb import RedisDB from core.base_service import BaseService
from core.dependence.redisdb import RedisDB
__all__ = ["GameCache", "GameCacheForStrategy", "GameCacheForMaterial"]
class GameCache: class GameCache:
def __init__(self, redis: RedisDB, qname: str, ttl: int = 3600): qname: str
def __init__(self, redis: RedisDB, ttl: int = 3600):
self.client = redis.client self.client = redis.client
self.qname = qname
self.ttl = ttl self.ttl = ttl
async def get_url_list(self, character_name: str): async def get_url_list(self, character_name: str):
@ -19,3 +23,11 @@ class GameCache:
await self.client.lpush(qname, *str_list) await self.client.lpush(qname, *str_list)
await self.client.expire(qname, self.ttl) await self.client.expire(qname, self.ttl)
return await self.client.llen(qname) return await self.client.llen(qname)
class GameCacheForStrategy(BaseService.Component, GameCache):
qname = "game:strategy"
class GameCacheForMaterial(BaseService.Component, GameCache):
qname = "game:material"

View File

@ -1,11 +1,14 @@
from typing import List, Optional from typing import List, Optional
from core.base_service import BaseService
from core.services.game.cache import GameCacheForMaterial, GameCacheForStrategy
from modules.apihelper.client.components.hyperion import Hyperion from modules.apihelper.client.components.hyperion import Hyperion
from .cache import GameCache
__all__ = ("GameMaterialService", "GameStrategyService")
class GameStrategyService: class GameStrategyService(BaseService):
def __init__(self, cache: GameCache, collections: Optional[List[int]] = None): def __init__(self, cache: GameCacheForStrategy, collections: Optional[List[int]] = None):
self._cache = cache self._cache = cache
self._hyperion = Hyperion() self._hyperion = Hyperion()
if collections is None: if collections is None:
@ -49,8 +52,8 @@ class GameStrategyService:
return artwork_info.image_urls[0] return artwork_info.image_urls[0]
class GameMaterialService: class GameMaterialService(BaseService):
def __init__(self, cache: GameCache, collections: Optional[List[int]] = None): def __init__(self, cache: GameCacheForMaterial, collections: Optional[List[int]] = None):
self._cache = cache self._cache = cache
self._hyperion = Hyperion() self._hyperion = Hyperion()
self._collections = [428421, 1164644] if collections is None else collections self._collections = [428421, 1164644] if collections is None else collections
@ -91,9 +94,8 @@ class GameMaterialService:
await self._cache.set_url_list(character_name, image_url_list) await self._cache.set_url_list(character_name, image_url_list)
if len(image_url_list) == 0: if len(image_url_list) == 0:
return "" return ""
elif len(image_url_list) == 1: if len(image_url_list) == 1:
return image_url_list[0] return image_url_list[0]
elif character_name in self._special: if character_name in self._special:
return image_url_list[2] return image_url_list[2]
else:
return image_url_list[1] return image_url_list[1]

View File

@ -0,0 +1,3 @@
from .services import PlayersService
__all__ = ("PlayersService",)

View File

@ -0,0 +1,2 @@
class PlayerNotFoundError(Exception):
pass

View File

@ -0,0 +1,96 @@
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, BaseSettings
from sqlalchemy import TypeDecorator
from sqlmodel import Boolean, Column, Enum, Field, SQLModel, Integer, Index, BigInteger, VARCHAR, func, DateTime
from core.basemodel import RegionEnum
try:
import ujson as jsonlib
except ImportError:
import json as jsonlib
__all__ = ("Player", "PlayersDataBase", "PlayerInfo", "PlayerInfoSQLModel")
class Player(SQLModel):
__table_args__ = (
Index("index_user_account_player", "user_id", "account_id", "player_id", unique=True),
dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"),
)
id: Optional[int] = Field(
default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True)
)
user_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
account_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
player_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
region: RegionEnum = Field(sa_column=Column(Enum(RegionEnum)))
is_chosen: Optional[bool] = Field(sa_column=Column(Boolean))
class PlayersDataBase(Player, table=True):
__tablename__ = "players"
class ExtraPlayerInfo(BaseModel):
class Config(BaseSettings.Config):
json_loads = jsonlib.loads
json_dumps = jsonlib.dumps
waifu_id: Optional[int] = None
class ExtraPlayerType(TypeDecorator): # pylint: disable=W0223
impl = VARCHAR(length=521)
cache_ok = True
def process_bind_param(self, value, dialect):
"""
:param value: ExtraPlayerInfo | obj | None
:param dialect:
:return:
"""
if value is not None:
if isinstance(value, ExtraPlayerInfo):
return value.json()
raise TypeError
return value
def process_result_value(self, value, dialect):
"""
:param value: str | obj | None
:param dialect:
:return:
"""
if value is not None:
return ExtraPlayerInfo.parse_raw(value)
return None
class PlayerInfo(SQLModel):
__table_args__ = (
Index("index_user_account_player", "user_id", "player_id", unique=True),
dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"),
)
id: Optional[int] = Field(
default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True)
)
user_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
player_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
nickname: Optional[str] = Field()
signature: Optional[str] = Field()
hand_image: Optional[int] = Field()
name_card: Optional[int] = Field()
extra_data: Optional[ExtraPlayerInfo] = Field(sa_column=Column(ExtraPlayerType))
create_time: Optional[datetime] = Field(
sa_column=Column(DateTime, server_default=func.now()) # pylint: disable=E1102
)
last_save_time: Optional[datetime] = Field(sa_column=Column(DateTime, onupdate=func.now())) # pylint: disable=E1102
is_update: Optional[bool] = Field(sa_column=Column(Boolean))
class PlayerInfoSQLModel(PlayerInfo, table=True):
__tablename__ = "players_info"

View File

@ -0,0 +1,109 @@
from typing import List, Optional
from sqlmodel import select, delete
from core.base_service import BaseService
from core.basemodel import RegionEnum
from core.dependence.mysql import MySQL
from core.services.players.models import PlayerInfoSQLModel
from core.services.players.models import PlayersDataBase as Player
from core.sqlmodel.session import AsyncSession
__all__ = ("PlayersRepository", "PlayerInfoRepository")
class PlayersRepository(BaseService.Component):
def __init__(self, mysql: MySQL):
self.engine = mysql.engine
async def get(
self,
user_id: int,
player_id: Optional[int] = None,
account_id: Optional[int] = None,
region: Optional[RegionEnum] = None,
is_chosen: Optional[bool] = None,
) -> Optional[Player]:
async with AsyncSession(self.engine) as session:
statement = select(Player).where(Player.user_id == user_id)
if player_id is not None:
statement = statement.where(Player.player_id == player_id)
if account_id is not None:
statement = statement.where(Player.account_id == account_id)
if region is not None:
statement = statement.where(Player.region == region)
if is_chosen is not None:
statement = statement.where(Player.is_chosen == is_chosen)
results = await session.exec(statement)
return results.first()
async def add(self, player: Player) -> None:
async with AsyncSession(self.engine) as session:
session.add(player)
await session.commit()
async def delete(self, player: Player) -> None:
async with AsyncSession(self.engine) as session:
await session.delete(player)
await session.commit()
async def update(self, player: Player) -> None:
async with AsyncSession(self.engine) as session:
session.add(player)
await session.commit()
await session.refresh(player)
async def get_all_by_user_id(self, user_id: int) -> List[Player]:
async with AsyncSession(self.engine) as session:
statement = select(Player).where(Player.user_id == user_id)
results = await session.exec(statement)
players = results.all()
return players
class PlayerInfoRepository(BaseService.Component):
def __init__(self, mysql: MySQL):
self.engine = mysql.engine
async def get(
self,
user_id: int,
player_id: int,
) -> Optional[PlayerInfoSQLModel]:
async with AsyncSession(self.engine) as session:
statement = (
select(PlayerInfoSQLModel)
.where(PlayerInfoSQLModel.player_id == player_id)
.where(PlayerInfoSQLModel.user_id == user_id)
)
results = await session.exec(statement)
return results.first()
async def add(self, player: PlayerInfoSQLModel) -> None:
async with AsyncSession(self.engine) as session:
session.add(player)
await session.commit()
async def delete(self, player: PlayerInfoSQLModel) -> None:
async with AsyncSession(self.engine) as session:
await session.delete(player)
await session.commit()
async def delete_by_id(
self,
user_id: int,
player_id: int,
) -> None:
async with AsyncSession(self.engine) as session:
statement = (
delete(PlayerInfoSQLModel)
.where(PlayerInfoSQLModel.player_id == player_id)
.where(PlayerInfoSQLModel.user_id == user_id)
)
await session.execute(statement)
async def update(self, player: PlayerInfoSQLModel) -> None:
async with AsyncSession(self.engine) as session:
session.add(player)
await session.commit()
await session.refresh(player)

View File

@ -0,0 +1,184 @@
from datetime import datetime, timedelta
from typing import List, Optional
from aiohttp import ClientConnectorError
from enkanetwork import (
EnkaNetworkAPI,
VaildateUIDError,
HTTPException,
EnkaPlayerNotFound,
PlayerInfo as EnkaPlayerInfo,
)
from core.base_service import BaseService
from core.basemodel import RegionEnum
from core.config import config
from core.dependence.redisdb import RedisDB
from core.services.players.models import PlayersDataBase as Player, PlayerInfoSQLModel, PlayerInfo
from core.services.players.repositories import PlayersRepository, PlayerInfoRepository
from utils.enkanetwork import RedisCache
from utils.log import logger
from utils.patch.aiohttp import AioHttpTimeoutException
__all__ = ("PlayersService", "PlayerInfoService")
class PlayersService(BaseService):
def __init__(self, players_repository: PlayersRepository) -> None:
self._repository = players_repository
async def get(
self,
user_id: int,
player_id: Optional[int] = None,
account_id: Optional[int] = None,
region: Optional[RegionEnum] = None,
is_chosen: Optional[bool] = None,
) -> Optional[Player]:
return await self._repository.get(user_id, player_id, account_id, region, is_chosen)
async def get_player(self, user_id: int, region: Optional[RegionEnum] = None) -> Optional[Player]:
return await self._repository.get(user_id, region=region, is_chosen=True)
async def add(self, player: Player) -> None:
await self._repository.add(player)
async def update(self, player: Player) -> None:
await self._repository.update(player)
async def get_all_by_user_id(self, user_id: int) -> List[Player]:
return await self._repository.get_all_by_user_id(user_id)
async def remove_all_by_user_id(self, user_id: int):
players = await self._repository.get_all_by_user_id(user_id)
for player in players:
await self._repository.delete(player)
async def delete(self, player: Player):
await self._repository.delete(player)
class PlayerInfoService(BaseService):
def __init__(self, redis: RedisDB, players_info_repository: PlayerInfoRepository):
self.cache = redis.client
self._players_info_repository = players_info_repository
self.enka_client = EnkaNetworkAPI(lang="chs", user_agent=config.enka_network_api_agent)
self.enka_client.set_cache(RedisCache(redis.client, key="players_info:enka_network", ttl=60))
self.qname = "players_info"
async def get_form_cache(self, player: Player):
qname = f"{self.qname}:{player.user_id}:{player.player_id}"
data = await self.cache.get(qname)
if data is None:
return None
json_data = str(data, encoding="utf-8")
return PlayerInfo.parse_raw(json_data)
async def set_form_cache(self, player: PlayerInfo):
qname = f"{self.qname}:{player.user_id}:{player.player_id}"
await self.cache.set(qname, player.json(), ex=60)
async def get_player_info_from_enka(self, player_id: int) -> Optional[EnkaPlayerInfo]:
try:
response = await self.enka_client.fetch_user(player_id, info=True)
return response.player
except (VaildateUIDError, EnkaPlayerNotFound, HTTPException) as exc:
logger.warning("EnkaNetwork 请求失败: %s", str(exc))
except AioHttpTimeoutException as exc:
logger.warning("EnkaNetwork 请求超时: %s", str(exc))
except ClientConnectorError as exc:
logger.warning("EnkaNetwork 请求错误: %s", str(exc))
except Exception as exc:
logger.error("EnkaNetwork 请求失败: %s", exc_info=exc)
return None
async def get(self, player: Player) -> Optional[PlayerInfo]:
player_info = await self.get_form_cache(player)
if player_info is not None:
return player_info
player_info = await self._players_info_repository.get(player.user_id, player.player_id)
if player_info is None:
player_info_enka = await self.get_player_info_from_enka(player.player_id)
if player_info_enka is None:
return None
player_info = PlayerInfo(
user_id=player.user_id,
player_id=player.player_id,
nickname=player_info_enka.nickname,
signature=player_info_enka.signature,
name_card=player_info_enka.namecard.id,
hand_image=player_info_enka.avatar.id,
create_time=datetime.now(),
last_save_time=datetime.now(),
is_update=True,
)
await self._players_info_repository.add(PlayerInfoSQLModel.from_orm(player_info))
await self.set_form_cache(player_info)
return player_info
if player_info.is_update:
expiration_time = datetime.now() - timedelta(days=7)
if player_info.last_save_time is None or player_info.last_save_time <= expiration_time:
player_info_enka = await self.get_player_info_from_enka(player.player_id)
if player_info_enka is None:
player_info.last_save_time = datetime.now()
await self._players_info_repository.update(PlayerInfoSQLModel.from_orm(player_info))
await self.set_form_cache(player_info)
return player_info
player_info.nickname = player_info_enka.nickname
player_info.name_card = player_info_enka.namecard.id
player_info.signature = player_info_enka.signature
player_info.hand_image = player_info_enka.avatar.id
player_info.nickname = player_info_enka.nickname
player_info.last_save_time = datetime.now()
await self._players_info_repository.update(PlayerInfoSQLModel.from_orm(player_info))
await self.set_form_cache(player_info)
return player_info
async def update_from_enka(self, player: Player) -> bool:
player_info = await self._players_info_repository.get(player.user_id, player.player_id)
if player_info is not None:
player_info_enka = await self.get_player_info_from_enka(player.player_id)
if player_info_enka is None:
return False
player_info.nickname = player_info_enka.nickname
player_info.name_card = player_info_enka.namecard.id
player_info.signature = player_info_enka.signature
player_info.hand_image = player_info_enka.avatar.id
player_info.nickname = player_info_enka.nickname
player_info.last_save_time = datetime.now()
await self._players_info_repository.update(player_info)
return True
return False
async def add_from_enka(self, player: Player) -> bool:
player_info = await self._players_info_repository.get(player.user_id, player.player_id)
if player_info is None:
player_info_enka = await self.get_player_info_from_enka(player.player_id)
if player_info_enka is None:
return False
player_info = PlayerInfoSQLModel(
user_id=player.user_id,
player_id=player.player_id,
nickname=player_info_enka.nickname,
signature=player_info_enka.signature,
name_card=player_info_enka.namecard.id,
hand_image=player_info_enka.avatar.id,
create_time=datetime.now(),
last_save_time=datetime.now(),
is_update=True,
)
await self._players_info_repository.add(player_info)
return True
return False
async def get_form_sql(self, player: Player):
return await self._players_info_repository.get(player.user_id, player.player_id)
async def delete_form_player(self, player: Player):
await self._players_info_repository.delete_by_id(user_id=player.user_id, player_id=player.player_id)
async def add(self, player_info: PlayerInfo):
await self._players_info_repository.add(PlayerInfoSQLModel.from_orm(player_info))
async def delete(self, player_info: PlayerInfo):
await self._players_info_repository.delete(PlayerInfoSQLModel.from_orm(player_info))

View File

@ -0,0 +1 @@
"""QuizService"""

View File

@ -1,12 +1,13 @@
from typing import List from typing import List
import ujson from core.base_service import BaseService
from core.dependence.redisdb import RedisDB
from core.services.quiz.models import Answer, Question
from core.base.redisdb import RedisDB __all__ = ("QuizCache",)
from .models import Answer, Question
class QuizCache: class QuizCache(BaseService.Component):
def __init__(self, redis: RedisDB): def __init__(self, redis: RedisDB):
self.client = redis.client self.client = redis.client
self.question_qname = "quiz:question" self.question_qname = "quiz:question"
@ -18,7 +19,7 @@ class QuizCache:
data_list = [self.question_qname + f":{question_id}" for question_id in await self.client.lrange(qname, 0, -1)] data_list = [self.question_qname + f":{question_id}" for question_id in await self.client.lrange(qname, 0, -1)]
data = await self.client.mget(data_list) data = await self.client.mget(data_list)
for i in data: for i in data:
temp_list.append(Question.de_json(ujson.loads(i))) temp_list.append(Question.parse_raw(i))
return temp_list return temp_list
async def get_all_question_id_list(self) -> List[str]: async def get_all_question_id_list(self) -> List[str]:
@ -29,19 +30,19 @@ class QuizCache:
qname = f"{self.question_qname}:{question_id}" qname = f"{self.question_qname}:{question_id}"
data = await self.client.get(qname) data = await self.client.get(qname)
json_data = str(data, encoding="utf-8") json_data = str(data, encoding="utf-8")
return Question.de_json(ujson.loads(json_data)) return Question.parse_raw(json_data)
async def get_one_answer(self, answer_id: int) -> Answer: async def get_one_answer(self, answer_id: int) -> Answer:
qname = f"{self.answer_qname}:{answer_id}" qname = f"{self.answer_qname}:{answer_id}"
data = await self.client.get(qname) data = await self.client.get(qname)
json_data = str(data, encoding="utf-8") json_data = str(data, encoding="utf-8")
return Answer.de_json(ujson.loads(json_data)) return Answer.parse_raw(json_data)
async def add_question(self, question_list: List[Question] = None) -> int: async def add_question(self, question_list: List[Question] = None) -> int:
if not question_list: if not question_list:
return 0 return 0
for question in question_list: for question in question_list:
await self.client.set(f"{self.question_qname}:{question.question_id}", ujson.dumps(question.to_dict())) await self.client.set(f"{self.question_qname}:{question.question_id}", question.json())
question_id_list = [question.question_id for question in question_list] question_id_list = [question.question_id for question in question_list]
await self.client.lpush(f"{self.question_qname}:id_list", *question_id_list) await self.client.lpush(f"{self.question_qname}:id_list", *question_id_list)
return await self.client.llen(f"{self.question_qname}:id_list") return await self.client.llen(f"{self.question_qname}:id_list")
@ -62,7 +63,7 @@ class QuizCache:
if not answer_list: if not answer_list:
return 0 return 0
for answer in answer_list: for answer in answer_list:
await self.client.set(f"{self.answer_qname}:{answer.answer_id}", ujson.dumps(answer.to_dict())) await self.client.set(f"{self.answer_qname}:{answer.answer_id}", answer.json())
answer_id_list = [answer.answer_id for answer in answer_list] answer_id_list = [answer.answer_id for answer in answer_list]
await self.client.lpush(f"{self.answer_qname}:id_list", *answer_id_list) await self.client.lpush(f"{self.answer_qname}:id_list", *answer_id_list)
return await self.client.llen(f"{self.answer_qname}:id_list") return await self.client.llen(f"{self.answer_qname}:id_list")

View File

@ -0,0 +1,57 @@
from typing import List, Optional
from pydantic import BaseModel
from sqlmodel import Column, Field, ForeignKey, Integer, SQLModel
__all__ = ("Answer", "AnswerDB", "Question", "QuestionDB")
class AnswerDB(SQLModel, table=True):
__tablename__ = "answer"
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
id: Optional[int] = Field(
default=None, primary_key=True, sa_column=Column(Integer, primary_key=True, autoincrement=True)
)
question_id: Optional[int] = Field(
sa_column=Column(Integer, ForeignKey("question.id", ondelete="RESTRICT", onupdate="RESTRICT"))
)
is_correct: Optional[bool] = Field()
text: Optional[str] = Field()
class QuestionDB(SQLModel, table=True):
__tablename__ = "question"
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
id: Optional[int] = Field(
default=None, primary_key=True, sa_column=Column(Integer, primary_key=True, autoincrement=True)
)
text: Optional[str] = Field()
class Answer(BaseModel):
answer_id: int = 0
question_id: int = 0
is_correct: bool = True
text: str = ""
def to_database_data(self) -> AnswerDB:
return AnswerDB(id=self.answer_id, question_id=self.question_id, text=self.text, is_correct=self.is_correct)
@classmethod
def de_database_data(cls, data: AnswerDB) -> Optional["Answer"]:
return cls(answer_id=data.id, question_id=data.question_id, text=data.text, is_correct=data.is_correct)
class Question(BaseModel):
question_id: int = 0
text: str = ""
answers: List[Answer] = []
def to_database_data(self) -> QuestionDB:
return QuestionDB(text=self.text, id=self.question_id)
@classmethod
def de_database_data(cls, data: QuestionDB) -> Optional["Question"]:
return cls(question_id=data.id, text=data.text)

View File

@ -2,54 +2,55 @@ from typing import List
from sqlmodel import select from sqlmodel import select
from core.base.mysql import MySQL from core.base_service import BaseService
from .models import AnswerDB, QuestionDB from core.dependence.mysql import MySQL
from core.services.quiz.models import AnswerDB, QuestionDB
from core.sqlmodel.session import AsyncSession
__all__ = ("QuizRepository",)
class QuizRepository: class QuizRepository(BaseService.Component):
def __init__(self, mysql: MySQL): def __init__(self, mysql: MySQL):
self.mysql = mysql self.engine = mysql.engine
async def get_question_list(self) -> List[QuestionDB]: async def get_question_list(self) -> List[QuestionDB]:
async with self.mysql.Session() as session: async with AsyncSession(self.engine) as session:
query = select(QuestionDB) query = select(QuestionDB)
results = await session.exec(query) results = await session.exec(query)
questions = results.all() return results.all()
return questions
async def get_answers_from_question_id(self, question_id: int) -> List[AnswerDB]: async def get_answers_from_question_id(self, question_id: int) -> List[AnswerDB]:
async with self.mysql.Session() as session: async with AsyncSession(self.engine) as session:
query = select(AnswerDB).where(AnswerDB.question_id == question_id) query = select(AnswerDB).where(AnswerDB.question_id == question_id)
results = await session.exec(query) results = await session.exec(query)
answers = results.all() return results.all()
return answers
async def add_question(self, question: QuestionDB): async def add_question(self, question: QuestionDB):
async with self.mysql.Session() as session: async with AsyncSession(self.engine) as session:
session.add(question) session.add(question)
await session.commit() await session.commit()
async def get_question_by_text(self, text: str) -> QuestionDB: async def get_question_by_text(self, text: str) -> QuestionDB:
async with self.mysql.Session() as session: async with AsyncSession(self.engine) as session:
query = select(QuestionDB).where(QuestionDB.text == text) query = select(QuestionDB).where(QuestionDB.text == text)
results = await session.exec(query) results = await session.exec(query)
question = results.first() return results.first()
return question[0]
async def add_answer(self, answer: AnswerDB): async def add_answer(self, answer: AnswerDB):
async with self.mysql.Session() as session: async with AsyncSession(self.engine) as session:
session.add(answer) session.add(answer)
await session.commit() await session.commit()
async def delete_question_by_id(self, question_id: int): async def delete_question_by_id(self, question_id: int):
async with self.mysql.Session() as session: async with AsyncSession(self.engine) as session:
statement = select(QuestionDB).where(QuestionDB.id == question_id) statement = select(QuestionDB).where(QuestionDB.id == question_id)
results = await session.exec(statement) results = await session.exec(statement)
question = results.one() question = results.one()
await session.delete(question) await session.delete(question)
async def delete_answer_by_id(self, answer_id: int): async def delete_answer_by_id(self, answer_id: int):
async with self.mysql.Session() as session: async with AsyncSession(self.engine) as session:
statement = select(AnswerDB).where(AnswerDB.id == answer_id) statement = select(AnswerDB).where(AnswerDB.id == answer_id)
results = await session.exec(statement) results = await session.exec(statement)
answer = results.one() answer = results.one()

View File

@ -1,12 +1,15 @@
import asyncio import asyncio
from typing import List from typing import List
from .cache import QuizCache from core.base_service import BaseService
from .models import Answer, Question from core.services.quiz.cache import QuizCache
from .repositories import QuizRepository from core.services.quiz.models import Answer, Question
from core.services.quiz.repositories import QuizRepository
__all__ = ("QuizService",)
class QuizService: class QuizService(BaseService):
def __init__(self, repository: QuizRepository, cache: QuizCache): def __init__(self, repository: QuizRepository, cache: QuizCache):
self._repository = repository self._repository = repository
self._cache = cache self._cache = cache

View File

@ -0,0 +1 @@
"""SearchService"""

View File

@ -1,12 +1,11 @@
from abc import abstractmethod from abc import abstractmethod
from typing import Optional, List from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel
__all__ = ["BaseEntry", "WeaponEntry", "WeaponsEntry", "StrategyEntry", "StrategyEntryList"]
from thefuzz import fuzz from thefuzz import fuzz
__all__ = ("BaseEntry", "WeaponEntry", "WeaponsEntry", "StrategyEntry", "StrategyEntryList")
class BaseEntry(BaseModel): class BaseEntry(BaseModel):
"""所有可搜索条目的基类。 """所有可搜索条目的基类。

View File

@ -5,19 +5,22 @@ import json
import os import os
import time import time
from pathlib import Path from pathlib import Path
from typing import Tuple, List, Optional, Dict from typing import Dict, List, Optional, Tuple
import aiofiles import aiofiles
from async_lru import alru_cache from async_lru import alru_cache
from core.search.models import WeaponEntry, BaseEntry, WeaponsEntry, StrategyEntry, StrategyEntryList from core.base_service import BaseService
from core.services.search.models import BaseEntry, StrategyEntry, StrategyEntryList, WeaponEntry, WeaponsEntry
from utils.const import PROJECT_ROOT from utils.const import PROJECT_ROOT
__all__ = ("SearchServices",)
ENTRY_DAYA_PATH = PROJECT_ROOT.joinpath("data", "entry") ENTRY_DAYA_PATH = PROJECT_ROOT.joinpath("data", "entry")
ENTRY_DAYA_PATH.mkdir(parents=True, exist_ok=True) ENTRY_DAYA_PATH.mkdir(parents=True, exist_ok=True)
class SearchServices: class SearchServices(BaseService):
def __init__(self): def __init__(self):
self._lock = asyncio.Lock() # 访问和修改操作成员变量必须加锁操作 self._lock = asyncio.Lock() # 访问和修改操作成员变量必须加锁操作
self.weapons: List[WeaponEntry] = [] self.weapons: List[WeaponEntry] = []

View File

@ -0,0 +1 @@
"""SignService"""

View File

@ -2,8 +2,10 @@ import enum
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
from sqlalchemy import func from sqlalchemy import func, BigInteger
from sqlmodel import SQLModel, Field, Enum, Column, DateTime from sqlmodel import Column, DateTime, Enum, Field, SQLModel, Integer
__all__ = ("SignStatusEnum", "Sign")
class SignStatusEnum(int, enum.Enum): class SignStatusEnum(int, enum.Enum):
@ -19,10 +21,13 @@ class SignStatusEnum(int, enum.Enum):
class Sign(SQLModel, table=True): class Sign(SQLModel, table=True):
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci") __table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
id: Optional[int] = Field(
id: int = Field(primary_key=True) default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True)
user_id: int = Field(foreign_key="user.user_id") )
user_id: int = Field(primary_key=True, sa_column=Column(BigInteger(), index=True))
chat_id: Optional[int] = Field(default=None) chat_id: Optional[int] = Field(default=None)
time_created: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True), server_default=func.now())) time_created: Optional[datetime] = Field(
time_updated: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True), onupdate=func.now())) sa_column=Column(DateTime, server_default=func.now()) # pylint: disable=E1102
)
time_updated: Optional[datetime] = Field(sa_column=Column(DateTime, onupdate=func.now())) # pylint: disable=E1102
status: Optional[SignStatusEnum] = Field(sa_column=Column(Enum(SignStatusEnum))) status: Optional[SignStatusEnum] = Field(sa_column=Column(Enum(SignStatusEnum)))

View File

@ -0,0 +1,50 @@
from typing import List, Optional
from sqlmodel import select
from core.base_service import BaseService
from core.dependence.mysql import MySQL
from core.services.sign.models import Sign
from core.sqlmodel.session import AsyncSession
__all__ = ("SignRepository",)
class SignRepository(BaseService.Component):
def __init__(self, mysql: MySQL):
self.engine = mysql.engine
async def add(self, sign: Sign):
async with AsyncSession(self.engine) as session:
session.add(sign)
await session.commit()
async def remove(self, sign: Sign):
async with AsyncSession(self.engine) as session:
await session.delete(sign)
await session.commit()
async def update(self, sign: Sign) -> Sign:
async with AsyncSession(self.engine) as session:
session.add(sign)
await session.commit()
await session.refresh(sign)
return sign
async def get_by_user_id(self, user_id: int) -> Optional[Sign]:
async with AsyncSession(self.engine) as session:
statement = select(Sign).where(Sign.user_id == user_id)
results = await session.exec(statement)
return results.first()
async def get_by_chat_id(self, chat_id: int) -> Optional[List[Sign]]:
async with AsyncSession(self.engine) as session:
statement = select(Sign).where(Sign.chat_id == chat_id)
results = await session.exec(statement)
return results.all()
async def get_all(self) -> List[Sign]:
async with AsyncSession(self.engine) as session:
query = select(Sign)
results = await session.exec(query)
return results.all()

View File

@ -1,8 +1,11 @@
from .models import Sign from core.base_service import BaseService
from .repositories import SignRepository from core.services.sign.models import Sign
from core.services.sign.repositories import SignRepository
__all__ = ["SignServices"]
class SignServices: class SignServices(BaseService):
def __init__(self, sign_repository: SignRepository) -> None: def __init__(self, sign_repository: SignRepository) -> None:
self._repository: SignRepository = sign_repository self._repository: SignRepository = sign_repository

View File

@ -0,0 +1 @@
"""TemplateService"""

View File

@ -3,10 +3,14 @@ import pickle # nosec B403
from hashlib import sha256 from hashlib import sha256
from typing import Any, Optional from typing import Any, Optional
from core.base.redisdb import RedisDB from core.base_service import BaseService
from core.dependence.redisdb import RedisDB
__all__ = ["TemplatePreviewCache", "HtmlToFileIdCache"]
class TemplatePreviewCache: class TemplatePreviewCache(BaseService.Component):
"""暂存渲染模板的数据用于预览""" """暂存渲染模板的数据用于预览"""
def __init__(self, redis: RedisDB): def __init__(self, redis: RedisDB):
@ -29,7 +33,7 @@ class TemplatePreviewCache:
return f"{self.qname}:{key}" return f"{self.qname}:{key}"
class HtmlToFileIdCache: class HtmlToFileIdCache(BaseService.Component):
"""html to file_id 的缓存""" """html to file_id 的缓存"""
def __init__(self, redis: RedisDB): def __init__(self, redis: RedisDB):

View File

@ -1,10 +1,12 @@
from enum import Enum from enum import Enum
from typing import Optional, Union, List from typing import List, Optional, Union
from telegram import Message, InputMediaPhoto, InputMediaDocument from telegram import InputMediaDocument, InputMediaPhoto, Message
from core.template.cache import HtmlToFileIdCache from core.services.template.cache import HtmlToFileIdCache
from core.template.error import ErrorFileType, FileIdNotFound from core.services.template.error import ErrorFileType, FileIdNotFound
__all__ = ["FileType", "RenderResult", "RenderGroupResult"]
class FileType(Enum): class FileType(Enum):
@ -16,9 +18,8 @@ class FileType(Enum):
"""对应的 Telegram media 类型""" """对应的 Telegram media 类型"""
if file_type == FileType.PHOTO: if file_type == FileType.PHOTO:
return InputMediaPhoto return InputMediaPhoto
elif file_type == FileType.DOCUMENT: if file_type == FileType.DOCUMENT:
return InputMediaDocument return InputMediaDocument
else:
raise ErrorFileType raise ErrorFileType

View File

@ -1,44 +1,31 @@
import time import asyncio
from typing import Optional from typing import Optional
from urllib.parse import ( from urllib.parse import urlencode, urljoin, urlsplit
urlencode,
urljoin,
urlsplit,
)
from uuid import uuid4 from uuid import uuid4
from fastapi import HTTPException from fastapi import FastAPI, HTTPException
from fastapi.responses import ( from fastapi.responses import FileResponse, HTMLResponse
FileResponse,
HTMLResponse,
)
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from jinja2 import ( from jinja2 import Environment, FileSystemLoader, Template
Environment,
FileSystemLoader,
Template,
)
from playwright.async_api import ViewportSize from playwright.async_api import ViewportSize
from core.base.aiobrowser import AioBrowser from core.application import Application
from core.base.webserver import webapp from core.base_service import BaseService
from core.bot import bot from core.config import config as application_config
from core.template.cache import ( from core.dependence.aiobrowser import AioBrowser
HtmlToFileIdCache, from core.services.template.cache import HtmlToFileIdCache, TemplatePreviewCache
TemplatePreviewCache, from core.services.template.error import QuerySelectorNotFound
) from core.services.template.models import FileType, RenderResult
from core.template.error import QuerySelectorNotFound
from core.template.models import (
FileType,
RenderResult,
)
from utils.const import PROJECT_ROOT from utils.const import PROJECT_ROOT
from utils.log import logger from utils.log import logger
__all__ = ("TemplateService", "TemplatePreviewer")
class TemplateService:
class TemplateService(BaseService):
def __init__( def __init__(
self, self,
app: Application,
browser: AioBrowser, browser: AioBrowser,
html_to_file_id_cache: HtmlToFileIdCache, html_to_file_id_cache: HtmlToFileIdCache,
preview_cache: TemplatePreviewCache, preview_cache: TemplatePreviewCache,
@ -51,10 +38,12 @@ class TemplateService:
loader=FileSystemLoader(template_dir), loader=FileSystemLoader(template_dir),
enable_async=True, enable_async=True,
autoescape=True, autoescape=True,
auto_reload=bot.config.debug, auto_reload=application_config.debug,
) )
self.using_preview = application_config.debug and application_config.webserver.enable
self.previewer = TemplatePreviewer(self, preview_cache) if self.using_preview:
self.previewer = TemplatePreviewer(self, preview_cache, app.web_app)
self.html_to_file_id_cache = html_to_file_id_cache self.html_to_file_id_cache = html_to_file_id_cache
@ -66,10 +55,11 @@ class TemplateService:
:param template_name: 模板文件名 :param template_name: 模板文件名
:param template_data: 模板数据 :param template_data: 模板数据
""" """
start_time = time.time() loop = asyncio.get_event_loop()
start_time = loop.time()
template = self.get_template(template_name) template = self.get_template(template_name)
html = await template.render_async(**template_data) html = await template.render_async(**template_data)
logger.debug(f"{template_name} 模板渲染使用了 {str(time.time() - start_time)}") logger.debug("%s 模板渲染使用了 %s", template_name, str(loop.time() - start_time))
return html return html
async def render( async def render(
@ -100,19 +90,20 @@ class TemplateService:
:param filename: 文件名字 :param filename: 文件名字
:return: :return:
""" """
start_time = time.time() loop = asyncio.get_event_loop()
start_time = loop.time()
template = self.get_template(template_name) template = self.get_template(template_name)
if bot.config.debug: if self.using_preview:
preview_url = await self.previewer.get_preview_url(template_name, template_data) preview_url = await self.previewer.get_preview_url(template_name, template_data)
logger.debug(f"调试模板 URL: {preview_url}") logger.debug("调试模板 URL: \n%s", preview_url)
html = await template.render_async(**template_data) html = await template.render_async(**template_data)
logger.debug(f"{template_name} 模板渲染使用了 {str(time.time() - start_time)}") logger.debug("%s 模板渲染使用了 %s", template_name, str(loop.time() - start_time))
file_id = await self.html_to_file_id_cache.get_data(html, file_type.name) file_id = await self.html_to_file_id_cache.get_data(html, file_type.name)
if file_id and not bot.config.debug: if file_id and not application_config.debug:
logger.debug(f"{template_name} 命中缓存,返回 file_id {file_id}") logger.debug("%s 命中缓存,返回 file_id[%s]", template_name, file_id)
return RenderResult( return RenderResult(
html=html, html=html,
photo=file_id, photo=file_id,
@ -125,7 +116,7 @@ class TemplateService:
) )
browser = await self._browser.get_browser() browser = await self._browser.get_browser()
start_time = time.time() start_time = loop.time()
page = await browser.new_page(viewport=viewport) page = await browser.new_page(viewport=viewport)
uri = (PROJECT_ROOT / template.filename).as_uri() uri = (PROJECT_ROOT / template.filename).as_uri()
await page.goto(uri) await page.goto(uri)
@ -142,10 +133,10 @@ class TemplateService:
if not clip: if not clip:
raise QuerySelectorNotFound raise QuerySelectorNotFound
except QuerySelectorNotFound: except QuerySelectorNotFound:
logger.warning(f"未找到 {query_selector} 元素") logger.warning("未找到 %s 元素", query_selector)
png_data = await page.screenshot(clip=clip, full_page=full_page) png_data = await page.screenshot(clip=clip, full_page=full_page)
await page.close() await page.close()
logger.debug(f"{template_name} 图片渲染使用了 {str(time.time() - start_time)}") logger.debug("%s 图片渲染使用了 %s", template_name, str(loop.time() - start_time))
return RenderResult( return RenderResult(
html=html, html=html,
photo=png_data, photo=png_data,
@ -158,15 +149,21 @@ class TemplateService:
) )
class TemplatePreviewer: class TemplatePreviewer(BaseService, load=application_config.webserver.enable and application_config.debug):
def __init__(self, template_service: TemplateService, cache: TemplatePreviewCache): def __init__(
self,
template_service: TemplateService,
cache: TemplatePreviewCache,
web_app: FastAPI,
):
self.web_app = web_app
self.template_service = template_service self.template_service = template_service
self.cache = cache self.cache = cache
self.register_routes() self.register_routes()
async def get_preview_url(self, template: str, data: dict): async def get_preview_url(self, template: str, data: dict):
"""获取预览 URL""" """获取预览 URL"""
components = urlsplit(bot.config.webserver.url) components = urlsplit(application_config.webserver.url)
path = urljoin("/preview/", template) path = urljoin("/preview/", template)
query = {} query = {}
@ -176,12 +173,13 @@ class TemplatePreviewer:
await self.cache.set_data(key, data) await self.cache.set_data(key, data)
query["key"] = key query["key"] = key
# noinspection PyProtectedMember
return components._replace(path=path, query=urlencode(query)).geturl() return components._replace(path=path, query=urlencode(query)).geturl()
def register_routes(self): def register_routes(self):
"""注册预览用到的路由""" """注册预览用到的路由"""
@webapp.get("/preview/{path:path}") @self.web_app.get("/preview/{path:path}")
async def preview_template(path: str, key: Optional[str] = None): # pylint: disable=W0612 async def preview_template(path: str, key: Optional[str] = None): # pylint: disable=W0612
# 如果是 /preview/ 开头的静态文件,直接返回内容。比如使用相对链接 ../ 引入的静态资源 # 如果是 /preview/ 开头的静态文件,直接返回内容。比如使用相对链接 ../ 引入的静态资源
if not path.endswith(".html"): if not path.endswith(".html"):
@ -206,4 +204,4 @@ class TemplatePreviewer:
for name in ["cache", "resources"]: for name in ["cache", "resources"]:
directory = PROJECT_ROOT / name directory = PROJECT_ROOT / name
directory.mkdir(exist_ok=True) directory.mkdir(exist_ok=True)
webapp.mount(f"/{name}", StaticFiles(directory=PROJECT_ROOT / name), name=name) self.web_app.mount(f"/{name}", StaticFiles(directory=PROJECT_ROOT / name), name=name)

View File

View File

@ -0,0 +1,24 @@
from typing import List
from core.base_service import BaseService
from core.dependence.redisdb import RedisDB
__all__ = ("UserAdminCache",)
class UserAdminCache(BaseService.Component):
def __init__(self, redis: RedisDB):
self.client = redis.client
self.qname = "users:admin"
async def ismember(self, user_id: int) -> bool:
return self.client.sismember(self.qname, user_id)
async def get_all(self) -> List[int]:
return [int(str_data) for str_data in await self.client.smembers(self.qname)]
async def set(self, user_id: int) -> bool:
return await self.client.sadd(self.qname, user_id)
async def remove(self, user_id: int) -> bool:
return await self.client.srem(self.qname, user_id)

View File

@ -0,0 +1,34 @@
import enum
from datetime import datetime
from typing import Optional
from sqlmodel import SQLModel, Field, DateTime, Column, Enum, BigInteger, Integer
__all__ = (
"User",
"UserDataBase",
"PermissionsEnum",
)
class PermissionsEnum(int, enum.Enum):
OWNER = 1
ADMIN = 2
PUBLIC = 3
class User(SQLModel):
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
id: Optional[int] = Field(
default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True)
)
user_id: int = Field(unique=True, sa_column=Column(BigInteger()))
permissions: Optional[PermissionsEnum] = Field(sa_column=Column(Enum(PermissionsEnum)))
locale: Optional[str] = Field()
ban_end_time: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True)))
ban_start_time: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True)))
is_banned: Optional[int] = Field()
class UserDataBase(User, table=True):
__tablename__ = "users"

View File

@ -0,0 +1,44 @@
from typing import Optional, List
from sqlmodel import select
from core.base_service import BaseService
from core.dependence.mysql import MySQL
from core.services.users.models import UserDataBase as User
from core.sqlmodel.session import AsyncSession
__all__ = ("UserRepository",)
class UserRepository(BaseService.Component):
def __init__(self, mysql: MySQL):
self.engine = mysql.engine
async def get_by_user_id(self, user_id: int) -> Optional[User]:
async with AsyncSession(self.engine) as session:
statement = select(User).where(User.user_id == user_id)
results = await session.exec(statement)
return results.first()
async def add(self, user: User):
async with AsyncSession(self.engine) as session:
session.add(user)
await session.commit()
async def update(self, user: User) -> User:
async with AsyncSession(self.engine) as session:
session.add(user)
await session.commit()
await session.refresh(user)
return user
async def remove(self, user: User):
async with AsyncSession(self.engine) as session:
await session.delete(user)
await session.commit()
async def get_all(self) -> List[User]:
async with AsyncSession(self.engine) as session:
statement = select(User)
results = await session.exec(statement)
return results.all()

View File

@ -0,0 +1,79 @@
from typing import List, Optional
from core.base_service import BaseService
from core.config import config
from core.services.users.cache import UserAdminCache
from core.services.users.models import PermissionsEnum, UserDataBase as User
from core.services.users.repositories import UserRepository
__all__ = ("UserService", "UserAdminService")
from utils.log import logger
class UserService(BaseService):
def __init__(self, user_repository: UserRepository) -> None:
self._repository: UserRepository = user_repository
async def get_user_by_id(self, user_id: int) -> Optional[User]:
"""从数据库获取用户信息
:param user_id:用户ID
:return: User
"""
return await self._repository.get_by_user_id(user_id)
async def remove(self, user: User):
return await self._repository.remove(user)
async def update_user(self, user: User):
return await self._repository.add(user)
class UserAdminService(BaseService):
def __init__(self, user_repository: UserRepository, cache: UserAdminCache):
self.user_repository = user_repository
self._cache = cache
async def initialize(self):
owner = config.owner
if owner:
user = await self.user_repository.get_by_user_id(owner)
await self._cache.set(user.user_id)
if user:
if user.permissions != PermissionsEnum.OWNER:
user.permissions = PermissionsEnum.OWNER
await self.user_repository.update(user)
else:
user = User(user_id=owner, permissions=PermissionsEnum.OWNER)
await self.user_repository.add(user)
else:
logger.warning("检测到未配置Bot所有者 会导无法正常使用管理员权限")
async def is_admin(self, user_id: int) -> bool:
return await self._cache.ismember(user_id)
async def get_admin_list(self) -> List[int]:
return await self._cache.get_all()
async def add_admin(self, user_id: int) -> bool:
user = await self.user_repository.get_by_user_id(user_id)
if user:
if user.permissions == PermissionsEnum.OWNER:
return False
if user.permissions != PermissionsEnum.ADMIN:
user.permissions = PermissionsEnum.ADMIN
await self.user_repository.update(user)
else:
user = User(user_id=user_id, permissions=PermissionsEnum.ADMIN)
await self.user_repository.add(user)
return await self._cache.set(user.user_id)
async def delete_admin(self, user_id: int) -> bool:
user = await self.user_repository.get_by_user_id(user_id)
if user:
if user.permissions == PermissionsEnum.OWNER:
return True # 假装移除成功
user.permissions = PermissionsEnum.PUBLIC
await self.user_repository.update(user)
return await self._cache.remove(user.user_id)
return False

View File

@ -0,0 +1 @@
"""WikiService"""

View File

@ -1,10 +1,13 @@
import ujson as json import ujson as json
from core.base.redisdb import RedisDB from core.base_service import BaseService
from core.dependence.redisdb import RedisDB
from modules.wiki.base import Model from modules.wiki.base import Model
__all__ = ["WikiCache"]
class WikiCache:
class WikiCache(BaseService.Component):
def __init__(self, redis: RedisDB): def __init__(self, redis: RedisDB):
self.client = redis.client self.client = redis.client
self.qname = "wiki" self.qname = "wiki"

View File

@ -1,12 +1,15 @@
from typing import List, NoReturn, Optional from typing import List, NoReturn, Optional
from core.wiki.cache import WikiCache from core.base_service import BaseService
from core.services.wiki.cache import WikiCache
from modules.wiki.character import Character from modules.wiki.character import Character
from modules.wiki.weapon import Weapon from modules.wiki.weapon import Weapon
from utils.log import logger from utils.log import logger
__all__ = ["WikiService"]
class WikiService:
class WikiService(BaseService):
def __init__(self, cache: WikiCache): def __init__(self, cache: WikiCache):
self._cache = cache self._cache = cache
"""Redis 在这里的作用是作为持久化""" """Redis 在这里的作用是作为持久化"""
@ -18,7 +21,7 @@ class WikiService:
async def refresh_weapon(self) -> NoReturn: async def refresh_weapon(self) -> NoReturn:
weapon_name_list = await Weapon.get_name_list() weapon_name_list = await Weapon.get_name_list()
logger.info(f"一共找到 {len(weapon_name_list)} 把武器信息") logger.info("一共找到 %s 把武器信息", len(weapon_name_list))
weapon_list = [] weapon_list = []
num = 0 num = 0
@ -26,7 +29,7 @@ class WikiService:
weapon_list.append(weapon) weapon_list.append(weapon)
num += 1 num += 1
if num % 10 == 0: if num % 10 == 0:
logger.info(f"现在已经获取到 {num} 把武器信息") logger.info("现在已经获取到 %s 把武器信息", num)
logger.info("写入武器信息到Redis") logger.info("写入武器信息到Redis")
self._weapon_list = weapon_list self._weapon_list = weapon_list
@ -35,7 +38,7 @@ class WikiService:
async def refresh_characters(self) -> NoReturn: async def refresh_characters(self) -> NoReturn:
character_name_list = await Character.get_name_list() character_name_list = await Character.get_name_list()
logger.info(f"一共找到 {len(character_name_list)} 个角色信息") logger.info("一共找到 %s 个角色信息", len(character_name_list))
character_list = [] character_list = []
num = 0 num = 0
@ -43,7 +46,7 @@ class WikiService:
character_list.append(character) character_list.append(character)
num += 1 num += 1
if num % 10 == 0: if num % 10 == 0:
logger.info(f"现在已经获取到 {num} 个角色信息") logger.info("现在已经获取到 %s 个角色信息", num)
logger.info("写入角色信息到Redis") logger.info("写入角色信息到Redis")
self._character_list = character_list self._character_list = character_list

View File

@ -1,11 +0,0 @@
from core.base.mysql import MySQL
from core.service import init_service
from .repositories import SignRepository
from .services import SignServices
@init_service
def create_game_strategy_service(mysql: MySQL):
_repository = SignRepository(mysql)
_service = SignServices(_repository)
return _service

Some files were not shown because too many files have changed in this diff Show More