mirror of
https://github.com/PaiGramTeam/PaiGram.git
synced 2024-11-21 14:48:20 +00:00
♻️ PaiGram V4
Co-authored-by: luoshuijs <luoshuijs@outlook.com> Co-authored-by: Karako <karakohear@gmail.com> Co-authored-by: xtaodada <xtao@xtaolink.cn>
This commit is contained in:
parent
baceace292
commit
233e7ab58d
32
.env.example
32
.env.example
@ -1,6 +1,12 @@
|
|||||||
# debug 开关
|
# debug 开关
|
||||||
DEBUG=false
|
DEBUG=false
|
||||||
|
|
||||||
|
AUTO_RELOAD=false
|
||||||
|
RELOAD_DELAY=0.25
|
||||||
|
RELOAD_DIRS=[]
|
||||||
|
RELOAD_INCLUDE=[]
|
||||||
|
RELOAD_EXCLUDE=[]
|
||||||
|
|
||||||
# MySQL
|
# MySQL
|
||||||
DB_HOST=127.0.0.1
|
DB_HOST=127.0.0.1
|
||||||
DB_PORT=3306
|
DB_PORT=3306
|
||||||
@ -17,14 +23,14 @@ REDIS_PASSWORD=""
|
|||||||
# 联系 https://t.me/BotFather 使用 /newbot 命令创建机器人并获取 token
|
# 联系 https://t.me/BotFather 使用 /newbot 命令创建机器人并获取 token
|
||||||
BOT_TOKEN="xxxxxxx"
|
BOT_TOKEN="xxxxxxx"
|
||||||
|
|
||||||
# bot 管理员
|
# bot 所有者
|
||||||
ADMINS=[{ "username": "", "user_id": -1 }]
|
OWNER=0
|
||||||
|
|
||||||
# 记录错误并发送消息通知开发人员 可选配置项
|
# 记录错误并发送消息通知开发人员 可选配置项
|
||||||
# ERROR_NOTIFICATION_CHAT_ID=chat_id
|
# ERROR_NOTIFICATION_CHAT_ID=chat_id
|
||||||
|
|
||||||
# 文章推送群组 可选配置项
|
# 文章推送群组 可选配置项
|
||||||
# CHANNELS=[{ "name": "", "chat_id": 1}]
|
# CHANNELS=[]
|
||||||
|
|
||||||
# 是否允许机器人邀请到其他群 默认不允许 如果允许 可以允许全部人或有认证选项 可选配置项
|
# 是否允许机器人邀请到其他群 默认不允许 如果允许 可以允许全部人或有认证选项 可选配置项
|
||||||
# JOIN_GROUPS = "NO_ALLOW"
|
# JOIN_GROUPS = "NO_ALLOW"
|
||||||
@ -33,20 +39,20 @@ ADMINS=[{ "username": "", "user_id": -1 }]
|
|||||||
# VERIFY_GROUPS=[]
|
# VERIFY_GROUPS=[]
|
||||||
|
|
||||||
# logger 配置 可选配置项
|
# logger 配置 可选配置项
|
||||||
LOGGER_NAME="TGPaimon"
|
# LOGGER_NAME="TGPaimon"
|
||||||
# 打印时的宽度
|
# 打印时的宽度
|
||||||
LOGGER_WIDTH=180
|
# LOGGER_WIDTH=180
|
||||||
# log 文件存放目录
|
# log 文件存放目录
|
||||||
LOGGER_LOG_PATH="logs"
|
# LOGGER_LOG_PATH="logs"
|
||||||
# log 时间格式,参考 datetime.strftime
|
# log 时间格式,参考 datetime.strftime
|
||||||
LOGGER_TIME_FORMAT="[%Y-%m-%d %X]"
|
# LOGGER_TIME_FORMAT="[%Y-%m-%d %X]"
|
||||||
# log 高亮关键词
|
# log 高亮关键词
|
||||||
LOGGER_RENDER_KEYWORDS=["BOT"]
|
# LOGGER_RENDER_KEYWORDS=["BOT"]
|
||||||
# traceback 相关配置
|
# traceback 相关配置
|
||||||
LOGGER_TRACEBACK_MAX_FRAMES=20
|
# LOGGER_TRACEBACK_MAX_FRAMES=20
|
||||||
LOGGER_LOCALS_MAX_DEPTH=0
|
# LOGGER_LOCALS_MAX_DEPTH=0
|
||||||
LOGGER_LOCALS_MAX_LENGTH=10
|
# LOGGER_LOCALS_MAX_LENGTH=10
|
||||||
LOGGER_LOCALS_MAX_STRING=80
|
# LOGGER_LOCALS_MAX_STRING=80
|
||||||
# 可被 logger 打印的 record 的名称(默认包含了 LOGGER_NAME )
|
# 可被 logger 打印的 record 的名称(默认包含了 LOGGER_NAME )
|
||||||
LOGGER_FILTERED_NAMES=["uvicorn","ErrorPush","ApiHelper"]
|
LOGGER_FILTERED_NAMES=["uvicorn","ErrorPush","ApiHelper"]
|
||||||
|
|
||||||
@ -77,7 +83,7 @@ LOGGER_FILTERED_NAMES=["uvicorn","ErrorPush","ApiHelper"]
|
|||||||
# ENKA_NETWORK_API_AGENT=""
|
# ENKA_NETWORK_API_AGENT=""
|
||||||
|
|
||||||
# Web Server
|
# Web Server
|
||||||
# 目前只用于预览模板,仅开发环境启动
|
# WEB_SWITCH=False # 是否开启
|
||||||
# WEB_URL=http://localhost:8080/
|
# WEB_URL=http://localhost:8080/
|
||||||
# WEB_HOST=localhost
|
# WEB_HOST=localhost
|
||||||
# WEB_PORT=8080
|
# WEB_PORT=8080
|
||||||
|
54
.github/workflows/integration-test.yml
vendored
Normal file
54
.github/workflows/integration-test.yml
vendored
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
name: Integration Test
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
paths:
|
||||||
|
- 'tests/integration/**'
|
||||||
|
pull_request:
|
||||||
|
types: [ opened, synchronize ]
|
||||||
|
paths:
|
||||||
|
- 'core/services/**'
|
||||||
|
- 'core/dependence/**'
|
||||||
|
- 'tests/integration/**'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
pytest:
|
||||||
|
name: pytest
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
services:
|
||||||
|
mysql:
|
||||||
|
image: mysql:5.7
|
||||||
|
env:
|
||||||
|
MYSQL_DATABASE: integration_test
|
||||||
|
MYSQL_ROOT_PASSWORD: 123456test
|
||||||
|
ports:
|
||||||
|
- 3306:3306
|
||||||
|
options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3
|
||||||
|
redis:
|
||||||
|
image: redis
|
||||||
|
ports:
|
||||||
|
- 6379:6379
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
- name: Set up Python 3.11
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: 3.11
|
||||||
|
- name: Setup integration test environment
|
||||||
|
run: cp tests/integration/.env.example .env && cp tests/integration/.env.example tests/integration/.env
|
||||||
|
- name: Create venv
|
||||||
|
run: |
|
||||||
|
pip install --upgrade pip
|
||||||
|
python3 -m venv venv
|
||||||
|
- name: Install requirements
|
||||||
|
run: |
|
||||||
|
source venv/bin/activate
|
||||||
|
python3 -m pip install --upgrade poetry
|
||||||
|
python3 -m poetry install --extras all
|
||||||
|
- name: Run test
|
||||||
|
run: |
|
||||||
|
source venv/bin/activate
|
||||||
|
python3 -m pytest tests/integration
|
21
.github/workflows/test.yml
vendored
21
.github/workflows/test.yml
vendored
@ -1,19 +1,17 @@
|
|||||||
name: test
|
name: Test modules
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
paths:
|
paths:
|
||||||
- 'tests/**'
|
- 'tests/unit/**'
|
||||||
pull_request:
|
pull_request:
|
||||||
types: [ opened, synchronize ]
|
types: [ opened, synchronize ]
|
||||||
paths:
|
paths:
|
||||||
- 'modules/apihelper/**'
|
- 'modules/apihelper/**'
|
||||||
- 'modules/wiki/**'
|
- 'modules/wiki/**'
|
||||||
- 'tests/**'
|
- 'tests/unit/**'
|
||||||
schedule:
|
|
||||||
- cron: '0 4 * * 3'
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
pytest:
|
pytest:
|
||||||
@ -22,16 +20,15 @@ jobs:
|
|||||||
continue-on-error: ${{ matrix.experimental }}
|
continue-on-error: ${{ matrix.experimental }}
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: [ '3.10' ]
|
|
||||||
os: [ ubuntu-latest, windows-latest ]
|
os: [ ubuntu-latest, windows-latest ]
|
||||||
experimental: [ false ]
|
|
||||||
fail-fast: False
|
fail-fast: False
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- name: Checkout code
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
uses: actions/checkout@v3
|
||||||
uses: actions/setup-python@v4
|
- name: Set up Python 3.11
|
||||||
|
uses: actions/setup-python@v2
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: 3.11
|
||||||
- name: restore or create a python virtualenv
|
- name: restore or create a python virtualenv
|
||||||
id: cache
|
id: cache
|
||||||
uses: syphar/restore-virtualenv@v1.2
|
uses: syphar/restore-virtualenv@v1.2
|
||||||
@ -45,4 +42,4 @@ jobs:
|
|||||||
poetry install --extras test
|
poetry install --extras test
|
||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
run: |
|
run: |
|
||||||
python -m pytest
|
python -m pytest tests/unit
|
5
.gitignore
vendored
5
.gitignore
vendored
@ -58,6 +58,5 @@ plugins/private
|
|||||||
.pytest_cache
|
.pytest_cache
|
||||||
|
|
||||||
### mtp ###
|
### mtp ###
|
||||||
paimon.session
|
paigram.session
|
||||||
PaimonBot.session
|
paigram.session-journal
|
||||||
PaimonBot.session-journal
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
<h1 align="center">PaiGram</h1>
|
<h1 align="center">PaiGram</h1>
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">·<img src="https://img.shields.io/badge/python-3.11%2B-blue" alt="">
|
||||||
<img src="https://img.shields.io/badge/python-3.8%2B-blue" alt="">
|
|
||||||
<img src="https://img.shields.io/badge/works%20on-my%20machine-brightgreen" alt="">
|
<img src="https://img.shields.io/badge/works%20on-my%20machine-brightgreen" alt="">
|
||||||
<img src="https://img.shields.io/badge/status-%E5%92%95%E5%92%95%E5%92%95-blue" alt="">
|
<img src="https://img.shields.io/badge/status-%E5%92%95%E5%92%95%E5%92%95-blue" alt="">
|
||||||
<a href="https://black.readthedocs.io/en/stable/index.html"><img src="https://img.shields.io/badge/code%20style-black-000000.svg" alt="code_style" /></a>
|
<a href="https://black.readthedocs.io/en/stable/index.html"><img src="https://img.shields.io/badge/code%20style-black-000000.svg" alt="code_style" /></a>
|
||||||
@ -19,7 +18,7 @@
|
|||||||
|
|
||||||
## 环境需求
|
## 环境需求
|
||||||
|
|
||||||
- Python 3.8+
|
- Python 3.11+
|
||||||
- MySQL
|
- MySQL
|
||||||
- Redis
|
- Redis
|
||||||
|
|
||||||
|
@ -6,19 +6,13 @@ from logging.config import fileConfig
|
|||||||
from typing import Iterator
|
from typing import Iterator
|
||||||
|
|
||||||
from alembic import context
|
from alembic import context
|
||||||
from sqlalchemy import (
|
from sqlalchemy import engine_from_config, pool
|
||||||
engine_from_config,
|
|
||||||
pool,
|
|
||||||
)
|
|
||||||
from sqlalchemy.engine import Connection
|
from sqlalchemy.engine import Connection
|
||||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||||
from sqlmodel import SQLModel
|
from sqlmodel import SQLModel
|
||||||
|
|
||||||
from utils.const import (
|
from core.config import config as BotConfig
|
||||||
CORE_DIR,
|
from utils.const import CORE_DIR, PLUGIN_DIR, PROJECT_ROOT
|
||||||
PLUGIN_DIR,
|
|
||||||
PROJECT_ROOT,
|
|
||||||
)
|
|
||||||
from utils.log import logger
|
from utils.log import logger
|
||||||
|
|
||||||
# this is the Alembic Config object, which provides
|
# this is the Alembic Config object, which provides
|
||||||
@ -28,7 +22,7 @@ config = context.config
|
|||||||
# Interpret the config file for Python logging.
|
# Interpret the config file for Python logging.
|
||||||
# This line sets up loggers basically.
|
# This line sets up loggers basically.
|
||||||
if config.config_file_name is not None:
|
if config.config_file_name is not None:
|
||||||
fileConfig(config.config_file_name)
|
fileConfig(config.config_file_name) # skipcq: PY-A6006
|
||||||
|
|
||||||
|
|
||||||
def scan_models() -> Iterator[str]:
|
def scan_models() -> Iterator[str]:
|
||||||
@ -46,7 +40,7 @@ def import_models():
|
|||||||
try:
|
try:
|
||||||
import_module(pkg) # 导入 models
|
import_module(pkg) # 导入 models
|
||||||
except Exception as e: # pylint: disable=W0703
|
except Exception as e: # pylint: disable=W0703
|
||||||
logger.error(f'在导入文件 "{pkg}" 的过程中遇到了错误: \n[red bold]{type(e).__name__}: {e}[/]')
|
logger.error("在导入文件 %s 的过程中遇到了错误: \n[red bold]%s: %s[/]", pkg, type(e).__name__, e, extra={"markup": True})
|
||||||
|
|
||||||
|
|
||||||
# register our models for alembic to auto-generate migrations
|
# register our models for alembic to auto-generate migrations
|
||||||
@ -61,14 +55,13 @@ target_metadata = SQLModel.metadata
|
|||||||
|
|
||||||
# here we allow ourselves to pass interpolation vars to alembic.ini
|
# here we allow ourselves to pass interpolation vars to alembic.ini
|
||||||
# from the application config module
|
# from the application config module
|
||||||
from core.config import config as botConfig
|
|
||||||
|
|
||||||
section = config.config_ini_section
|
section = config.config_ini_section
|
||||||
config.set_section_option(section, "DB_HOST", botConfig.mysql.host)
|
config.set_section_option(section, "DB_HOST", BotConfig.mysql.host)
|
||||||
config.set_section_option(section, "DB_PORT", str(botConfig.mysql.port))
|
config.set_section_option(section, "DB_PORT", str(BotConfig.mysql.port))
|
||||||
config.set_section_option(section, "DB_USERNAME", botConfig.mysql.username)
|
config.set_section_option(section, "DB_USERNAME", BotConfig.mysql.username)
|
||||||
config.set_section_option(section, "DB_PASSWORD", botConfig.mysql.password)
|
config.set_section_option(section, "DB_PASSWORD", BotConfig.mysql.password)
|
||||||
config.set_section_option(section, "DB_DATABASE", botConfig.mysql.database)
|
config.set_section_option(section, "DB_DATABASE", BotConfig.mysql.database)
|
||||||
|
|
||||||
|
|
||||||
def run_migrations_offline() -> None:
|
def run_migrations_offline() -> None:
|
||||||
|
@ -5,16 +5,19 @@ Revises:
|
|||||||
Create Date: 2022-09-01 16:55:20.372560
|
Create Date: 2022-09-01 16:55:20.372560
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from alembic import op
|
from base64 import b64decode
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
import sqlmodel
|
import sqlmodel
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision = "9e9a36470cd5"
|
revision = "9e9a36470cd5"
|
||||||
down_revision = None
|
down_revision = None
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
old_cookies_database_name1 = b64decode("bWlob3lvX2Nvb2tpZXM=").decode()
|
||||||
|
old_cookies_database_name2 = b64decode("aG95b3ZlcnNlX2Nvb2tpZXM=").decode()
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
@ -22,7 +25,7 @@ def upgrade() -> None:
|
|||||||
op.create_table(
|
op.create_table(
|
||||||
"question",
|
"question",
|
||||||
sa.Column("id", sa.Integer(), nullable=False),
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
sa.Column("text", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
sa.Column("text", sqlmodel.AutoString(), nullable=True),
|
||||||
sa.PrimaryKeyConstraint("id"),
|
sa.PrimaryKeyConstraint("id"),
|
||||||
mysql_charset="utf8mb4",
|
mysql_charset="utf8mb4",
|
||||||
mysql_collate="utf8mb4_general_ci",
|
mysql_collate="utf8mb4_general_ci",
|
||||||
@ -35,7 +38,7 @@ def upgrade() -> None:
|
|||||||
nullable=True,
|
nullable=True,
|
||||||
),
|
),
|
||||||
sa.Column("id", sa.Integer(), nullable=False),
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
sa.Column("user_id", sa.Integer(), nullable=False),
|
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||||
sa.Column("yuanshen_uid", sa.Integer(), nullable=True),
|
sa.Column("yuanshen_uid", sa.Integer(), nullable=True),
|
||||||
sa.Column("genshin_uid", sa.Integer(), nullable=True),
|
sa.Column("genshin_uid", sa.Integer(), nullable=True),
|
||||||
sa.PrimaryKeyConstraint("id"),
|
sa.PrimaryKeyConstraint("id"),
|
||||||
@ -46,7 +49,7 @@ def upgrade() -> None:
|
|||||||
op.create_table(
|
op.create_table(
|
||||||
"admin",
|
"admin",
|
||||||
sa.Column("id", sa.Integer(), nullable=False),
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
sa.Column("user_id", sa.Integer(), nullable=False),
|
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||||
sa.ForeignKeyConstraint(
|
sa.ForeignKeyConstraint(
|
||||||
["user_id"],
|
["user_id"],
|
||||||
["user.user_id"],
|
["user.user_id"],
|
||||||
@ -60,7 +63,7 @@ def upgrade() -> None:
|
|||||||
sa.Column("question_id", sa.Integer(), nullable=True),
|
sa.Column("question_id", sa.Integer(), nullable=True),
|
||||||
sa.Column("id", sa.Integer(), nullable=False),
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
sa.Column("is_correct", sa.Boolean(), nullable=True),
|
sa.Column("is_correct", sa.Boolean(), nullable=True),
|
||||||
sa.Column("text", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
sa.Column("text", sqlmodel.AutoString(), nullable=True),
|
||||||
sa.ForeignKeyConstraint(
|
sa.ForeignKeyConstraint(
|
||||||
["question_id"],
|
["question_id"],
|
||||||
["question.id"],
|
["question.id"],
|
||||||
@ -72,7 +75,7 @@ def upgrade() -> None:
|
|||||||
mysql_collate="utf8mb4_general_ci",
|
mysql_collate="utf8mb4_general_ci",
|
||||||
)
|
)
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"hoyoverse_cookies",
|
old_cookies_database_name2,
|
||||||
sa.Column("cookies", sa.JSON(), nullable=True),
|
sa.Column("cookies", sa.JSON(), nullable=True),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"status",
|
"status",
|
||||||
@ -85,7 +88,7 @@ def upgrade() -> None:
|
|||||||
nullable=True,
|
nullable=True,
|
||||||
),
|
),
|
||||||
sa.Column("id", sa.Integer(), nullable=False),
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
sa.Column("user_id", sa.Integer(), nullable=True),
|
sa.Column("user_id", sa.BigInteger(), nullable=True),
|
||||||
sa.ForeignKeyConstraint(
|
sa.ForeignKeyConstraint(
|
||||||
["user_id"],
|
["user_id"],
|
||||||
["user.user_id"],
|
["user.user_id"],
|
||||||
@ -95,7 +98,7 @@ def upgrade() -> None:
|
|||||||
mysql_collate="utf8mb4_general_ci",
|
mysql_collate="utf8mb4_general_ci",
|
||||||
)
|
)
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"mihoyo_cookies",
|
old_cookies_database_name1,
|
||||||
sa.Column("cookies", sa.JSON(), nullable=True),
|
sa.Column("cookies", sa.JSON(), nullable=True),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"status",
|
"status",
|
||||||
@ -108,7 +111,7 @@ def upgrade() -> None:
|
|||||||
nullable=True,
|
nullable=True,
|
||||||
),
|
),
|
||||||
sa.Column("id", sa.Integer(), nullable=False),
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
sa.Column("user_id", sa.Integer(), nullable=True),
|
sa.Column("user_id", sa.BigInteger(), nullable=True),
|
||||||
sa.ForeignKeyConstraint(
|
sa.ForeignKeyConstraint(
|
||||||
["user_id"],
|
["user_id"],
|
||||||
["user.user_id"],
|
["user.user_id"],
|
||||||
@ -119,6 +122,9 @@ def upgrade() -> None:
|
|||||||
)
|
)
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"sign",
|
"sign",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("chat_id", sa.BigInteger(), nullable=True),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"time_created",
|
"time_created",
|
||||||
sa.DateTime(timezone=True),
|
sa.DateTime(timezone=True),
|
||||||
@ -140,14 +146,11 @@ def upgrade() -> None:
|
|||||||
),
|
),
|
||||||
nullable=True,
|
nullable=True,
|
||||||
),
|
),
|
||||||
sa.Column("id", sa.Integer(), nullable=False),
|
|
||||||
sa.Column("user_id", sa.Integer(), nullable=False),
|
|
||||||
sa.Column("chat_id", sa.Integer(), nullable=True),
|
|
||||||
sa.ForeignKeyConstraint(
|
sa.ForeignKeyConstraint(
|
||||||
["user_id"],
|
["user_id"],
|
||||||
["user.user_id"],
|
["user.user_id"],
|
||||||
),
|
),
|
||||||
sa.PrimaryKeyConstraint("id"),
|
sa.PrimaryKeyConstraint("id", "user_id"),
|
||||||
mysql_charset="utf8mb4",
|
mysql_charset="utf8mb4",
|
||||||
mysql_collate="utf8mb4_general_ci",
|
mysql_collate="utf8mb4_general_ci",
|
||||||
)
|
)
|
||||||
@ -157,8 +160,8 @@ def upgrade() -> None:
|
|||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.drop_table("sign")
|
op.drop_table("sign")
|
||||||
op.drop_table("mihoyo_cookies")
|
op.drop_table(old_cookies_database_name1)
|
||||||
op.drop_table("hoyoverse_cookies")
|
op.drop_table(old_cookies_database_name2)
|
||||||
op.drop_table("answer")
|
op.drop_table("answer")
|
||||||
op.drop_table("admin")
|
op.drop_table("admin")
|
||||||
op.drop_table("user")
|
op.drop_table("user")
|
||||||
|
301
alembic/versions/ddcfba3c7d5c_v4.py
Normal file
301
alembic/versions/ddcfba3c7d5c_v4.py
Normal file
@ -0,0 +1,301 @@
|
|||||||
|
"""v4
|
||||||
|
|
||||||
|
Revision ID: ddcfba3c7d5c
|
||||||
|
Revises: 9e9a36470cd5
|
||||||
|
Create Date: 2023-02-11 17:07:18.170175
|
||||||
|
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from base64 import b64decode
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
import sqlmodel
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.exc import NoSuchTableError
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "ddcfba3c7d5c"
|
||||||
|
down_revision = "9e9a36470cd5"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
old_cookies_database_name1 = b64decode("bWlob3lvX2Nvb2tpZXM=").decode()
|
||||||
|
old_cookies_database_name2 = b64decode("aG95b3ZlcnNlX2Nvb2tpZXM=").decode()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
connection = op.get_bind()
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
cookies_table = op.create_table(
|
||||||
|
"cookies",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("account_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("data", sa.JSON(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"status",
|
||||||
|
sa.Enum(
|
||||||
|
"STATUS_SUCCESS",
|
||||||
|
"INVALID_COOKIES",
|
||||||
|
"TOO_MANY_REQUESTS",
|
||||||
|
name="cookiesstatusenum",
|
||||||
|
),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"region",
|
||||||
|
sa.Enum("NULL", "HYPERION", "HOYOLAB", name="regionenum"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column("is_share", sa.Boolean(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.Index("index_user_account", "user_id", "account_id", unique=True),
|
||||||
|
mysql_charset="utf8mb4",
|
||||||
|
mysql_collate="utf8mb4_general_ci",
|
||||||
|
)
|
||||||
|
for old_cookies_database_name in (old_cookies_database_name1, old_cookies_database_name2):
|
||||||
|
try:
|
||||||
|
statement = f"SELECT * FROM {old_cookies_database_name};" # skipcq: BAN-B608
|
||||||
|
old_cookies_table_data = connection.execute(text(statement))
|
||||||
|
except NoSuchTableError:
|
||||||
|
logger.warning("Table '%s' doesn't exist", old_cookies_database_name)
|
||||||
|
continue
|
||||||
|
if old_cookies_table_data is None:
|
||||||
|
logger.warning("Old Cookies Database is None")
|
||||||
|
continue
|
||||||
|
for row in old_cookies_table_data:
|
||||||
|
try:
|
||||||
|
user_id = row["user_id"]
|
||||||
|
status = row["status"]
|
||||||
|
cookies_row = row["cookies"]
|
||||||
|
cookies_data = json.loads(cookies_row)
|
||||||
|
account_id = cookies_data.get("account_id")
|
||||||
|
if account_id is None: # Cleaning Data 清洗数据
|
||||||
|
account_id = cookies_data.get("ltuid")
|
||||||
|
else:
|
||||||
|
account_mid_v2 = cookies_data.get("account_mid_v2")
|
||||||
|
if account_mid_v2 is not None:
|
||||||
|
cookies_data.pop("account_id")
|
||||||
|
cookies_data.setdefault("account_uid_v2", account_id)
|
||||||
|
if old_cookies_database_name == old_cookies_database_name1:
|
||||||
|
region = "HYPERION"
|
||||||
|
else:
|
||||||
|
region = "HOYOLAB"
|
||||||
|
if account_id is None:
|
||||||
|
logger.warning("Can not get user account_id, user_id :%s", user_id)
|
||||||
|
continue
|
||||||
|
insert = cookies_table.insert().values(
|
||||||
|
user_id=int(user_id),
|
||||||
|
account_id=int(account_id),
|
||||||
|
status=status,
|
||||||
|
data=cookies_data,
|
||||||
|
region=region,
|
||||||
|
is_share=True,
|
||||||
|
)
|
||||||
|
with op.get_context().autocommit_block():
|
||||||
|
connection.execute(insert)
|
||||||
|
except Exception as exc: # pylint: disable=W0703
|
||||||
|
logger.error(
|
||||||
|
"Process %s->cookies Exception", old_cookies_database_name, exc_info=exc
|
||||||
|
) # pylint: disable=W0703
|
||||||
|
players_table = op.create_table(
|
||||||
|
"players",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("account_id", sa.BigInteger(), nullable=True),
|
||||||
|
sa.Column("player_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"region",
|
||||||
|
sa.Enum("NULL", "HYPERION", "HOYOLAB", name="regionenum"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column("is_chosen", sa.Boolean(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.Index("index_user_account_player", "user_id", "account_id", "player_id", unique=True),
|
||||||
|
mysql_charset="utf8mb4",
|
||||||
|
mysql_collate="utf8mb4_general_ci",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
statement = "SELECT * FROM user;"
|
||||||
|
old_user_table_data = connection.execute(text(statement))
|
||||||
|
except NoSuchTableError:
|
||||||
|
logger.warning("Table 'user' doesn't exist")
|
||||||
|
return # should not happen
|
||||||
|
if old_user_table_data is not None:
|
||||||
|
for row in old_user_table_data:
|
||||||
|
try:
|
||||||
|
user_id = row["user_id"]
|
||||||
|
y_uid = row["yuanshen_uid"]
|
||||||
|
g_uid = row["genshin_uid"]
|
||||||
|
region = row["region"]
|
||||||
|
account_id = None
|
||||||
|
cookies_row = connection.execute(
|
||||||
|
cookies_table.select().where(cookies_table.c.user_id == user_id)
|
||||||
|
).first()
|
||||||
|
if cookies_row is not None:
|
||||||
|
account_id = cookies_row["account_id"]
|
||||||
|
if y_uid:
|
||||||
|
insert = players_table.insert().values(
|
||||||
|
user_id=int(user_id),
|
||||||
|
player_id=int(y_uid),
|
||||||
|
is_chosen=(region == "HYPERION"),
|
||||||
|
region="HYPERION",
|
||||||
|
account_id=account_id,
|
||||||
|
)
|
||||||
|
with op.get_context().autocommit_block():
|
||||||
|
connection.execute(insert)
|
||||||
|
if g_uid:
|
||||||
|
insert = players_table.insert().values(
|
||||||
|
user_id=int(user_id),
|
||||||
|
player_id=int(g_uid),
|
||||||
|
is_chosen=(region == "HOYOLAB"),
|
||||||
|
region="HOYOLAB",
|
||||||
|
account_id=account_id,
|
||||||
|
)
|
||||||
|
with op.get_context().autocommit_block():
|
||||||
|
connection.execute(insert)
|
||||||
|
except Exception as exc: # pylint: disable=W0703
|
||||||
|
logger.error("Process user->player Exception", exc_info=exc)
|
||||||
|
else:
|
||||||
|
logger.warning("Old User Database is None")
|
||||||
|
|
||||||
|
users_table = op.create_table(
|
||||||
|
"users",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.BigInteger(), nullable=False, primary_key=True),
|
||||||
|
sa.Column(
|
||||||
|
"permissions",
|
||||||
|
sa.Enum("OWNER", "ADMIN", "PUBLIC", name="permissionsenum"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column("locale", sqlmodel.AutoString(), nullable=True),
|
||||||
|
sa.Column("is_banned", sa.BigInteger(), nullable=True),
|
||||||
|
sa.Column("ban_end_time", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("ban_start_time", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.UniqueConstraint("user_id"),
|
||||||
|
mysql_charset="utf8mb4",
|
||||||
|
mysql_collate="utf8mb4_general_ci",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
statement = "SELECT * FROM admin;"
|
||||||
|
old_user_table_data = connection.execute(text(statement))
|
||||||
|
except NoSuchTableError:
|
||||||
|
logger.warning("Table 'admin' doesn't exist")
|
||||||
|
return # should not happen
|
||||||
|
if old_user_table_data is not None:
|
||||||
|
for row in old_user_table_data:
|
||||||
|
try:
|
||||||
|
user_id = row["user_id"]
|
||||||
|
insert = users_table.insert().values(
|
||||||
|
user_id=int(user_id),
|
||||||
|
permissions="ADMIN",
|
||||||
|
)
|
||||||
|
with op.get_context().autocommit_block():
|
||||||
|
connection.execute(insert)
|
||||||
|
except Exception as exc: # pylint: disable=W0703
|
||||||
|
logger.error("Process admin->users Exception", exc_info=exc)
|
||||||
|
else:
|
||||||
|
logger.warning("Old User Database is None")
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"players_info",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("player_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("nickname", sqlmodel.AutoString(length=128), nullable=True),
|
||||||
|
sa.Column("signature", sqlmodel.AutoString(length=255), nullable=True),
|
||||||
|
sa.Column("hand_image", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("name_card", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("extra_data", sa.VARCHAR(length=512), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"create_time",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column("last_save_time", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("is_update", sa.Boolean(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.Index("index_user_player", "user_id", "player_id", unique=True),
|
||||||
|
mysql_charset="utf8mb4",
|
||||||
|
mysql_collate="utf8mb4_general_ci",
|
||||||
|
)
|
||||||
|
|
||||||
|
op.drop_table(old_cookies_database_name1)
|
||||||
|
op.drop_table(old_cookies_database_name2)
|
||||||
|
op.drop_table("admin")
|
||||||
|
op.drop_constraint("sign_ibfk_1", "sign", type_="foreignkey")
|
||||||
|
op.drop_index("user_id", table_name="sign")
|
||||||
|
op.drop_table("user")
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"user",
|
||||||
|
sa.Column("region", sa.Enum("NULL", "HYPERION", "HOYOLAB", name="regionenum"), nullable=True),
|
||||||
|
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column("user_id", sa.BigInteger(), autoincrement=False, nullable=False),
|
||||||
|
sa.Column("yuanshen_uid", sa.INTEGER(), autoincrement=False, nullable=True),
|
||||||
|
sa.Column("genshin_uid", sa.INTEGER(), autoincrement=False, nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
mysql_collate="utf8mb4_general_ci",
|
||||||
|
mysql_default_charset="utf8mb4",
|
||||||
|
mysql_engine="InnoDB",
|
||||||
|
)
|
||||||
|
op.create_index("user_id", "user", ["user_id"], unique=False)
|
||||||
|
op.create_table(
|
||||||
|
"admin",
|
||||||
|
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column("user_id", sa.BigInteger(), autoincrement=False, nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["user.user_id"], name="admin_ibfk_1"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
mysql_collate="utf8mb4_general_ci",
|
||||||
|
mysql_default_charset="utf8mb4",
|
||||||
|
mysql_engine="InnoDB",
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
old_cookies_database_name1,
|
||||||
|
sa.Column("cookies", sa.JSON(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"status",
|
||||||
|
sa.Enum("STATUS_SUCCESS", "INVALID_COOKIES", "TOO_MANY_REQUESTS", name="cookiesstatusenum"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column("user_id", sa.BigInteger(), autoincrement=False, nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["user.user_id"], name="mihoyo_cookies_ibfk_1"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
mysql_collate="utf8mb4_general_ci",
|
||||||
|
mysql_default_charset="utf8mb4",
|
||||||
|
mysql_engine="InnoDB",
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
old_cookies_database_name2,
|
||||||
|
sa.Column("cookies", sa.JSON(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"status",
|
||||||
|
sa.Enum("STATUS_SUCCESS", "INVALID_COOKIES", "TOO_MANY_REQUESTS", name="cookiesstatusenum"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column("user_id", sa.BigInteger(), autoincrement=False, nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["user.user_id"], name="hoyoverse_cookies_ibfk_1"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
mysql_collate="utf8mb4_general_ci",
|
||||||
|
mysql_default_charset="utf8mb4",
|
||||||
|
mysql_engine="InnoDB",
|
||||||
|
)
|
||||||
|
op.create_foreign_key("sign_ibfk_1", "sign", "user", ["user_id"], ["user_id"])
|
||||||
|
op.create_index("user_id", "sign", ["user_id"], unique=False)
|
||||||
|
op.drop_table("users")
|
||||||
|
op.drop_table("players")
|
||||||
|
op.drop_table("cookies")
|
||||||
|
op.drop_table("players_info")
|
||||||
|
# ### end Alembic commands ###
|
@ -1,14 +0,0 @@
|
|||||||
from core.service import init_service
|
|
||||||
from core.base.mysql import MySQL
|
|
||||||
from core.base.redisdb import RedisDB
|
|
||||||
from core.admin.cache import BotAdminCache
|
|
||||||
from core.admin.repositories import BotAdminRepository
|
|
||||||
from core.admin.services import BotAdminService
|
|
||||||
|
|
||||||
|
|
||||||
@init_service
|
|
||||||
def create_bot_admin_service(mysql: MySQL, redis: RedisDB):
|
|
||||||
_cache = BotAdminCache(redis)
|
|
||||||
_repository = BotAdminRepository(mysql)
|
|
||||||
_service = BotAdminService(_repository, _cache)
|
|
||||||
return _service
|
|
@ -1,38 +0,0 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
from core.base.redisdb import RedisDB
|
|
||||||
|
|
||||||
|
|
||||||
class BotAdminCache:
|
|
||||||
def __init__(self, redis: RedisDB):
|
|
||||||
self.client = redis.client
|
|
||||||
self.qname = "bot:admin"
|
|
||||||
|
|
||||||
async def get_list(self):
|
|
||||||
return [int(str_data) for str_data in await self.client.lrange(self.qname, 0, -1)]
|
|
||||||
|
|
||||||
async def set_list(self, str_list: List[int], ttl: int = -1):
|
|
||||||
await self.client.ltrim(self.qname, 1, 0)
|
|
||||||
await self.client.lpush(self.qname, *str_list)
|
|
||||||
if ttl != -1:
|
|
||||||
await self.client.expire(self.qname, ttl)
|
|
||||||
count = await self.client.llen(self.qname)
|
|
||||||
return count
|
|
||||||
|
|
||||||
|
|
||||||
class GroupAdminCache:
|
|
||||||
def __init__(self, redis: RedisDB):
|
|
||||||
self.client = redis.client
|
|
||||||
self.qname = "group:admin_list"
|
|
||||||
|
|
||||||
async def get_chat_admin(self, chat_id: int):
|
|
||||||
qname = f"{self.qname}:{chat_id}"
|
|
||||||
return [int(str_id) for str_id in await self.client.lrange(qname, 0, -1)]
|
|
||||||
|
|
||||||
async def set_chat_admin(self, chat_id: int, admin_list: List[int]):
|
|
||||||
qname = f"{self.qname}:{chat_id}"
|
|
||||||
await self.client.ltrim(qname, 1, 0)
|
|
||||||
await self.client.lpush(qname, *admin_list)
|
|
||||||
await self.client.expire(qname, 60)
|
|
||||||
count = await self.client.llen(qname)
|
|
||||||
return count
|
|
@ -1,8 +0,0 @@
|
|||||||
from sqlmodel import SQLModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class Admin(SQLModel, table=True):
|
|
||||||
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
|
||||||
|
|
||||||
id: int = Field(primary_key=True)
|
|
||||||
user_id: int = Field(foreign_key="user.user_id")
|
|
@ -1,33 +0,0 @@
|
|||||||
from typing import List, cast
|
|
||||||
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
from core.admin.models import Admin
|
|
||||||
from core.base.mysql import MySQL
|
|
||||||
|
|
||||||
|
|
||||||
class BotAdminRepository:
|
|
||||||
def __init__(self, mysql: MySQL):
|
|
||||||
self.mysql = mysql
|
|
||||||
|
|
||||||
async def delete_by_user_id(self, user_id: int):
|
|
||||||
async with self.mysql.Session() as session:
|
|
||||||
session = cast(AsyncSession, session)
|
|
||||||
statement = select(Admin).where(Admin.user_id == user_id)
|
|
||||||
results = await session.exec(statement)
|
|
||||||
admin = results.one()
|
|
||||||
await session.delete(admin)
|
|
||||||
|
|
||||||
async def add_by_user_id(self, user_id: int):
|
|
||||||
async with self.mysql.Session() as session:
|
|
||||||
admin = Admin(user_id=user_id)
|
|
||||||
session.add(admin)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
async def get_all_user_id(self) -> List[int]:
|
|
||||||
async with self.mysql.Session() as session:
|
|
||||||
query = select(Admin)
|
|
||||||
results = await session.exec(query)
|
|
||||||
admins = results.all()
|
|
||||||
return [admin[0].user_id for admin in admins]
|
|
@ -1,60 +0,0 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
from asyncmy.errors import IntegrityError
|
|
||||||
from telegram import Bot
|
|
||||||
|
|
||||||
from core.admin.cache import BotAdminCache, GroupAdminCache
|
|
||||||
from core.admin.repositories import BotAdminRepository
|
|
||||||
from core.config import config
|
|
||||||
from utils.log import logger
|
|
||||||
|
|
||||||
|
|
||||||
class BotAdminService:
|
|
||||||
def __init__(self, repository: BotAdminRepository, cache: BotAdminCache):
|
|
||||||
self._repository = repository
|
|
||||||
self._cache = cache
|
|
||||||
|
|
||||||
async def get_admin_list(self) -> List[int]:
|
|
||||||
admin_list = await self._cache.get_list()
|
|
||||||
if len(admin_list) == 0:
|
|
||||||
admin_list = await self._repository.get_all_user_id()
|
|
||||||
for config_admin in config.admins:
|
|
||||||
admin_list.append(config_admin.user_id)
|
|
||||||
await self._cache.set_list(admin_list)
|
|
||||||
return admin_list
|
|
||||||
|
|
||||||
async def add_admin(self, user_id: int) -> bool:
|
|
||||||
try:
|
|
||||||
await self._repository.add_by_user_id(user_id)
|
|
||||||
except IntegrityError:
|
|
||||||
logger.warning("用户 %s 已经存在 Admin 数据库", user_id)
|
|
||||||
admin_list = await self._repository.get_all_user_id()
|
|
||||||
for config_admin in config.admins:
|
|
||||||
admin_list.append(config_admin.user_id)
|
|
||||||
await self._cache.set_list(admin_list)
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def delete_admin(self, user_id: int) -> bool:
|
|
||||||
try:
|
|
||||||
await self._repository.delete_by_user_id(user_id)
|
|
||||||
except ValueError:
|
|
||||||
return False
|
|
||||||
admin_list = await self._repository.get_all_user_id()
|
|
||||||
for config_admin in config.admins:
|
|
||||||
admin_list.append(config_admin.user_id)
|
|
||||||
await self._cache.set_list(admin_list)
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class GroupAdminService:
|
|
||||||
def __init__(self, cache: GroupAdminCache):
|
|
||||||
self._cache = cache
|
|
||||||
|
|
||||||
async def get_admins(self, bot: Bot, chat_id: int, extra_user: List[int]) -> List[int]:
|
|
||||||
admin_id_list = await self._cache.get_chat_admin(chat_id)
|
|
||||||
if len(admin_id_list) == 0:
|
|
||||||
admin_list = await bot.get_chat_administrators(chat_id)
|
|
||||||
admin_id_list = [admin.user.id for admin in admin_list]
|
|
||||||
await self._cache.set_chat_admin(chat_id, admin_id_list)
|
|
||||||
admin_id_list += extra_user
|
|
||||||
return admin_id_list
|
|
287
core/application.py
Normal file
287
core/application.py
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
"""BOT"""
|
||||||
|
import asyncio
|
||||||
|
import signal
|
||||||
|
from functools import wraps
|
||||||
|
from signal import SIGABRT, SIGINT, SIGTERM, signal as signal_func
|
||||||
|
from ssl import SSLZeroReturnError
|
||||||
|
from typing import Callable, List, Optional, TYPE_CHECKING, TypeVar
|
||||||
|
|
||||||
|
import pytz
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from telegram import Bot, Update
|
||||||
|
from telegram.error import NetworkError, TelegramError, TimedOut
|
||||||
|
from telegram.ext import (
|
||||||
|
Application as TelegramApplication,
|
||||||
|
ApplicationBuilder as TelegramApplicationBuilder,
|
||||||
|
Defaults,
|
||||||
|
JobQueue,
|
||||||
|
)
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
from uvicorn import Server
|
||||||
|
|
||||||
|
from core.config import config as application_config
|
||||||
|
from core.handler.limiterhandler import LimiterHandler
|
||||||
|
from core.manager import Managers
|
||||||
|
from core.override.telegram import HTTPXRequest
|
||||||
|
from utils.const import WRAPPER_ASSIGNMENTS
|
||||||
|
from utils.log import logger
|
||||||
|
from utils.models.signal import Singleton
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from asyncio import Task
|
||||||
|
from types import FrameType
|
||||||
|
|
||||||
|
__all__ = ("Application",)
|
||||||
|
|
||||||
|
R = TypeVar("R")
|
||||||
|
T = TypeVar("T")
|
||||||
|
P = ParamSpec("P")
|
||||||
|
|
||||||
|
|
||||||
|
class Application(Singleton):
|
||||||
|
"""Application"""
|
||||||
|
|
||||||
|
_web_server_task: Optional["Task"] = None
|
||||||
|
|
||||||
|
_startup_funcs: List[Callable] = []
|
||||||
|
_shutdown_funcs: List[Callable] = []
|
||||||
|
|
||||||
|
def __init__(self, managers: "Managers", telegram: "TelegramApplication", web_server: "Server") -> None:
|
||||||
|
self._running = False
|
||||||
|
self.managers = managers
|
||||||
|
self.telegram = telegram
|
||||||
|
self.web_server = web_server
|
||||||
|
self.managers.set_application(application=self) # 给 managers 设置 application
|
||||||
|
self.managers.build_executor("Application")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build(cls):
|
||||||
|
managers = Managers()
|
||||||
|
telegram = (
|
||||||
|
TelegramApplicationBuilder()
|
||||||
|
.get_updates_read_timeout(application_config.update_read_timeout)
|
||||||
|
.get_updates_write_timeout(application_config.update_write_timeout)
|
||||||
|
.get_updates_connect_timeout(application_config.update_connect_timeout)
|
||||||
|
.get_updates_pool_timeout(application_config.update_pool_timeout)
|
||||||
|
.defaults(Defaults(tzinfo=pytz.timezone("Asia/Shanghai")))
|
||||||
|
.token(application_config.bot_token)
|
||||||
|
.request(
|
||||||
|
HTTPXRequest(
|
||||||
|
connection_pool_size=application_config.connection_pool_size,
|
||||||
|
proxy_url=application_config.proxy_url,
|
||||||
|
read_timeout=application_config.read_timeout,
|
||||||
|
write_timeout=application_config.write_timeout,
|
||||||
|
connect_timeout=application_config.connect_timeout,
|
||||||
|
pool_timeout=application_config.pool_timeout,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
web_server = Server(
|
||||||
|
uvicorn.Config(
|
||||||
|
app=FastAPI(debug=application_config.debug),
|
||||||
|
port=application_config.webserver.port,
|
||||||
|
host=application_config.webserver.host,
|
||||||
|
log_config=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return cls(managers, telegram, web_server)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def running(self) -> bool:
|
||||||
|
"""bot 是否正在运行"""
|
||||||
|
with self._lock:
|
||||||
|
return self._running
|
||||||
|
|
||||||
|
@property
|
||||||
|
def web_app(self) -> FastAPI:
|
||||||
|
"""fastapi app"""
|
||||||
|
return self.web_server.config.app
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bot(self) -> Optional[Bot]:
|
||||||
|
return self.telegram.bot
|
||||||
|
|
||||||
|
@property
|
||||||
|
def job_queue(self) -> Optional[JobQueue]:
|
||||||
|
return self.telegram.job_queue
|
||||||
|
|
||||||
|
async def _on_startup(self) -> None:
|
||||||
|
for func in self._startup_funcs:
|
||||||
|
await self.managers.executor(func, block=getattr(func, "block", False))
|
||||||
|
|
||||||
|
async def _on_shutdown(self) -> None:
|
||||||
|
for func in self._shutdown_funcs:
|
||||||
|
await self.managers.executor(func, block=getattr(func, "block", False))
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""BOT 初始化"""
|
||||||
|
self.telegram.add_handler(LimiterHandler(limit_time=10), group=-1) # 启用入口洪水限制
|
||||||
|
await self.managers.start_dependency() # 启动基础服务
|
||||||
|
await self.managers.init_components() # 实例化组件
|
||||||
|
await self.managers.start_services() # 启动其他服务
|
||||||
|
await self.managers.install_plugins() # 安装插件
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
"""BOT 关闭"""
|
||||||
|
await self.managers.uninstall_plugins() # 卸载插件
|
||||||
|
await self.managers.stop_services() # 终止其他服务
|
||||||
|
await self.managers.stop_dependency() # 终止基础服务
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""启动 BOT"""
|
||||||
|
logger.info("正在启动 BOT 中...")
|
||||||
|
|
||||||
|
def error_callback(exc: TelegramError) -> None:
|
||||||
|
"""错误信息回调"""
|
||||||
|
self.telegram.create_task(self.telegram.process_error(error=exc, update=None))
|
||||||
|
|
||||||
|
await self.telegram.initialize()
|
||||||
|
logger.info("[blue]Telegram[/] 初始化成功", extra={"markup": True})
|
||||||
|
|
||||||
|
if application_config.webserver.enable: # 如果使用 web app
|
||||||
|
server_config = self.web_server.config
|
||||||
|
server_config.setup_event_loop()
|
||||||
|
if not server_config.loaded:
|
||||||
|
server_config.load()
|
||||||
|
self.web_server.lifespan = server_config.lifespan_class(server_config)
|
||||||
|
try:
|
||||||
|
await self.web_server.startup()
|
||||||
|
except OSError as e:
|
||||||
|
if e.errno == 10048:
|
||||||
|
logger.error("Web Server 端口被占用:%s", e)
|
||||||
|
logger.error("Web Server 启动失败,正在退出")
|
||||||
|
raise SystemExit from None
|
||||||
|
|
||||||
|
if self.web_server.should_exit:
|
||||||
|
logger.error("Web Server 启动失败,正在退出")
|
||||||
|
raise SystemExit from None
|
||||||
|
logger.success("Web Server 启动成功")
|
||||||
|
|
||||||
|
self._web_server_task = asyncio.create_task(self.web_server.main_loop())
|
||||||
|
|
||||||
|
for _ in range(5): # 连接至 telegram 服务器
|
||||||
|
try:
|
||||||
|
await self.telegram.updater.start_polling(
|
||||||
|
error_callback=error_callback, allowed_updates=Update.ALL_TYPES
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except TimedOut:
|
||||||
|
logger.warning("连接至 [blue]telegram[/] 服务器失败,正在重试", extra={"markup": True})
|
||||||
|
continue
|
||||||
|
except NetworkError as e:
|
||||||
|
logger.exception()
|
||||||
|
if isinstance(e, SSLZeroReturnError):
|
||||||
|
logger.error("代理服务出现异常, 请检查您的代理服务是否配置成功.")
|
||||||
|
else:
|
||||||
|
logger.error("网络连接出现问题, 请检查您的网络状况.")
|
||||||
|
raise SystemExit from e
|
||||||
|
|
||||||
|
await self.initialize()
|
||||||
|
logger.success("BOT 初始化成功")
|
||||||
|
logger.debug("BOT 开始启动")
|
||||||
|
|
||||||
|
await self._on_startup()
|
||||||
|
await self.telegram.start()
|
||||||
|
self._running = True
|
||||||
|
logger.success("BOT 启动成功")
|
||||||
|
|
||||||
|
def stop_signal_handler(self, signum: int):
|
||||||
|
"""终止信号处理"""
|
||||||
|
signals = {k: v for v, k in signal.__dict__.items() if v.startswith("SIG") and not v.startswith("SIG_")}
|
||||||
|
logger.debug("接收到了终止信号 %s 正在退出...", signals[signum])
|
||||||
|
if self._web_server_task:
|
||||||
|
self._web_server_task.cancel()
|
||||||
|
|
||||||
|
async def idle(self) -> None:
|
||||||
|
"""在接收到中止信号之前,堵塞loop"""
|
||||||
|
|
||||||
|
task = None
|
||||||
|
|
||||||
|
def stop_handler(signum: int, _: "FrameType") -> None:
|
||||||
|
self.stop_signal_handler(signum)
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
for s in (SIGINT, SIGTERM, SIGABRT):
|
||||||
|
signal_func(s, stop_handler)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
task = asyncio.create_task(asyncio.sleep(600))
|
||||||
|
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""关闭"""
|
||||||
|
logger.info("BOT 正在关闭")
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
await self._on_shutdown()
|
||||||
|
|
||||||
|
if self.telegram.updater.running:
|
||||||
|
await self.telegram.updater.stop()
|
||||||
|
|
||||||
|
await self.shutdown()
|
||||||
|
|
||||||
|
if self.telegram.running:
|
||||||
|
await self.telegram.stop()
|
||||||
|
|
||||||
|
await self.telegram.shutdown()
|
||||||
|
if self.web_server is not None:
|
||||||
|
try:
|
||||||
|
await self.web_server.shutdown()
|
||||||
|
logger.info("Web Server 已经关闭")
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.success("BOT 关闭成功")
|
||||||
|
|
||||||
|
def launch(self) -> None:
|
||||||
|
"""启动"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(self.start())
|
||||||
|
loop.run_until_complete(self.idle())
|
||||||
|
except (SystemExit, KeyboardInterrupt) as exc:
|
||||||
|
logger.debug("接收到了终止信号,BOT 即将关闭", exc_info=exc) # 接收到了终止信号
|
||||||
|
except NetworkError as e:
|
||||||
|
if isinstance(e, SSLZeroReturnError):
|
||||||
|
logger.critical("代理服务出现异常, 请检查您的代理服务是否配置成功.")
|
||||||
|
else:
|
||||||
|
logger.critical("网络连接出现问题, 请检查您的网络状况.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.critical("遇到了未知错误: %s", {type(e)}, exc_info=e)
|
||||||
|
finally:
|
||||||
|
loop.run_until_complete(self.stop())
|
||||||
|
|
||||||
|
if application_config.reload:
|
||||||
|
raise SystemExit from None
|
||||||
|
|
||||||
|
def on_startup(self, func: Callable[P, R]) -> Callable[P, R]:
|
||||||
|
"""注册一个在 BOT 启动时执行的函数"""
|
||||||
|
|
||||||
|
if func not in self._startup_funcs:
|
||||||
|
self._startup_funcs.append(func)
|
||||||
|
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
||||||
|
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
def on_shutdown(self, func: Callable[P, R]) -> Callable[P, R]:
|
||||||
|
"""注册一个在 BOT 停止时执行的函数"""
|
||||||
|
|
||||||
|
if func not in self._shutdown_funcs:
|
||||||
|
self._shutdown_funcs.append(func)
|
||||||
|
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
||||||
|
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
@ -1,31 +0,0 @@
|
|||||||
from sqlalchemy.ext.asyncio import create_async_engine
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
from typing_extensions import Self
|
|
||||||
|
|
||||||
from core.config import BotConfig
|
|
||||||
from core.service import Service
|
|
||||||
|
|
||||||
|
|
||||||
class MySQL(Service):
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: BotConfig) -> Self:
|
|
||||||
return cls(**config.mysql.dict())
|
|
||||||
|
|
||||||
def __init__(self, host: str, port: int, username: str, password: str, database: str):
|
|
||||||
self.database = database
|
|
||||||
self.password = password
|
|
||||||
self.user = username
|
|
||||||
self.port = port
|
|
||||||
self.host = host
|
|
||||||
self.url = f"mysql+asyncmy://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
|
|
||||||
self.engine = create_async_engine(self.url)
|
|
||||||
self.Session = sessionmaker(bind=self.engine, class_=AsyncSession)
|
|
||||||
|
|
||||||
async def get_session(self):
|
|
||||||
"""获取会话"""
|
|
||||||
async with self.Session() as session:
|
|
||||||
yield session
|
|
||||||
|
|
||||||
async def stop(self):
|
|
||||||
self.Session.close_all()
|
|
@ -1,64 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
|
|
||||||
import uvicorn
|
|
||||||
from fastapi import FastAPI
|
|
||||||
|
|
||||||
from core.config import (
|
|
||||||
BotConfig,
|
|
||||||
config as botConfig,
|
|
||||||
)
|
|
||||||
from core.service import Service
|
|
||||||
|
|
||||||
__all__ = ["webapp", "WebServer"]
|
|
||||||
|
|
||||||
webapp = FastAPI(debug=botConfig.debug)
|
|
||||||
|
|
||||||
|
|
||||||
@webapp.get("/")
|
|
||||||
def index():
|
|
||||||
return {"Hello": "Paimon"}
|
|
||||||
|
|
||||||
|
|
||||||
class WebServer(Service):
|
|
||||||
debug: bool
|
|
||||||
|
|
||||||
host: str
|
|
||||||
port: int
|
|
||||||
|
|
||||||
server: uvicorn.Server
|
|
||||||
|
|
||||||
_server_task: asyncio.Task
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: BotConfig) -> Service:
|
|
||||||
return cls(debug=config.debug, host=config.webserver.host, port=config.webserver.port)
|
|
||||||
|
|
||||||
def __init__(self, debug: bool, host: str, port: int):
|
|
||||||
self.debug = debug
|
|
||||||
self.host = host
|
|
||||||
self.port = port
|
|
||||||
|
|
||||||
self.server = uvicorn.Server(
|
|
||||||
uvicorn.Config(app=webapp, port=port, use_colors=False, host=host, log_config=None)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
"""启动 service"""
|
|
||||||
|
|
||||||
# 暂时只在开发环境启动 webserver 用于开发调试
|
|
||||||
if not self.debug:
|
|
||||||
return
|
|
||||||
|
|
||||||
# 防止 uvicorn server 拦截 signals
|
|
||||||
self.server.install_signal_handlers = lambda: None
|
|
||||||
self._server_task = asyncio.create_task(self.server.serve())
|
|
||||||
|
|
||||||
async def stop(self):
|
|
||||||
"""关闭 service"""
|
|
||||||
if not self.debug:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.server.should_exit = True
|
|
||||||
|
|
||||||
# 等待 task 结束
|
|
||||||
await self._server_task
|
|
60
core/base_service.py
Normal file
60
core/base_service.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
from abc import ABC
|
||||||
|
from itertools import chain
|
||||||
|
from typing import ClassVar, Iterable, Type, TypeVar
|
||||||
|
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from utils.helpers import isabstract
|
||||||
|
|
||||||
|
__all__ = ("BaseService", "BaseServiceType", "DependenceType", "ComponentType", "get_all_services")
|
||||||
|
|
||||||
|
|
||||||
|
class _BaseService:
|
||||||
|
"""服务基类"""
|
||||||
|
|
||||||
|
_is_component: ClassVar[bool] = False
|
||||||
|
_is_dependence: ClassVar[bool] = False
|
||||||
|
|
||||||
|
def __init_subclass__(cls, load: bool = True, **kwargs):
|
||||||
|
cls.is_dependence = cls._is_dependence
|
||||||
|
cls.is_component = cls._is_component
|
||||||
|
cls.load = load
|
||||||
|
|
||||||
|
async def __aenter__(self) -> Self:
|
||||||
|
await self.initialize()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||||
|
await self.shutdown()
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Initialize resources used by this service"""
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
"""Stop & clear resources used by this service"""
|
||||||
|
|
||||||
|
|
||||||
|
class _Dependence(_BaseService, ABC):
|
||||||
|
_is_dependence: ClassVar[bool] = True
|
||||||
|
|
||||||
|
|
||||||
|
class _Component(_BaseService, ABC):
|
||||||
|
_is_component: ClassVar[bool] = True
|
||||||
|
|
||||||
|
|
||||||
|
class BaseService(_BaseService, ABC):
|
||||||
|
Dependence: Type[_BaseService] = _Dependence
|
||||||
|
Component: Type[_BaseService] = _Component
|
||||||
|
|
||||||
|
|
||||||
|
BaseServiceType = TypeVar("BaseServiceType", bound=_BaseService)
|
||||||
|
DependenceType = TypeVar("DependenceType", bound=_Dependence)
|
||||||
|
ComponentType = TypeVar("ComponentType", bound=_Component)
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
def get_all_services() -> Iterable[Type[_BaseService]]:
|
||||||
|
return filter(
|
||||||
|
lambda x: x.__name__[0] != "_" and x.load and not isabstract(x),
|
||||||
|
chain(BaseService.__subclasses__(), _Dependence.__subclasses__(), _Component.__subclasses__()),
|
||||||
|
)
|
29
core/basemodel.py
Normal file
29
core/basemodel.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import enum
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ujson as jsonlib
|
||||||
|
except ImportError:
|
||||||
|
import json as jsonlib
|
||||||
|
|
||||||
|
from pydantic import BaseSettings
|
||||||
|
|
||||||
|
__all__ = ("RegionEnum", "Settings")
|
||||||
|
|
||||||
|
|
||||||
|
class RegionEnum(int, enum.Enum):
|
||||||
|
"""账号数据所在服务器"""
|
||||||
|
|
||||||
|
NULL = 0
|
||||||
|
HYPERION = 1 # 米忽悠国服 hyperion
|
||||||
|
HOYOLAB = 2 # 米忽悠国际服 hoyolab
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
cls.update_forward_refs()
|
||||||
|
return super(Settings, cls).__new__(cls) # pylint: disable=E1120
|
||||||
|
|
||||||
|
class Config(BaseSettings.Config):
|
||||||
|
case_sensitive = False
|
||||||
|
json_loads = jsonlib.loads
|
||||||
|
json_dumps = jsonlib.dumps
|
@ -1,69 +0,0 @@
|
|||||||
from telegram import Update, ReplyKeyboardRemove
|
|
||||||
from telegram.error import BadRequest, Forbidden
|
|
||||||
from telegram.ext import CallbackContext, ConversationHandler
|
|
||||||
|
|
||||||
from core.plugin import handler, conversation
|
|
||||||
from utils.bot import get_chat
|
|
||||||
from utils.log import logger
|
|
||||||
|
|
||||||
|
|
||||||
async def clean_message(context: CallbackContext):
|
|
||||||
job = context.job
|
|
||||||
message_id = job.data
|
|
||||||
chat_info = f"chat_id[{job.chat_id}]"
|
|
||||||
try:
|
|
||||||
chat = await get_chat(job.chat_id)
|
|
||||||
full_name = chat.full_name
|
|
||||||
if full_name:
|
|
||||||
chat_info = f"{full_name}[{chat.id}]"
|
|
||||||
else:
|
|
||||||
chat_info = f"{chat.title}[{chat.id}]"
|
|
||||||
except (BadRequest, Forbidden) as exc:
|
|
||||||
logger.warning("获取 chat info 失败 %s", exc.message)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("获取 chat info 消息失败 %s", str(exc))
|
|
||||||
logger.debug("删除消息 %s message_id[%s]", chat_info, message_id)
|
|
||||||
try:
|
|
||||||
# noinspection PyTypeChecker
|
|
||||||
await context.bot.delete_message(chat_id=job.chat_id, message_id=message_id)
|
|
||||||
except BadRequest as exc:
|
|
||||||
if "not found" in exc.message:
|
|
||||||
logger.warning("删除消息 %s message_id[%s] 失败 消息不存在", chat_info, message_id)
|
|
||||||
elif "Message can't be deleted" in exc.message:
|
|
||||||
logger.warning("删除消息 %s message_id[%s] 失败 消息无法删除 可能是没有授权", chat_info, message_id)
|
|
||||||
else:
|
|
||||||
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
|
|
||||||
except Forbidden as exc:
|
|
||||||
if "bot was kicked" in exc.message:
|
|
||||||
logger.warning("删除消息 %s message_id[%s] 失败 已经被踢出群", chat_info, message_id)
|
|
||||||
else:
|
|
||||||
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
|
|
||||||
|
|
||||||
|
|
||||||
def add_delete_message_job(context: CallbackContext, chat_id: int, message_id: int, delete_seconds: int):
|
|
||||||
context.job_queue.run_once(
|
|
||||||
callback=clean_message,
|
|
||||||
when=delete_seconds,
|
|
||||||
data=message_id,
|
|
||||||
name=f"{chat_id}|{message_id}|clean_message",
|
|
||||||
chat_id=chat_id,
|
|
||||||
job_kwargs={"replace_existing": True, "id": f"{chat_id}|{message_id}|clean_message"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _BasePlugin:
|
|
||||||
@staticmethod
|
|
||||||
def _add_delete_message_job(context: CallbackContext, chat_id: int, message_id: int, delete_seconds: int = 60):
|
|
||||||
return add_delete_message_job(context, chat_id, message_id, delete_seconds)
|
|
||||||
|
|
||||||
|
|
||||||
class _Conversation(_BasePlugin):
|
|
||||||
@conversation.fallback
|
|
||||||
@handler.command(command="cancel", block=True)
|
|
||||||
async def cancel(self, update: Update, _: CallbackContext) -> int:
|
|
||||||
await update.effective_message.reply_text("退出命令", reply_markup=ReplyKeyboardRemove())
|
|
||||||
return ConversationHandler.END
|
|
||||||
|
|
||||||
|
|
||||||
class BasePlugin(_BasePlugin):
|
|
||||||
Conversation = _Conversation
|
|
345
core/bot.py
345
core/bot.py
@ -1,345 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import inspect
|
|
||||||
import os
|
|
||||||
from asyncio import CancelledError
|
|
||||||
from importlib import import_module
|
|
||||||
from multiprocessing import RLock as Lock
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Callable, ClassVar, Dict, Iterator, List, NoReturn, Optional, TYPE_CHECKING, Type, TypeVar
|
|
||||||
|
|
||||||
import genshin
|
|
||||||
import pytz
|
|
||||||
from async_timeout import timeout
|
|
||||||
from telegram import Update
|
|
||||||
from telegram import __version__ as tg_version
|
|
||||||
from telegram.error import NetworkError, TimedOut
|
|
||||||
from telegram.ext import (
|
|
||||||
AIORateLimiter,
|
|
||||||
Application as TgApplication,
|
|
||||||
CallbackContext,
|
|
||||||
Defaults,
|
|
||||||
JobQueue,
|
|
||||||
MessageHandler,
|
|
||||||
filters,
|
|
||||||
TypeHandler,
|
|
||||||
)
|
|
||||||
from telegram.ext.filters import StatusUpdate
|
|
||||||
|
|
||||||
from core.config import BotConfig, config # pylint: disable=W0611
|
|
||||||
from core.error import ServiceNotFoundError
|
|
||||||
|
|
||||||
# noinspection PyProtectedMember
|
|
||||||
from core.plugin import Plugin, _Plugin
|
|
||||||
from core.service import Service
|
|
||||||
from metadata.scripts.metadatas import make_github_fast
|
|
||||||
from utils.const import PLUGIN_DIR, PROJECT_ROOT
|
|
||||||
from utils.log import logger
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["bot"]
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
PluginType = TypeVar("PluginType", bound=_Plugin)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from telegram import __version_info__ as tg_version_info
|
|
||||||
except ImportError:
|
|
||||||
tg_version_info = (0, 0, 0, 0, 0) # type: ignore[assignment]
|
|
||||||
|
|
||||||
if tg_version_info < (20, 0, 0, "alpha", 6):
|
|
||||||
logger.warning(
|
|
||||||
"Bot与当前PTB版本 [cyan bold]%s[/] [red bold]不兼容[/],请更新到最新版本后使用 [blue bold]poetry install[/] 重新安装依赖",
|
|
||||||
tg_version,
|
|
||||||
extra={"markup": True},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Bot:
|
|
||||||
_lock: ClassVar[Lock] = Lock()
|
|
||||||
_instance: ClassVar[Optional["Bot"]] = None
|
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs) -> "Bot":
|
|
||||||
"""实现单例"""
|
|
||||||
with cls._lock: # 使线程、进程安全
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = object.__new__(cls)
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
app: Optional[TgApplication] = None
|
|
||||||
_config: BotConfig = config
|
|
||||||
_services: Dict[Type[T], T] = {}
|
|
||||||
_running: bool = False
|
|
||||||
|
|
||||||
def _inject(self, signature: inspect.Signature, target: Callable[..., T]) -> T:
|
|
||||||
kwargs = {}
|
|
||||||
for name, parameter in signature.parameters.items():
|
|
||||||
if name != "self" and parameter.annotation != inspect.Parameter.empty:
|
|
||||||
if value := self._services.get(parameter.annotation):
|
|
||||||
kwargs[name] = value
|
|
||||||
return target(**kwargs)
|
|
||||||
|
|
||||||
def init_inject(self, target: Callable[..., T]) -> T:
|
|
||||||
"""用于实例化Plugin的方法。用于给插件传入一些必要组件,如 MySQL、Redis等"""
|
|
||||||
if isinstance(target, type):
|
|
||||||
signature = inspect.signature(target.__init__)
|
|
||||||
else:
|
|
||||||
signature = inspect.signature(target)
|
|
||||||
return self._inject(signature, target)
|
|
||||||
|
|
||||||
async def async_inject(self, target: Callable[..., T]) -> T:
|
|
||||||
return await self._inject(inspect.signature(target), target)
|
|
||||||
|
|
||||||
def _gen_pkg(self, root: Path) -> Iterator[str]:
|
|
||||||
"""生成可以用于 import_module 导入的字符串"""
|
|
||||||
for path in root.iterdir():
|
|
||||||
if not path.name.startswith("_"):
|
|
||||||
if path.is_dir():
|
|
||||||
yield from self._gen_pkg(path)
|
|
||||||
elif path.suffix == ".py":
|
|
||||||
yield str(path.relative_to(PROJECT_ROOT).with_suffix("")).replace(os.sep, ".")
|
|
||||||
|
|
||||||
async def install_plugins(self):
|
|
||||||
"""安装插件"""
|
|
||||||
for pkg in self._gen_pkg(PLUGIN_DIR):
|
|
||||||
try:
|
|
||||||
import_module(pkg) # 导入插件
|
|
||||||
except Exception as e: # pylint: disable=W0703
|
|
||||||
logger.exception(
|
|
||||||
'在导入文件 "%s" 的过程中遇到了错误 [red bold]%s[/]', pkg, type(e).__name__, exc_info=e, extra={"markup": True}
|
|
||||||
)
|
|
||||||
continue # 如有错误则继续
|
|
||||||
callback_dict: Dict[int, List[Callable]] = {}
|
|
||||||
for plugin_cls in {*Plugin.__subclasses__(), *Plugin.Conversation.__subclasses__()}:
|
|
||||||
path = f"{plugin_cls.__module__}.{plugin_cls.__name__}"
|
|
||||||
try:
|
|
||||||
plugin: PluginType = self.init_inject(plugin_cls)
|
|
||||||
if hasattr(plugin, "__async_init__"):
|
|
||||||
await self.async_inject(plugin.__async_init__)
|
|
||||||
handlers = plugin.handlers
|
|
||||||
for index, handler in enumerate(handlers):
|
|
||||||
if isinstance(handler, TypeHandler): # 对 TypeHandler 进行特殊处理,优先级必须设置 -1,否则无用
|
|
||||||
handlers.pop(index)
|
|
||||||
self.app.add_handler(handler, group=-1)
|
|
||||||
self.app.add_handlers(handlers)
|
|
||||||
if handlers:
|
|
||||||
logger.debug('插件 "%s" 添加了 %s 个 handler ', path, len(handlers))
|
|
||||||
|
|
||||||
# noinspection PyProtectedMember
|
|
||||||
for priority, callback in plugin._new_chat_members_handler_funcs(): # pylint: disable=W0212
|
|
||||||
if not callback_dict.get(priority):
|
|
||||||
callback_dict[priority] = []
|
|
||||||
callback_dict[priority].append(callback)
|
|
||||||
|
|
||||||
error_handlers = plugin.error_handlers
|
|
||||||
for callback, block in error_handlers.items():
|
|
||||||
self.app.add_error_handler(callback, block)
|
|
||||||
if error_handlers:
|
|
||||||
logger.debug('插件 "%s" 添加了 %s 个 error handler ', path, len(error_handlers))
|
|
||||||
|
|
||||||
if jobs := plugin.jobs:
|
|
||||||
logger.debug('插件 "%s" 添加了 %s 个 jobs ', path, len(jobs))
|
|
||||||
logger.success('插件 "%s" 载入成功', path)
|
|
||||||
except Exception as e: # pylint: disable=W0703
|
|
||||||
logger.exception(
|
|
||||||
'在安装插件 "%s" 的过程中遇到了错误 [red bold]%s[/]', path, type(e).__name__, exc_info=e, extra={"markup": True}
|
|
||||||
)
|
|
||||||
if callback_dict:
|
|
||||||
num = sum(len(callback_dict[i]) for i in callback_dict)
|
|
||||||
|
|
||||||
async def _new_chat_member_callback(update: "Update", context: "CallbackContext"):
|
|
||||||
nonlocal callback
|
|
||||||
for _, value in callback_dict.items():
|
|
||||||
for callback in value:
|
|
||||||
await callback(update, context)
|
|
||||||
|
|
||||||
self.app.add_handler(
|
|
||||||
MessageHandler(callback=_new_chat_member_callback, filters=StatusUpdate.NEW_CHAT_MEMBERS, block=False)
|
|
||||||
)
|
|
||||||
logger.success(
|
|
||||||
"成功添加了 %s 个针对 [blue]%s[/] 的 [blue]MessageHandler[/]",
|
|
||||||
num,
|
|
||||||
StatusUpdate.NEW_CHAT_MEMBERS,
|
|
||||||
extra={"markup": True},
|
|
||||||
)
|
|
||||||
# special handler
|
|
||||||
from plugins.system.start import StartPlugin
|
|
||||||
|
|
||||||
self.app.add_handler(
|
|
||||||
MessageHandler(
|
|
||||||
callback=StartPlugin.unknown_command, filters=filters.COMMAND & filters.ChatType.PRIVATE, block=False
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _start_base_services(self):
|
|
||||||
for pkg in self._gen_pkg(PROJECT_ROOT / "core/base"):
|
|
||||||
try:
|
|
||||||
import_module(pkg)
|
|
||||||
except Exception as e: # pylint: disable=W0703
|
|
||||||
logger.exception(
|
|
||||||
'在导入文件 "%s" 的过程中遇到了错误 [red bold]%s[/]', pkg, type(e).__name__, exc_info=e, extra={"markup": True}
|
|
||||||
)
|
|
||||||
raise SystemExit from e
|
|
||||||
for base_service_cls in Service.__subclasses__():
|
|
||||||
try:
|
|
||||||
if hasattr(base_service_cls, "from_config"):
|
|
||||||
instance = base_service_cls.from_config(self._config)
|
|
||||||
else:
|
|
||||||
instance = self.init_inject(base_service_cls)
|
|
||||||
await instance.start()
|
|
||||||
logger.success('服务 "%s" 初始化成功', base_service_cls.__name__)
|
|
||||||
self._services.update({base_service_cls: instance})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error('服务 "%s" 初始化失败', base_service_cls.__name__)
|
|
||||||
raise SystemExit from e
|
|
||||||
|
|
||||||
async def start_services(self):
|
|
||||||
"""启动服务"""
|
|
||||||
await self._start_base_services()
|
|
||||||
for path in (PROJECT_ROOT / "core").iterdir():
|
|
||||||
if not path.name.startswith("_") and path.is_dir() and path.name != "base":
|
|
||||||
pkg = str(path.relative_to(PROJECT_ROOT).with_suffix("")).replace(os.sep, ".")
|
|
||||||
try:
|
|
||||||
import_module(pkg)
|
|
||||||
except Exception as e: # pylint: disable=W0703
|
|
||||||
logger.exception(
|
|
||||||
'在导入文件 "%s" 的过程中遇到了错误 [red bold]%s[/]',
|
|
||||||
pkg,
|
|
||||||
type(e).__name__,
|
|
||||||
exc_info=e,
|
|
||||||
extra={"markup": True},
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
async def stop_services(self):
|
|
||||||
"""关闭服务"""
|
|
||||||
if not self._services:
|
|
||||||
return
|
|
||||||
logger.info("正在关闭服务")
|
|
||||||
for _, service in filter(lambda x: not isinstance(x[1], TgApplication), self._services.items()):
|
|
||||||
async with timeout(5):
|
|
||||||
try:
|
|
||||||
if hasattr(service, "stop"):
|
|
||||||
if inspect.iscoroutinefunction(service.stop):
|
|
||||||
await service.stop()
|
|
||||||
else:
|
|
||||||
service.stop()
|
|
||||||
logger.success('服务 "%s" 关闭成功', service.__class__.__name__)
|
|
||||||
except CancelledError:
|
|
||||||
logger.warning('服务 "%s" 关闭超时', service.__class__.__name__)
|
|
||||||
except Exception as e: # pylint: disable=W0703
|
|
||||||
logger.exception('服务 "%s" 关闭失败', service.__class__.__name__, exc_info=e)
|
|
||||||
|
|
||||||
async def _post_init(self, context: CallbackContext) -> NoReturn:
|
|
||||||
logger.info("开始初始化 genshin.py 相关资源")
|
|
||||||
try:
|
|
||||||
# 替换为 fastgit 镜像源
|
|
||||||
for i in dir(genshin.utility.extdb):
|
|
||||||
if "_URL" in i:
|
|
||||||
setattr(
|
|
||||||
genshin.utility.extdb,
|
|
||||||
i,
|
|
||||||
make_github_fast(getattr(genshin.utility.extdb, i)),
|
|
||||||
)
|
|
||||||
await genshin.utility.update_characters_enka()
|
|
||||||
except Exception as exc: # pylint: disable=W0703
|
|
||||||
logger.error("初始化 genshin.py 相关资源失败")
|
|
||||||
logger.exception(exc)
|
|
||||||
else:
|
|
||||||
logger.success("初始化 genshin.py 相关资源成功")
|
|
||||||
self._services.update({CallbackContext: context})
|
|
||||||
logger.info("开始初始化服务")
|
|
||||||
await self.start_services()
|
|
||||||
logger.info("开始安装插件")
|
|
||||||
await self.install_plugins()
|
|
||||||
logger.info("BOT 初始化成功")
|
|
||||||
|
|
||||||
def launch(self) -> NoReturn:
|
|
||||||
"""启动机器人"""
|
|
||||||
self._running = True
|
|
||||||
logger.info("正在初始化BOT")
|
|
||||||
self.app = (
|
|
||||||
TgApplication.builder()
|
|
||||||
.read_timeout(self.config.read_timeout)
|
|
||||||
.write_timeout(self.config.write_timeout)
|
|
||||||
.connect_timeout(self.config.connect_timeout)
|
|
||||||
.pool_timeout(self.config.pool_timeout)
|
|
||||||
.get_updates_read_timeout(self.config.update_read_timeout)
|
|
||||||
.get_updates_write_timeout(self.config.update_write_timeout)
|
|
||||||
.get_updates_connect_timeout(self.config.update_connect_timeout)
|
|
||||||
.get_updates_pool_timeout(self.config.update_pool_timeout)
|
|
||||||
.rate_limiter(AIORateLimiter())
|
|
||||||
.defaults(Defaults(tzinfo=pytz.timezone("Asia/Shanghai")))
|
|
||||||
.token(self._config.bot_token)
|
|
||||||
.post_init(self._post_init)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
for _ in range(5):
|
|
||||||
try:
|
|
||||||
self.app.run_polling(
|
|
||||||
close_loop=False,
|
|
||||||
timeout=self.config.timeout,
|
|
||||||
allowed_updates=Update.ALL_TYPES,
|
|
||||||
)
|
|
||||||
break
|
|
||||||
except TimedOut:
|
|
||||||
logger.warning("连接至 [blue]telegram[/] 服务器失败,正在重试", extra={"markup": True})
|
|
||||||
continue
|
|
||||||
except NetworkError as e:
|
|
||||||
if "SSLZeroReturnError" in str(e):
|
|
||||||
logger.error("代理服务出现异常, 请检查您的代理服务是否配置成功.")
|
|
||||||
else:
|
|
||||||
logger.error("网络连接出现问题, 请检查您的网络状况.")
|
|
||||||
break
|
|
||||||
except (SystemExit, KeyboardInterrupt):
|
|
||||||
pass
|
|
||||||
except Exception as e: # pylint: disable=W0703
|
|
||||||
logger.exception("BOT 执行过程中出现错误", exc_info=e)
|
|
||||||
finally:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
loop.run_until_complete(self.stop_services())
|
|
||||||
loop.close()
|
|
||||||
logger.info("BOT 已经关闭")
|
|
||||||
self._running = False
|
|
||||||
|
|
||||||
def find_service(self, target: Type[T]) -> T:
|
|
||||||
"""查找服务。若没找到则抛出 ServiceNotFoundError"""
|
|
||||||
if (result := self._services.get(target)) is None:
|
|
||||||
raise ServiceNotFoundError(target)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def add_service(self, service: T) -> NoReturn:
|
|
||||||
"""添加服务。若已经有同类型的服务,则会抛出异常"""
|
|
||||||
if type(service) in self._services:
|
|
||||||
raise ValueError(f'Service "{type(service)}" is already existed.')
|
|
||||||
self.update_service(service)
|
|
||||||
|
|
||||||
def update_service(self, service: T):
|
|
||||||
"""更新服务。若服务不存在,则添加;若存在,则更新"""
|
|
||||||
self._services.update({type(service): service})
|
|
||||||
|
|
||||||
def contain_service(self, service: Any) -> bool:
|
|
||||||
"""判断服务是否存在"""
|
|
||||||
if isinstance(service, type):
|
|
||||||
return service in self._services
|
|
||||||
else:
|
|
||||||
return service in self._services.values()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def job_queue(self) -> JobQueue:
|
|
||||||
return self.app.job_queue
|
|
||||||
|
|
||||||
@property
|
|
||||||
def services(self) -> Dict[Type[T], T]:
|
|
||||||
return self._services
|
|
||||||
|
|
||||||
@property
|
|
||||||
def config(self) -> BotConfig:
|
|
||||||
return self._config
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_running(self) -> bool:
|
|
||||||
return self._running
|
|
||||||
|
|
||||||
|
|
||||||
bot = Bot()
|
|
1
core/builtins/__init__.py
Normal file
1
core/builtins/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""bot builtins"""
|
38
core/builtins/contexts.py
Normal file
38
core/builtins/contexts.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
"""上下文管理"""
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from telegram.ext import CallbackContext
|
||||||
|
from telegram import Update
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CallbackContextCV",
|
||||||
|
"UpdateCV",
|
||||||
|
"handler_contexts",
|
||||||
|
"job_contexts",
|
||||||
|
]
|
||||||
|
|
||||||
|
CallbackContextCV: ContextVar["CallbackContext"] = ContextVar("TelegramContextCallback")
|
||||||
|
UpdateCV: ContextVar["Update"] = ContextVar("TelegramUpdate")
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def handler_contexts(update: "Update", context: "CallbackContext") -> None:
|
||||||
|
context_token = CallbackContextCV.set(context)
|
||||||
|
update_token = UpdateCV.set(update)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
CallbackContextCV.reset(context_token)
|
||||||
|
UpdateCV.reset(update_token)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def job_contexts(context: "CallbackContext") -> None:
|
||||||
|
token = CallbackContextCV.set(context)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
CallbackContextCV.reset(token)
|
309
core/builtins/dispatcher.py
Normal file
309
core/builtins/dispatcher.py
Normal file
@ -0,0 +1,309 @@
|
|||||||
|
"""参数分发器"""
|
||||||
|
import asyncio
|
||||||
|
import inspect
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from asyncio import AbstractEventLoop
|
||||||
|
from functools import cached_property, lru_cache, partial, wraps
|
||||||
|
from inspect import Parameter, Signature
|
||||||
|
from itertools import chain
|
||||||
|
from types import GenericAlias, MethodType
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from arkowrapper import ArkoWrapper
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from telegram import Bot as TelegramBot, Chat, Message, Update, User
|
||||||
|
from telegram.ext import Application as TelegramApplication, CallbackContext, Job
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
from uvicorn import Server
|
||||||
|
|
||||||
|
from core.application import Application
|
||||||
|
from utils.const import WRAPPER_ASSIGNMENTS
|
||||||
|
from utils.typedefs import R, T
|
||||||
|
|
||||||
|
__all__ = (
|
||||||
|
"catch",
|
||||||
|
"AbstractDispatcher",
|
||||||
|
"BaseDispatcher",
|
||||||
|
"HandlerDispatcher",
|
||||||
|
"JobDispatcher",
|
||||||
|
"dispatched",
|
||||||
|
)
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
|
||||||
|
TargetType = Union[Type, str, Callable[[Any], bool]]
|
||||||
|
|
||||||
|
_CATCH_TARGET_ATTR = "_catch_targets"
|
||||||
|
|
||||||
|
|
||||||
|
def catch(*targets: Union[str, Type]) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||||
|
def decorate(func: Callable[P, R]) -> Callable[P, R]:
|
||||||
|
setattr(func, _CATCH_TARGET_ATTR, targets)
|
||||||
|
|
||||||
|
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
||||||
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorate
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(64)
|
||||||
|
def get_signature(func: Union[type, Callable]) -> Signature:
|
||||||
|
if isinstance(func, type):
|
||||||
|
return inspect.signature(func.__init__)
|
||||||
|
return inspect.signature(func)
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractDispatcher(ABC):
|
||||||
|
"""参数分发器"""
|
||||||
|
|
||||||
|
IGNORED_ATTRS = []
|
||||||
|
|
||||||
|
_args: List[Any] = []
|
||||||
|
_kwargs: Dict[Union[str, Type], Any] = {}
|
||||||
|
_application: "Optional[Application]" = None
|
||||||
|
|
||||||
|
def set_application(self, application: "Application") -> None:
|
||||||
|
self._application = application
|
||||||
|
|
||||||
|
@property
|
||||||
|
def application(self) -> "Application":
|
||||||
|
if self._application is None:
|
||||||
|
raise RuntimeError(f"No application was set for this {self.__class__.__name__}.")
|
||||||
|
return self._application
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
self._args = list(args)
|
||||||
|
self._kwargs = dict(kwargs)
|
||||||
|
|
||||||
|
for _, value in kwargs.items():
|
||||||
|
type_arg = type(value)
|
||||||
|
if type_arg != str:
|
||||||
|
self._kwargs[type_arg] = value
|
||||||
|
|
||||||
|
for arg in args:
|
||||||
|
type_arg = type(arg)
|
||||||
|
if type_arg != str:
|
||||||
|
self._kwargs[type_arg] = arg
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def catch_funcs(self) -> List[MethodType]:
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
return list(
|
||||||
|
ArkoWrapper(dir(self))
|
||||||
|
.filter(lambda x: not x.startswith("_"))
|
||||||
|
.filter(
|
||||||
|
lambda x: x not in self.IGNORED_ATTRS + ["dispatch", "catch_funcs", "catch_func_map", "dispatch_funcs"]
|
||||||
|
)
|
||||||
|
.map(lambda x: getattr(self, x))
|
||||||
|
.filter(lambda x: isinstance(x, MethodType))
|
||||||
|
.filter(lambda x: hasattr(x, "_catch_targets"))
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def catch_func_map(self) -> Dict[Union[str, Type[T]], Callable[..., T]]:
|
||||||
|
result = {}
|
||||||
|
for catch_func in self.catch_funcs:
|
||||||
|
catch_targets = getattr(catch_func, _CATCH_TARGET_ATTR)
|
||||||
|
for catch_target in catch_targets:
|
||||||
|
result[catch_target] = catch_func
|
||||||
|
return result
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def dispatch_funcs(self) -> List[MethodType]:
|
||||||
|
return list(
|
||||||
|
ArkoWrapper(dir(self))
|
||||||
|
.filter(lambda x: x.startswith("dispatch_by_"))
|
||||||
|
.map(lambda x: getattr(self, x))
|
||||||
|
.filter(lambda x: isinstance(x, MethodType))
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def dispatch_by_default(self, parameter: Parameter) -> Parameter:
|
||||||
|
"""默认的 dispatch 方法"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def dispatch_by_catch_funcs(self, parameter: Parameter) -> Parameter:
|
||||||
|
"""使用 catch_func 获取并分配参数"""
|
||||||
|
|
||||||
|
def dispatch(self, func: Callable[P, R]) -> Callable[..., R]:
|
||||||
|
"""将参数分配给函数,从而合成一个无需参数即可执行的函数"""
|
||||||
|
params = {}
|
||||||
|
signature = get_signature(func)
|
||||||
|
parameters: Dict[str, Parameter] = dict(signature.parameters)
|
||||||
|
|
||||||
|
for name, parameter in list(parameters.items()):
|
||||||
|
parameter: Parameter
|
||||||
|
if any(
|
||||||
|
[
|
||||||
|
name == "self" and isinstance(func, (type, MethodType)),
|
||||||
|
parameter.kind in [Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL],
|
||||||
|
]
|
||||||
|
):
|
||||||
|
del parameters[name]
|
||||||
|
continue
|
||||||
|
|
||||||
|
for dispatch_func in self.dispatch_funcs:
|
||||||
|
parameters[name] = dispatch_func(parameter)
|
||||||
|
|
||||||
|
for name, parameter in parameters.items():
|
||||||
|
if parameter.default != Parameter.empty:
|
||||||
|
params[name] = parameter.default
|
||||||
|
else:
|
||||||
|
params[name] = None
|
||||||
|
|
||||||
|
return partial(func, **params)
|
||||||
|
|
||||||
|
@catch(Application)
|
||||||
|
def catch_application(self) -> Application:
|
||||||
|
return self.application
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDispatcher(AbstractDispatcher):
|
||||||
|
"""默认参数分发器"""
|
||||||
|
|
||||||
|
_instances: Sequence[Any]
|
||||||
|
|
||||||
|
def _get_kwargs(self) -> Dict[Type[T], T]:
|
||||||
|
result = self._get_default_kwargs()
|
||||||
|
result[AbstractDispatcher] = self
|
||||||
|
result.update(self._kwargs)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _get_default_kwargs(self) -> Dict[Type[T], T]:
|
||||||
|
application = self.application
|
||||||
|
_default_kwargs = {
|
||||||
|
FastAPI: application.web_app,
|
||||||
|
Server: application.web_server,
|
||||||
|
TelegramApplication: application.telegram,
|
||||||
|
TelegramBot: application.telegram.bot,
|
||||||
|
}
|
||||||
|
if not application.running:
|
||||||
|
for obj in chain(
|
||||||
|
application.managers.dependency,
|
||||||
|
application.managers.components,
|
||||||
|
application.managers.services,
|
||||||
|
application.managers.plugins,
|
||||||
|
):
|
||||||
|
_default_kwargs[type(obj)] = obj
|
||||||
|
return {k: v for k, v in _default_kwargs.items() if v is not None}
|
||||||
|
|
||||||
|
def dispatch_by_default(self, parameter: Parameter) -> Parameter:
|
||||||
|
annotation = parameter.annotation
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
if isinstance(annotation, type) and (value := self._get_kwargs().get(annotation, None)) is not None:
|
||||||
|
parameter._default = value # pylint: disable=W0212
|
||||||
|
return parameter
|
||||||
|
|
||||||
|
def dispatch_by_catch_funcs(self, parameter: Parameter) -> Parameter:
|
||||||
|
annotation = parameter.annotation
|
||||||
|
if annotation != Any and isinstance(annotation, GenericAlias):
|
||||||
|
return parameter
|
||||||
|
|
||||||
|
catch_func = self.catch_func_map.get(annotation) or self.catch_func_map.get(parameter.name)
|
||||||
|
if catch_func is not None:
|
||||||
|
# noinspection PyUnresolvedReferences,PyProtectedMember
|
||||||
|
parameter._default = catch_func() # pylint: disable=W0212
|
||||||
|
return parameter
|
||||||
|
|
||||||
|
@catch(AbstractEventLoop)
|
||||||
|
def catch_loop(self) -> AbstractEventLoop:
|
||||||
|
return asyncio.get_event_loop()
|
||||||
|
|
||||||
|
|
||||||
|
class HandlerDispatcher(BaseDispatcher):
|
||||||
|
"""Handler 参数分发器"""
|
||||||
|
|
||||||
|
def __init__(self, update: Optional[Update] = None, context: Optional[CallbackContext] = None, **kwargs) -> None:
|
||||||
|
super().__init__(update=update, context=context, **kwargs)
|
||||||
|
self._update = update
|
||||||
|
self._context = context
|
||||||
|
|
||||||
|
def dispatch(
|
||||||
|
self, func: Callable[P, R], *, update: Optional[Update] = None, context: Optional[CallbackContext] = None
|
||||||
|
) -> Callable[..., R]:
|
||||||
|
self._update = update or self._update
|
||||||
|
self._context = context or self._context
|
||||||
|
if self._update is None:
|
||||||
|
from core.builtins.contexts import UpdateCV
|
||||||
|
|
||||||
|
self._update = UpdateCV.get()
|
||||||
|
if self._context is None:
|
||||||
|
from core.builtins.contexts import CallbackContextCV
|
||||||
|
|
||||||
|
self._context = CallbackContextCV.get()
|
||||||
|
return super().dispatch(func)
|
||||||
|
|
||||||
|
def dispatch_by_default(self, parameter: Parameter) -> Parameter:
|
||||||
|
"""HandlerDispatcher 默认不使用 dispatch_by_default"""
|
||||||
|
return parameter
|
||||||
|
|
||||||
|
@catch(Update)
|
||||||
|
def catch_update(self) -> Update:
|
||||||
|
return self._update
|
||||||
|
|
||||||
|
@catch(CallbackContext)
|
||||||
|
def catch_context(self) -> CallbackContext:
|
||||||
|
return self._context
|
||||||
|
|
||||||
|
@catch(Message)
|
||||||
|
def catch_message(self) -> Message:
|
||||||
|
return self._update.effective_message
|
||||||
|
|
||||||
|
@catch(User)
|
||||||
|
def catch_user(self) -> User:
|
||||||
|
return self._update.effective_user
|
||||||
|
|
||||||
|
@catch(Chat)
|
||||||
|
def catch_chat(self) -> Chat:
|
||||||
|
return self._update.effective_chat
|
||||||
|
|
||||||
|
|
||||||
|
class JobDispatcher(BaseDispatcher):
|
||||||
|
"""Job 参数分发器"""
|
||||||
|
|
||||||
|
def __init__(self, context: Optional[CallbackContext] = None, **kwargs) -> None:
|
||||||
|
super().__init__(context=context, **kwargs)
|
||||||
|
self._context = context
|
||||||
|
|
||||||
|
def dispatch(self, func: Callable[P, R], *, context: Optional[CallbackContext] = None) -> Callable[..., R]:
|
||||||
|
self._context = context or self._context
|
||||||
|
if self._context is None:
|
||||||
|
from core.builtins.contexts import CallbackContextCV
|
||||||
|
|
||||||
|
self._context = CallbackContextCV.get()
|
||||||
|
return super().dispatch(func)
|
||||||
|
|
||||||
|
@catch("data")
|
||||||
|
def catch_data(self) -> Any:
|
||||||
|
return self._context.job.data
|
||||||
|
|
||||||
|
@catch(Job)
|
||||||
|
def catch_job(self) -> Job:
|
||||||
|
return self._context.job
|
||||||
|
|
||||||
|
@catch(CallbackContext)
|
||||||
|
def catch_context(self) -> CallbackContext:
|
||||||
|
return self._context
|
||||||
|
|
||||||
|
|
||||||
|
def dispatched(dispatcher: Type[AbstractDispatcher] = BaseDispatcher):
|
||||||
|
def decorate(func: Callable[P, R]) -> Callable[P, R]:
|
||||||
|
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
||||||
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
|
return dispatcher().dispatch(func)(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorate
|
131
core/builtins/executor.py
Normal file
131
core/builtins/executor.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
"""执行器"""
|
||||||
|
import inspect
|
||||||
|
from functools import cached_property
|
||||||
|
from multiprocessing import RLock as Lock
|
||||||
|
from typing import Callable, ClassVar, Dict, Generic, Optional, TYPE_CHECKING, Type, TypeVar
|
||||||
|
|
||||||
|
from telegram import Update
|
||||||
|
from telegram.ext import CallbackContext
|
||||||
|
from typing_extensions import ParamSpec, Self
|
||||||
|
|
||||||
|
from core.builtins.contexts import handler_contexts, job_contexts
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.application import Application
|
||||||
|
from core.builtins.dispatcher import AbstractDispatcher, HandlerDispatcher
|
||||||
|
from multiprocessing.synchronize import RLock as LockType
|
||||||
|
|
||||||
|
__all__ = ("BaseExecutor", "Executor", "HandlerExecutor", "JobExecutor")
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
R = TypeVar("R")
|
||||||
|
P = ParamSpec("P")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseExecutor:
|
||||||
|
"""执行器
|
||||||
|
Args:
|
||||||
|
name(str): 该执行器的名称。执行器的名称是唯一的。
|
||||||
|
|
||||||
|
只支持执行只拥有 POSITIONAL_OR_KEYWORD 和 KEYWORD_ONLY 两种参数类型的函数
|
||||||
|
"""
|
||||||
|
|
||||||
|
_lock: ClassVar["LockType"] = Lock()
|
||||||
|
_instances: ClassVar[Dict[str, Self]] = {}
|
||||||
|
_application: "Optional[Application]" = None
|
||||||
|
|
||||||
|
def set_application(self, application: "Application") -> None:
|
||||||
|
self._application = application
|
||||||
|
|
||||||
|
@property
|
||||||
|
def application(self) -> "Application":
|
||||||
|
if self._application is None:
|
||||||
|
raise RuntimeError(f"No application was set for this {self.__class__.__name__}.")
|
||||||
|
return self._application
|
||||||
|
|
||||||
|
def __new__(cls: Type[T], name: str, *args, **kwargs) -> T:
|
||||||
|
with cls._lock:
|
||||||
|
if (instance := cls._instances.get(name)) is None:
|
||||||
|
instance = object.__new__(cls)
|
||||||
|
instance.__init__(name, *args, **kwargs)
|
||||||
|
cls._instances.update({name: instance})
|
||||||
|
return instance
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""当前执行器的名称"""
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
def __init__(self, name: str, dispatcher: Optional[Type["AbstractDispatcher"]] = None) -> None:
|
||||||
|
self._name = name
|
||||||
|
self._dispatcher = dispatcher
|
||||||
|
|
||||||
|
|
||||||
|
class Executor(BaseExecutor, Generic[P, R]):
|
||||||
|
async def __call__(
|
||||||
|
self,
|
||||||
|
target: Callable[P, R],
|
||||||
|
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> R:
|
||||||
|
dispatcher = self._dispatcher or dispatcher
|
||||||
|
dispatcher_instance = dispatcher(**kwargs)
|
||||||
|
dispatcher_instance.set_application(application=self.application)
|
||||||
|
dispatched_func = dispatcher_instance.dispatch(target) # 分发参数,组成新函数
|
||||||
|
|
||||||
|
# 执行
|
||||||
|
if inspect.iscoroutinefunction(target):
|
||||||
|
result = await dispatched_func()
|
||||||
|
else:
|
||||||
|
result = dispatched_func()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class HandlerExecutor(BaseExecutor, Generic[P, R]):
|
||||||
|
"""Handler专用执行器"""
|
||||||
|
|
||||||
|
_callback: Callable[P, R]
|
||||||
|
_dispatcher: "HandlerDispatcher"
|
||||||
|
|
||||||
|
def __init__(self, func: Callable[P, R], dispatcher: Optional[Type["HandlerDispatcher"]] = None) -> None:
|
||||||
|
if dispatcher is None:
|
||||||
|
from core.builtins.dispatcher import HandlerDispatcher
|
||||||
|
|
||||||
|
dispatcher = HandlerDispatcher
|
||||||
|
super().__init__("handler", dispatcher)
|
||||||
|
self._callback = func
|
||||||
|
self._dispatcher = dispatcher()
|
||||||
|
|
||||||
|
def set_application(self, application: "Application") -> None:
|
||||||
|
self._application = application
|
||||||
|
if self._dispatcher is not None:
|
||||||
|
self._dispatcher.set_application(application)
|
||||||
|
|
||||||
|
async def __call__(self, update: Update, context: CallbackContext) -> R:
|
||||||
|
with handler_contexts(update, context):
|
||||||
|
dispatched_func = self._dispatcher.dispatch(self._callback, update=update, context=context)
|
||||||
|
return await dispatched_func()
|
||||||
|
|
||||||
|
|
||||||
|
class JobExecutor(BaseExecutor):
|
||||||
|
"""Job 专用执行器"""
|
||||||
|
|
||||||
|
def __init__(self, func: Callable[P, R], dispatcher: Optional[Type["AbstractDispatcher"]] = None) -> None:
|
||||||
|
if dispatcher is None:
|
||||||
|
from core.builtins.dispatcher import JobDispatcher
|
||||||
|
|
||||||
|
dispatcher = JobDispatcher
|
||||||
|
super().__init__("job", dispatcher)
|
||||||
|
self._callback = func
|
||||||
|
self._dispatcher = dispatcher()
|
||||||
|
|
||||||
|
def set_application(self, application: "Application") -> None:
|
||||||
|
self._application = application
|
||||||
|
if self._dispatcher is not None:
|
||||||
|
self._dispatcher.set_application(application)
|
||||||
|
|
||||||
|
async def __call__(self, context: CallbackContext) -> R:
|
||||||
|
with job_contexts(context):
|
||||||
|
dispatched_func = self._dispatcher.dispatch(self._callback, context=context)
|
||||||
|
return await dispatched_func()
|
185
core/builtins/reloader.py
Normal file
185
core/builtins/reloader.py
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
import inspect
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import threading
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Iterator, List, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
from watchfiles import watch
|
||||||
|
|
||||||
|
from utils.const import HANDLED_SIGNALS, PROJECT_ROOT
|
||||||
|
from utils.log import logger
|
||||||
|
from utils.typedefs import StrOrPath
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from multiprocessing.process import BaseProcess
|
||||||
|
|
||||||
|
__all__ = ("Reloader",)
|
||||||
|
|
||||||
|
multiprocessing.allow_connection_pickling()
|
||||||
|
spawn = multiprocessing.get_context("spawn")
|
||||||
|
|
||||||
|
|
||||||
|
class FileFilter:
|
||||||
|
"""监控文件过滤"""
|
||||||
|
|
||||||
|
def __init__(self, includes: List[str], excludes: List[str]) -> None:
|
||||||
|
default_includes = ["*.py"]
|
||||||
|
self.includes = [default for default in default_includes if default not in excludes]
|
||||||
|
self.includes.extend(includes)
|
||||||
|
self.includes = list(set(self.includes))
|
||||||
|
|
||||||
|
default_excludes = [".*", ".py[cod]", ".sw.*", "~*", __file__]
|
||||||
|
self.excludes = [default for default in default_excludes if default not in includes]
|
||||||
|
self.exclude_dirs = []
|
||||||
|
for e in excludes:
|
||||||
|
p = Path(e)
|
||||||
|
try:
|
||||||
|
is_dir = p.is_dir()
|
||||||
|
except OSError:
|
||||||
|
is_dir = False
|
||||||
|
|
||||||
|
if is_dir:
|
||||||
|
self.exclude_dirs.append(p)
|
||||||
|
else:
|
||||||
|
self.excludes.append(e)
|
||||||
|
self.excludes = list(set(self.excludes))
|
||||||
|
|
||||||
|
def __call__(self, path: Path) -> bool:
|
||||||
|
for include_pattern in self.includes:
|
||||||
|
if path.match(include_pattern):
|
||||||
|
for exclude_dir in self.exclude_dirs:
|
||||||
|
if exclude_dir in path.parents:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for exclude_pattern in self.excludes:
|
||||||
|
if path.match(exclude_pattern):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class Reloader:
|
||||||
|
_target: Callable[..., None]
|
||||||
|
_process: "BaseProcess"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def process(self) -> "BaseProcess":
|
||||||
|
return self._process
|
||||||
|
|
||||||
|
@property
|
||||||
|
def target(self) -> Callable[..., None]:
|
||||||
|
return self._target
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: Callable[..., None],
|
||||||
|
*,
|
||||||
|
reload_delay: float = 0.25,
|
||||||
|
reload_dirs: List[StrOrPath] = None,
|
||||||
|
reload_includes: List[str] = None,
|
||||||
|
reload_excludes: List[str] = None,
|
||||||
|
):
|
||||||
|
if inspect.iscoroutinefunction(target):
|
||||||
|
raise ValueError("不支持异步函数")
|
||||||
|
self._target = target
|
||||||
|
|
||||||
|
self.reload_delay = reload_delay
|
||||||
|
|
||||||
|
_reload_dirs = []
|
||||||
|
for reload_dir in reload_dirs or []:
|
||||||
|
_reload_dirs.append(PROJECT_ROOT.joinpath(Path(reload_dir)))
|
||||||
|
|
||||||
|
self.reload_dirs = []
|
||||||
|
for reload_dir in _reload_dirs:
|
||||||
|
append = True
|
||||||
|
for parent in reload_dir.parents:
|
||||||
|
if parent in _reload_dirs:
|
||||||
|
append = False
|
||||||
|
break
|
||||||
|
if append:
|
||||||
|
self.reload_dirs.append(reload_dir)
|
||||||
|
|
||||||
|
if not self.reload_dirs:
|
||||||
|
logger.warning("需要检测的目标文件夹列表为空", extra={"tag": "Reloader"})
|
||||||
|
|
||||||
|
self._should_exit = threading.Event()
|
||||||
|
|
||||||
|
frame = inspect.currentframe().f_back
|
||||||
|
|
||||||
|
self.watch_filter = FileFilter(reload_includes or [], (reload_excludes or []) + [frame.f_globals["__file__"]])
|
||||||
|
self.watcher = watch(
|
||||||
|
*self.reload_dirs,
|
||||||
|
watch_filter=None,
|
||||||
|
stop_event=self._should_exit,
|
||||||
|
yield_on_timeout=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_changes(self) -> Optional[List[Path]]:
|
||||||
|
if not self._process.is_alive():
|
||||||
|
logger.info("目标进程已经关闭", extra={"tag": "Reloader"})
|
||||||
|
self._should_exit.set()
|
||||||
|
try:
|
||||||
|
changes = next(self.watcher)
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
if changes:
|
||||||
|
unique_paths = {Path(c[1]) for c in changes}
|
||||||
|
return [p for p in unique_paths if self.watch_filter(p)]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[Optional[List[Path]]]:
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self) -> Optional[List[Path]]:
|
||||||
|
return self.get_changes()
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
self.startup()
|
||||||
|
for changes in self:
|
||||||
|
if changes:
|
||||||
|
logger.warning(
|
||||||
|
"检测到文件 %s 发生改变, 正在重载...",
|
||||||
|
[str(c.relative_to(PROJECT_ROOT)).replace(os.sep, "/") for c in changes],
|
||||||
|
extra={"tag": "Reloader"},
|
||||||
|
)
|
||||||
|
self.restart()
|
||||||
|
|
||||||
|
self.shutdown()
|
||||||
|
|
||||||
|
def signal_handler(self, *_) -> None:
|
||||||
|
"""当接收到结束信号量时"""
|
||||||
|
self._process.join(3)
|
||||||
|
if self._process.is_alive():
|
||||||
|
self._process.terminate()
|
||||||
|
self._process.join()
|
||||||
|
self._should_exit.set()
|
||||||
|
|
||||||
|
def startup(self) -> None:
|
||||||
|
"""启动进程"""
|
||||||
|
logger.info("目标进程正在启动", extra={"tag": "Reloader"})
|
||||||
|
|
||||||
|
for sig in HANDLED_SIGNALS:
|
||||||
|
signal.signal(sig, self.signal_handler)
|
||||||
|
|
||||||
|
self._process = spawn.Process(target=self._target)
|
||||||
|
self._process.start()
|
||||||
|
logger.success("目标进程启动成功", extra={"tag": "Reloader"})
|
||||||
|
|
||||||
|
def restart(self) -> None:
|
||||||
|
"""重启进程"""
|
||||||
|
self._process.terminate()
|
||||||
|
self._process.join(10)
|
||||||
|
|
||||||
|
self._process = spawn.Process(target=self._target)
|
||||||
|
self._process.start()
|
||||||
|
logger.info("目标进程已经重载", extra={"tag": "Reloader"})
|
||||||
|
|
||||||
|
def shutdown(self) -> None:
|
||||||
|
"""关闭进程"""
|
||||||
|
self._process.terminate()
|
||||||
|
self._process.join(10)
|
||||||
|
|
||||||
|
logger.info("重载器已经关闭", extra={"tag": "Reloader"})
|
@ -1,19 +1,15 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import List, Optional, Union
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
from pydantic import AnyUrl, BaseModel, Field
|
from pydantic import AnyUrl, Field
|
||||||
|
|
||||||
|
from core.basemodel import Settings
|
||||||
from utils.const import PROJECT_ROOT
|
from utils.const import PROJECT_ROOT
|
||||||
from utils.models.base import Settings
|
|
||||||
from utils.typedefs import NaturalNumber
|
from utils.typedefs import NaturalNumber
|
||||||
|
|
||||||
__all__ = ["BotConfig", "config", "JoinGroups"]
|
__all__ = ("ApplicationConfig", "config", "JoinGroups")
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
@ -25,22 +21,12 @@ class JoinGroups(str, Enum):
|
|||||||
ALLOW_ALL = "ALLOW_ALL"
|
ALLOW_ALL = "ALLOW_ALL"
|
||||||
|
|
||||||
|
|
||||||
class ConfigChannel(BaseModel):
|
|
||||||
name: str
|
|
||||||
chat_id: int
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigUser(BaseModel):
|
|
||||||
username: Optional[str]
|
|
||||||
user_id: int
|
|
||||||
|
|
||||||
|
|
||||||
class MySqlConfig(Settings):
|
class MySqlConfig(Settings):
|
||||||
host: str = "127.0.0.1"
|
host: str = "127.0.0.1"
|
||||||
port: int = 3306
|
port: int = 3306
|
||||||
username: str = None
|
username: Optional[str] = None
|
||||||
password: str = None
|
password: Optional[str] = None
|
||||||
database: str = None
|
database: Optional[str] = None
|
||||||
|
|
||||||
class Config(Settings.Config):
|
class Config(Settings.Config):
|
||||||
env_prefix = "db_"
|
env_prefix = "db_"
|
||||||
@ -58,7 +44,7 @@ class RedisConfig(Settings):
|
|||||||
|
|
||||||
class LoggerConfig(Settings):
|
class LoggerConfig(Settings):
|
||||||
name: str = "TGPaimon"
|
name: str = "TGPaimon"
|
||||||
width: int = 180
|
width: Optional[int] = None
|
||||||
time_format: str = "[%Y-%m-%d %X]"
|
time_format: str = "[%Y-%m-%d %X]"
|
||||||
traceback_max_frames: int = 20
|
traceback_max_frames: int = 20
|
||||||
path: Path = PROJECT_ROOT / "logs"
|
path: Path = PROJECT_ROOT / "logs"
|
||||||
@ -78,6 +64,9 @@ class MTProtoConfig(Settings):
|
|||||||
|
|
||||||
|
|
||||||
class WebServerConfig(Settings):
|
class WebServerConfig(Settings):
|
||||||
|
enable: bool = False
|
||||||
|
"""是否启用WebServer"""
|
||||||
|
|
||||||
url: AnyUrl = "http://localhost:8080"
|
url: AnyUrl = "http://localhost:8080"
|
||||||
host: str = "localhost"
|
host: str = "localhost"
|
||||||
port: int = 8080
|
port: int = 8080
|
||||||
@ -97,6 +86,16 @@ class ErrorConfig(Settings):
|
|||||||
env_prefix = "error_"
|
env_prefix = "error_"
|
||||||
|
|
||||||
|
|
||||||
|
class ReloadConfig(Settings):
|
||||||
|
delay: float = 0.25
|
||||||
|
dirs: List[str] = []
|
||||||
|
include: List[str] = []
|
||||||
|
exclude: List[str] = []
|
||||||
|
|
||||||
|
class Config(Settings.Config):
|
||||||
|
env_prefix = "reload_"
|
||||||
|
|
||||||
|
|
||||||
class NoticeConfig(Settings):
|
class NoticeConfig(Settings):
|
||||||
user_mismatch: str = "再乱点我叫西风骑士团、千岩军、天领奉行、三十人团和风纪官了!"
|
user_mismatch: str = "再乱点我叫西风骑士团、千岩军、天领奉行、三十人团和风纪官了!"
|
||||||
|
|
||||||
@ -104,24 +103,32 @@ class NoticeConfig(Settings):
|
|||||||
env_prefix = "notice_"
|
env_prefix = "notice_"
|
||||||
|
|
||||||
|
|
||||||
class PluginConfig(Settings):
|
class ApplicationConfig(Settings):
|
||||||
download_file_max_size: int = 5
|
|
||||||
|
|
||||||
class Config(Settings.Config):
|
|
||||||
env_prefix = "plugin_"
|
|
||||||
|
|
||||||
|
|
||||||
class BotConfig(Settings):
|
|
||||||
debug: bool = False
|
debug: bool = False
|
||||||
|
"""debug 开关"""
|
||||||
|
retry: int = 5
|
||||||
|
"""重试次数"""
|
||||||
|
auto_reload: bool = False
|
||||||
|
"""自动重载"""
|
||||||
|
|
||||||
|
proxy_url: Optional[AnyUrl] = None
|
||||||
|
"""代理链接"""
|
||||||
|
|
||||||
bot_token: str = ""
|
bot_token: str = ""
|
||||||
|
"""BOT的token"""
|
||||||
|
|
||||||
|
owner: Optional[int] = None
|
||||||
|
|
||||||
|
channels: List[int] = []
|
||||||
|
"""文章推送群组"""
|
||||||
|
|
||||||
channels: List["ConfigChannel"] = []
|
|
||||||
admins: List["ConfigUser"] = []
|
|
||||||
verify_groups: List[Union[int, str]] = []
|
verify_groups: List[Union[int, str]] = []
|
||||||
|
"""启用群验证功能的群组"""
|
||||||
join_groups: Optional[JoinGroups] = JoinGroups.NO_ALLOW
|
join_groups: Optional[JoinGroups] = JoinGroups.NO_ALLOW
|
||||||
|
"""是否允许机器人被邀请到其它群组"""
|
||||||
|
|
||||||
timeout: int = 10
|
timeout: int = 10
|
||||||
|
connection_pool_size: int = 256
|
||||||
read_timeout: Optional[float] = None
|
read_timeout: Optional[float] = None
|
||||||
write_timeout: Optional[float] = None
|
write_timeout: Optional[float] = None
|
||||||
connect_timeout: Optional[float] = None
|
connect_timeout: Optional[float] = None
|
||||||
@ -138,6 +145,7 @@ class BotConfig(Settings):
|
|||||||
pass_challenge_app_key: str = ""
|
pass_challenge_app_key: str = ""
|
||||||
pass_challenge_user_web: str = ""
|
pass_challenge_user_web: str = ""
|
||||||
|
|
||||||
|
reload: ReloadConfig = ReloadConfig()
|
||||||
mysql: MySqlConfig = MySqlConfig()
|
mysql: MySqlConfig = MySqlConfig()
|
||||||
logger: LoggerConfig = LoggerConfig()
|
logger: LoggerConfig = LoggerConfig()
|
||||||
webserver: WebServerConfig = WebServerConfig()
|
webserver: WebServerConfig = WebServerConfig()
|
||||||
@ -145,8 +153,7 @@ class BotConfig(Settings):
|
|||||||
mtproto: MTProtoConfig = MTProtoConfig()
|
mtproto: MTProtoConfig = MTProtoConfig()
|
||||||
error: ErrorConfig = ErrorConfig()
|
error: ErrorConfig = ErrorConfig()
|
||||||
notice: NoticeConfig = NoticeConfig()
|
notice: NoticeConfig = NoticeConfig()
|
||||||
plugin: PluginConfig = PluginConfig()
|
|
||||||
|
|
||||||
|
|
||||||
BotConfig.update_forward_refs()
|
ApplicationConfig.update_forward_refs()
|
||||||
config = BotConfig()
|
config = ApplicationConfig()
|
||||||
|
@ -1,21 +0,0 @@
|
|||||||
from core.base.mysql import MySQL
|
|
||||||
from core.base.redisdb import RedisDB
|
|
||||||
from core.cookies.cache import PublicCookiesCache
|
|
||||||
from core.cookies.repositories import CookiesRepository
|
|
||||||
from core.cookies.services import CookiesService, PublicCookiesService
|
|
||||||
from core.service import init_service
|
|
||||||
|
|
||||||
|
|
||||||
@init_service
|
|
||||||
def create_cookie_service(mysql: MySQL):
|
|
||||||
_repository = CookiesRepository(mysql)
|
|
||||||
_service = CookiesService(_repository)
|
|
||||||
return _service
|
|
||||||
|
|
||||||
|
|
||||||
@init_service
|
|
||||||
def create_public_cookie_service(mysql: MySQL, redis: RedisDB):
|
|
||||||
_repository = CookiesRepository(mysql)
|
|
||||||
_cache = PublicCookiesCache(redis)
|
|
||||||
_service = PublicCookiesService(_repository, _cache)
|
|
||||||
return _service
|
|
@ -1,27 +0,0 @@
|
|||||||
import enum
|
|
||||||
from typing import Optional, Dict
|
|
||||||
|
|
||||||
from sqlmodel import SQLModel, Field, JSON, Enum, Column
|
|
||||||
|
|
||||||
|
|
||||||
class CookiesStatusEnum(int, enum.Enum):
|
|
||||||
STATUS_SUCCESS = 0
|
|
||||||
INVALID_COOKIES = 1
|
|
||||||
TOO_MANY_REQUESTS = 2
|
|
||||||
|
|
||||||
|
|
||||||
class Cookies(SQLModel):
|
|
||||||
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
|
||||||
|
|
||||||
id: int = Field(primary_key=True)
|
|
||||||
user_id: Optional[int] = Field(foreign_key="user.user_id")
|
|
||||||
cookies: Optional[Dict[str, str]] = Field(sa_column=Column(JSON))
|
|
||||||
status: Optional[CookiesStatusEnum] = Field(sa_column=Column(Enum(CookiesStatusEnum)))
|
|
||||||
|
|
||||||
|
|
||||||
class HyperionCookie(Cookies, table=True):
|
|
||||||
__tablename__ = "mihoyo_cookies"
|
|
||||||
|
|
||||||
|
|
||||||
class HoyolabCookie(Cookies, table=True):
|
|
||||||
__tablename__ = "hoyoverse_cookies"
|
|
@ -1,109 +0,0 @@
|
|||||||
from typing import cast, List
|
|
||||||
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.exc import NoResultFound
|
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
from core.base.mysql import MySQL
|
|
||||||
from utils.error import RegionNotFoundError
|
|
||||||
from utils.models.base import RegionEnum
|
|
||||||
from .error import CookiesNotFoundError
|
|
||||||
from .models import HyperionCookie, HoyolabCookie, Cookies
|
|
||||||
|
|
||||||
|
|
||||||
class CookiesRepository:
|
|
||||||
def __init__(self, mysql: MySQL):
|
|
||||||
self.mysql = mysql
|
|
||||||
|
|
||||||
async def add_cookies(self, user_id: int, cookies: dict, region: RegionEnum):
|
|
||||||
async with self.mysql.Session() as session:
|
|
||||||
session = cast(AsyncSession, session)
|
|
||||||
if region == RegionEnum.HYPERION:
|
|
||||||
db_data = HyperionCookie(user_id=user_id, cookies=cookies)
|
|
||||||
elif region == RegionEnum.HOYOLAB:
|
|
||||||
db_data = HoyolabCookie(user_id=user_id, cookies=cookies)
|
|
||||||
else:
|
|
||||||
raise RegionNotFoundError(region.name)
|
|
||||||
session.add(db_data)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
async def update_cookies(self, user_id: int, cookies: dict, region: RegionEnum):
|
|
||||||
async with self.mysql.Session() as session:
|
|
||||||
session = cast(AsyncSession, session)
|
|
||||||
if region == RegionEnum.HYPERION:
|
|
||||||
statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id)
|
|
||||||
elif region == RegionEnum.HOYOLAB:
|
|
||||||
statement = select(HoyolabCookie).where(HoyolabCookie.user_id == user_id)
|
|
||||||
else:
|
|
||||||
raise RegionNotFoundError(region.name)
|
|
||||||
results = await session.exec(statement)
|
|
||||||
db_cookies = results.first()
|
|
||||||
if db_cookies is None:
|
|
||||||
raise CookiesNotFoundError(user_id)
|
|
||||||
db_cookies = db_cookies[0]
|
|
||||||
db_cookies.cookies = cookies
|
|
||||||
session.add(db_cookies)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(db_cookies)
|
|
||||||
|
|
||||||
async def update_cookies_ex(self, cookies: Cookies, region: RegionEnum):
|
|
||||||
async with self.mysql.Session() as session:
|
|
||||||
session = cast(AsyncSession, session)
|
|
||||||
if region not in [RegionEnum.HYPERION, RegionEnum.HOYOLAB]:
|
|
||||||
raise RegionNotFoundError(region.name)
|
|
||||||
session.add(cookies)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(cookies)
|
|
||||||
|
|
||||||
async def get_cookies(self, user_id, region: RegionEnum) -> Cookies:
|
|
||||||
async with self.mysql.Session() as session:
|
|
||||||
session = cast(AsyncSession, session)
|
|
||||||
if region == RegionEnum.HYPERION:
|
|
||||||
statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id)
|
|
||||||
results = await session.exec(statement)
|
|
||||||
db_cookies = results.first()
|
|
||||||
if db_cookies is None:
|
|
||||||
raise CookiesNotFoundError(user_id)
|
|
||||||
return db_cookies[0]
|
|
||||||
elif region == RegionEnum.HOYOLAB:
|
|
||||||
statement = select(HoyolabCookie).where(HoyolabCookie.user_id == user_id)
|
|
||||||
results = await session.exec(statement)
|
|
||||||
db_cookies = results.first()
|
|
||||||
if db_cookies is None:
|
|
||||||
raise CookiesNotFoundError(user_id)
|
|
||||||
return db_cookies[0]
|
|
||||||
else:
|
|
||||||
raise RegionNotFoundError(region.name)
|
|
||||||
|
|
||||||
async def get_all_cookies(self, region: RegionEnum) -> List[Cookies]:
|
|
||||||
async with self.mysql.Session() as session:
|
|
||||||
session = cast(AsyncSession, session)
|
|
||||||
if region == RegionEnum.HYPERION:
|
|
||||||
statement = select(HyperionCookie)
|
|
||||||
results = await session.exec(statement)
|
|
||||||
db_cookies = results.all()
|
|
||||||
return [cookies[0] for cookies in db_cookies]
|
|
||||||
elif region == RegionEnum.HOYOLAB:
|
|
||||||
statement = select(HoyolabCookie)
|
|
||||||
results = await session.exec(statement)
|
|
||||||
db_cookies = results.all()
|
|
||||||
return [cookies[0] for cookies in db_cookies]
|
|
||||||
else:
|
|
||||||
raise RegionNotFoundError(region.name)
|
|
||||||
|
|
||||||
async def del_cookies(self, user_id, region: RegionEnum):
|
|
||||||
async with self.mysql.Session() as session:
|
|
||||||
session = cast(AsyncSession, session)
|
|
||||||
if region == RegionEnum.HYPERION:
|
|
||||||
statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id)
|
|
||||||
elif region == RegionEnum.HOYOLAB:
|
|
||||||
statement = select(HoyolabCookie).where(HoyolabCookie.user_id == user_id)
|
|
||||||
else:
|
|
||||||
raise RegionNotFoundError(region.name)
|
|
||||||
results = await session.execute(statement)
|
|
||||||
try:
|
|
||||||
db_cookies = results.unique().scalar_one()
|
|
||||||
except NoResultFound as exc:
|
|
||||||
raise CookiesNotFoundError(user_id) from exc
|
|
||||||
await session.delete(db_cookies)
|
|
||||||
await session.commit()
|
|
1
core/dependence/__init__.py
Normal file
1
core/dependence/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""基础服务"""
|
@ -1,26 +1,40 @@
|
|||||||
from typing import Optional
|
from typing import Optional, TYPE_CHECKING
|
||||||
|
|
||||||
from playwright.async_api import Browser, Playwright, async_playwright, Error
|
from playwright.async_api import Error, async_playwright
|
||||||
|
|
||||||
from core.service import Service
|
from core.base_service import BaseService
|
||||||
from utils.log import logger
|
from utils.log import logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from playwright.async_api import Playwright as AsyncPlaywright, Browser
|
||||||
|
|
||||||
|
__all__ = ("AioBrowser",)
|
||||||
|
|
||||||
|
|
||||||
|
class AioBrowser(BaseService.Dependence):
|
||||||
|
@property
|
||||||
|
def browser(self):
|
||||||
|
return self._browser
|
||||||
|
|
||||||
class AioBrowser(Service):
|
|
||||||
def __init__(self, loop=None):
|
def __init__(self, loop=None):
|
||||||
self.browser: Optional[Browser] = None
|
self._browser: Optional["Browser"] = None
|
||||||
self._playwright: Optional[Playwright] = None
|
self._playwright: Optional["AsyncPlaywright"] = None
|
||||||
self._loop = loop
|
self._loop = loop
|
||||||
|
|
||||||
async def start(self):
|
async def get_browser(self):
|
||||||
|
if self._browser is None:
|
||||||
|
await self.initialize()
|
||||||
|
return self._browser
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
if self._playwright is None:
|
if self._playwright is None:
|
||||||
logger.info("正在尝试启动 [blue]Playwright[/]", extra={"markup": True})
|
logger.info("正在尝试启动 [blue]Playwright[/]", extra={"markup": True})
|
||||||
self._playwright = await async_playwright().start()
|
self._playwright = await async_playwright().start()
|
||||||
logger.success("[blue]Playwright[/] 启动成功", extra={"markup": True})
|
logger.success("[blue]Playwright[/] 启动成功", extra={"markup": True})
|
||||||
if self.browser is None:
|
if self._browser is None:
|
||||||
logger.info("正在尝试启动 [blue]Browser[/]", extra={"markup": True})
|
logger.info("正在尝试启动 [blue]Browser[/]", extra={"markup": True})
|
||||||
try:
|
try:
|
||||||
self.browser = await self._playwright.chromium.launch(timeout=5000)
|
self._browser = await self._playwright.chromium.launch(timeout=5000)
|
||||||
logger.success("[blue]Browser[/] 启动成功", extra={"markup": True})
|
logger.success("[blue]Browser[/] 启动成功", extra={"markup": True})
|
||||||
except Error as err:
|
except Error as err:
|
||||||
if "playwright install" in str(err):
|
if "playwright install" in str(err):
|
||||||
@ -33,15 +47,10 @@ class AioBrowser(Service):
|
|||||||
raise RuntimeError("检查到 playwright 刚刚安装或者未升级\n请运行以下命令下载新浏览器\nplaywright install chromium")
|
raise RuntimeError("检查到 playwright 刚刚安装或者未升级\n请运行以下命令下载新浏览器\nplaywright install chromium")
|
||||||
raise err
|
raise err
|
||||||
|
|
||||||
return self.browser
|
return self._browser
|
||||||
|
|
||||||
async def stop(self):
|
async def shutdown(self):
|
||||||
if self.browser is not None:
|
if self._browser is not None:
|
||||||
await self.browser.close()
|
await self._browser.close()
|
||||||
if self._playwright is not None:
|
if self._playwright is not None:
|
||||||
await self._playwright.stop()
|
self._playwright.stop()
|
||||||
|
|
||||||
async def get_browser(self) -> Browser:
|
|
||||||
if self.browser is None:
|
|
||||||
await self.start()
|
|
||||||
return self.browser
|
|
16
core/dependence/aiobrowser.pyi
Normal file
16
core/dependence/aiobrowser.pyi
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from asyncio import AbstractEventLoop
|
||||||
|
|
||||||
|
from playwright.async_api import Browser, Playwright as AsyncPlaywright
|
||||||
|
|
||||||
|
from core.base_service import BaseService
|
||||||
|
|
||||||
|
__all__ = ("AioBrowser",)
|
||||||
|
|
||||||
|
class AioBrowser(BaseService.Dependence):
|
||||||
|
_browser: Browser | None
|
||||||
|
_playwright: AsyncPlaywright | None
|
||||||
|
_loop: AbstractEventLoop
|
||||||
|
|
||||||
|
@property
|
||||||
|
def browser(self) -> Browser | None: ...
|
||||||
|
async def get_browser(self) -> Browser: ...
|
@ -17,7 +17,7 @@ from enkanetwork.model.assets import CharacterAsset as EnkaCharacterAsset
|
|||||||
from httpx import AsyncClient, HTTPError, HTTPStatusError, TransportError, URL
|
from httpx import AsyncClient, HTTPError, HTTPStatusError, TransportError, URL
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from core.service import Service
|
from core.base_service import BaseService
|
||||||
from metadata.genshin import AVATAR_DATA, HONEY_DATA, MATERIAL_DATA, NAMECARD_DATA, WEAPON_DATA
|
from metadata.genshin import AVATAR_DATA, HONEY_DATA, MATERIAL_DATA, NAMECARD_DATA, WEAPON_DATA
|
||||||
from metadata.scripts.honey import update_honey_metadata
|
from metadata.scripts.honey import update_honey_metadata
|
||||||
from metadata.scripts.metadatas import update_metadata_from_ambr, update_metadata_from_github
|
from metadata.scripts.metadatas import update_metadata_from_ambr, update_metadata_from_github
|
||||||
@ -31,6 +31,8 @@ if TYPE_CHECKING:
|
|||||||
from httpx import Response
|
from httpx import Response
|
||||||
from multiprocessing.synchronize import RLock
|
from multiprocessing.synchronize import RLock
|
||||||
|
|
||||||
|
__all__ = ("AssetsServiceType", "AssetsService", "AssetsServiceError", "AssetsCouldNotFound", "DEFAULT_EnkaAssets")
|
||||||
|
|
||||||
ICON_TYPE = Union[Callable[[bool], Awaitable[Optional[Path]]], Callable[..., Awaitable[Optional[Path]]]]
|
ICON_TYPE = Union[Callable[[bool], Awaitable[Optional[Path]]], Callable[..., Awaitable[Optional[Path]]]]
|
||||||
NAME_MAP_TYPE = Dict[str, StrOrURL]
|
NAME_MAP_TYPE = Dict[str, StrOrURL]
|
||||||
|
|
||||||
@ -127,7 +129,7 @@ class _AssetsService(ABC):
|
|||||||
|
|
||||||
async def _download(self, url: StrOrURL, path: Path, retry: int = 5) -> Path | None:
|
async def _download(self, url: StrOrURL, path: Path, retry: int = 5) -> Path | None:
|
||||||
"""从 url 下载图标至 path"""
|
"""从 url 下载图标至 path"""
|
||||||
logger.debug(f"正在从 {url} 下载图标至 {path}")
|
logger.debug("正在从 %s 下载图标至 %s", url, path)
|
||||||
headers = {"user-agent": "TGPaimonBot/3.0"} if URL(url).host == "enka.network" else None
|
headers = {"user-agent": "TGPaimonBot/3.0"} if URL(url).host == "enka.network" else None
|
||||||
for time in range(retry):
|
for time in range(retry):
|
||||||
try:
|
try:
|
||||||
@ -204,8 +206,8 @@ class _AssetsService(ABC):
|
|||||||
"""魔法"""
|
"""魔法"""
|
||||||
if item in self.icon_types:
|
if item in self.icon_types:
|
||||||
return partial(self._get_img, item=item)
|
return partial(self._get_img, item=item)
|
||||||
else:
|
object.__getattribute__(self, item)
|
||||||
object.__getattribute__(self, item)
|
return None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@cached_property
|
@cached_property
|
||||||
@ -498,7 +500,7 @@ class _NamecardAssets(_AssetsService):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class AssetsService(Service):
|
class AssetsService(BaseService.Dependence):
|
||||||
"""asset服务
|
"""asset服务
|
||||||
|
|
||||||
用于储存和管理 asset :
|
用于储存和管理 asset :
|
||||||
@ -527,8 +529,10 @@ class AssetsService(Service):
|
|||||||
):
|
):
|
||||||
setattr(self, attr, globals()[assets_type_name]())
|
setattr(self, attr, globals()[assets_type_name]())
|
||||||
|
|
||||||
async def start(self): # pylint: disable=R0201
|
async def initialize(self) -> None: # pylint: disable=R0201
|
||||||
|
"""启动 AssetsService 服务,刷新元数据"""
|
||||||
logger.info("正在刷新元数据")
|
logger.info("正在刷新元数据")
|
||||||
|
# todo 这3个任务同时异步下载
|
||||||
await update_metadata_from_github(False)
|
await update_metadata_from_github(False)
|
||||||
await update_metadata_from_ambr(False)
|
await update_metadata_from_ambr(False)
|
||||||
await update_honey_metadata(False)
|
await update_honey_metadata(False)
|
167
core/dependence/assets.pyi
Normal file
167
core/dependence/assets.pyi
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Awaitable, Callable, ClassVar, TypeVar
|
||||||
|
|
||||||
|
from enkanetwork import Assets as EnkaAssets
|
||||||
|
from enkanetwork.model.assets import CharacterAsset as EnkaCharacterAsset
|
||||||
|
from httpx import AsyncClient
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from core.base_service import BaseService
|
||||||
|
from utils.typedefs import StrOrInt
|
||||||
|
|
||||||
|
__all__ = ("AssetsServiceType", "AssetsService", "AssetsServiceError", "AssetsCouldNotFound", "DEFAULT_EnkaAssets")
|
||||||
|
|
||||||
|
ICON_TYPE = Callable[[bool], Awaitable[Path | None]] | Callable[..., Awaitable[Path | None]]
|
||||||
|
DEFAULT_EnkaAssets: EnkaAssets
|
||||||
|
_GET_TYPE = partial | list[str] | int | str | ICON_TYPE | Path | AsyncClient | None | Self | dict[str, str]
|
||||||
|
|
||||||
|
class AssetsServiceError(Exception): ...
|
||||||
|
|
||||||
|
class AssetsCouldNotFound(AssetsServiceError):
|
||||||
|
message: str
|
||||||
|
target: str
|
||||||
|
def __init__(self, message: str, target: str): ...
|
||||||
|
|
||||||
|
class _AssetsService(ABC):
|
||||||
|
icon_types: ClassVar[list[str]]
|
||||||
|
id: int
|
||||||
|
type: str
|
||||||
|
|
||||||
|
icon: ICON_TYPE
|
||||||
|
"""图标"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
@property
|
||||||
|
def game_name(self) -> str:
|
||||||
|
"""游戏数据中的名称"""
|
||||||
|
@property
|
||||||
|
def honey_id(self) -> str:
|
||||||
|
"""当前资源在 Honey Impact 所对应的 ID"""
|
||||||
|
@property
|
||||||
|
def path(self) -> Path:
|
||||||
|
"""当前资源的文件夹"""
|
||||||
|
@property
|
||||||
|
def client(self) -> AsyncClient:
|
||||||
|
"""当前的 http client"""
|
||||||
|
def __init__(self, client: AsyncClient | None = None) -> None: ...
|
||||||
|
def __call__(self, target: int) -> Self:
|
||||||
|
"""用于生成与 target 对应的 assets"""
|
||||||
|
def __getattr__(self, item: str) -> _GET_TYPE:
|
||||||
|
"""魔法"""
|
||||||
|
async def get_link(self, item: str) -> str | None:
|
||||||
|
"""获取相应图标链接"""
|
||||||
|
@abstractmethod
|
||||||
|
@property
|
||||||
|
def game_name_map(self) -> dict[str, str]:
|
||||||
|
"""游戏中的图标名"""
|
||||||
|
@abstractmethod
|
||||||
|
@property
|
||||||
|
def honey_name_map(self) -> dict[str, str]:
|
||||||
|
"""来自honey的图标名"""
|
||||||
|
|
||||||
|
class _AvatarAssets(_AssetsService):
|
||||||
|
enka: EnkaCharacterAsset | None
|
||||||
|
|
||||||
|
side: ICON_TYPE
|
||||||
|
"""侧视图图标"""
|
||||||
|
|
||||||
|
card: ICON_TYPE
|
||||||
|
"""卡片图标"""
|
||||||
|
|
||||||
|
gacha: ICON_TYPE
|
||||||
|
"""抽卡立绘"""
|
||||||
|
|
||||||
|
gacha_card: ICON_TYPE
|
||||||
|
"""抽卡卡片"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def honey_name_map(self) -> dict[str, str]: ...
|
||||||
|
@property
|
||||||
|
def game_name_map(self) -> dict[str, str]: ...
|
||||||
|
@property
|
||||||
|
def enka(self) -> EnkaCharacterAsset | None: ...
|
||||||
|
def __init__(self, client: AsyncClient | None = None, enka: EnkaAssets | None = None) -> None: ...
|
||||||
|
def __call__(self, target: StrOrInt) -> Self: ...
|
||||||
|
def __getitem__(self, item: str) -> _GET_TYPE | EnkaCharacterAsset: ...
|
||||||
|
def game_name(self) -> str: ...
|
||||||
|
|
||||||
|
class _WeaponAssets(_AssetsService):
|
||||||
|
awaken: ICON_TYPE
|
||||||
|
"""突破后图标"""
|
||||||
|
|
||||||
|
gacha: ICON_TYPE
|
||||||
|
"""抽卡立绘"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def honey_name_map(self) -> dict[str, str]: ...
|
||||||
|
@property
|
||||||
|
def game_name_map(self) -> dict[str, str]: ...
|
||||||
|
def __call__(self, target: StrOrInt) -> Self: ...
|
||||||
|
def game_name(self) -> str: ...
|
||||||
|
|
||||||
|
class _MaterialAssets(_AssetsService):
|
||||||
|
@property
|
||||||
|
def honey_name_map(self) -> dict[str, str]: ...
|
||||||
|
@property
|
||||||
|
def game_name_map(self) -> dict[str, str]: ...
|
||||||
|
def __call__(self, target: StrOrInt) -> Self: ...
|
||||||
|
def game_name(self) -> str: ...
|
||||||
|
|
||||||
|
class _ArtifactAssets(_AssetsService):
|
||||||
|
flower: ICON_TYPE
|
||||||
|
"""生之花"""
|
||||||
|
|
||||||
|
plume: ICON_TYPE
|
||||||
|
"""死之羽"""
|
||||||
|
|
||||||
|
sands: ICON_TYPE
|
||||||
|
"""时之沙"""
|
||||||
|
|
||||||
|
goblet: ICON_TYPE
|
||||||
|
"""空之杯"""
|
||||||
|
|
||||||
|
circlet: ICON_TYPE
|
||||||
|
"""理之冠"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def honey_name_map(self) -> dict[str, str]: ...
|
||||||
|
@property
|
||||||
|
def game_name_map(self) -> dict[str, str]: ...
|
||||||
|
def game_name(self) -> str: ...
|
||||||
|
|
||||||
|
class _NamecardAssets(_AssetsService):
|
||||||
|
enka: EnkaCharacterAsset | None
|
||||||
|
|
||||||
|
navbar: ICON_TYPE
|
||||||
|
"""好友名片背景"""
|
||||||
|
|
||||||
|
profile: ICON_TYPE
|
||||||
|
"""个人资料名片背景"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def honey_name_map(self) -> dict[str, str]: ...
|
||||||
|
@property
|
||||||
|
def game_name_map(self) -> dict[str, str]: ...
|
||||||
|
def game_name(self) -> str: ...
|
||||||
|
|
||||||
|
class AssetsService(BaseService.Dependence):
|
||||||
|
avatar: _AvatarAssets
|
||||||
|
"""角色"""
|
||||||
|
|
||||||
|
weapon: _WeaponAssets
|
||||||
|
"""武器"""
|
||||||
|
|
||||||
|
material: _MaterialAssets
|
||||||
|
"""素材"""
|
||||||
|
|
||||||
|
artifact: _ArtifactAssets
|
||||||
|
"""圣遗物"""
|
||||||
|
|
||||||
|
namecard: _NamecardAssets
|
||||||
|
"""名片"""
|
||||||
|
|
||||||
|
AssetsServiceType = TypeVar("AssetsServiceType", bound=_AssetsService)
|
@ -4,6 +4,8 @@ from urllib.parse import urlparse
|
|||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
|
|
||||||
|
from core.base_service import BaseService
|
||||||
|
from core.config import config as bot_config
|
||||||
from utils.log import logger
|
from utils.log import logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -13,13 +15,12 @@ try:
|
|||||||
session.log.debug = lambda *args, **kwargs: None # 关闭日记
|
session.log.debug = lambda *args, **kwargs: None # 关闭日记
|
||||||
PYROGRAM_AVAILABLE = True
|
PYROGRAM_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
Client = None
|
||||||
|
session = None
|
||||||
PYROGRAM_AVAILABLE = False
|
PYROGRAM_AVAILABLE = False
|
||||||
|
|
||||||
from core.bot import bot
|
|
||||||
from core.service import Service
|
|
||||||
|
|
||||||
|
class MTProto(BaseService.Dependence):
|
||||||
class MTProto(Service):
|
|
||||||
async def get_session(self):
|
async def get_session(self):
|
||||||
async with aiofiles.open(self.session_path, mode="r") as f:
|
async with aiofiles.open(self.session_path, mode="r") as f:
|
||||||
return await f.read()
|
return await f.read()
|
||||||
@ -32,9 +33,9 @@ class MTProto(Service):
|
|||||||
return os.path.exists(self.session_path)
|
return os.path.exists(self.session_path)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.name = "PaimonBot"
|
self.name = "paigram"
|
||||||
current_dir = os.getcwd()
|
current_dir = os.getcwd()
|
||||||
self.session_path = os.path.join(current_dir, "paimon.session")
|
self.session_path = os.path.join(current_dir, "paigram.session")
|
||||||
self.client: Optional[Client] = None
|
self.client: Optional[Client] = None
|
||||||
self.proxy: Optional[dict] = None
|
self.proxy: Optional[dict] = None
|
||||||
http_proxy = os.environ.get("HTTP_PROXY")
|
http_proxy = os.environ.get("HTTP_PROXY")
|
||||||
@ -42,25 +43,25 @@ class MTProto(Service):
|
|||||||
http_proxy_url = urlparse(http_proxy)
|
http_proxy_url = urlparse(http_proxy)
|
||||||
self.proxy = {"scheme": "http", "hostname": http_proxy_url.hostname, "port": http_proxy_url.port}
|
self.proxy = {"scheme": "http", "hostname": http_proxy_url.hostname, "port": http_proxy_url.port}
|
||||||
|
|
||||||
async def start(self): # pylint: disable=W0221
|
async def initialize(self): # pylint: disable=W0221
|
||||||
if not PYROGRAM_AVAILABLE:
|
if not PYROGRAM_AVAILABLE:
|
||||||
logger.info("MTProto 服务需要的 pyrogram 模块未导入 本次服务 client 为 None")
|
logger.info("MTProto 服务需要的 pyrogram 模块未导入 本次服务 client 为 None")
|
||||||
return
|
return
|
||||||
if bot.config.mtproto.api_id is None:
|
if bot_config.mtproto.api_id is None:
|
||||||
logger.info("MTProto 服务需要的 api_id 未配置 本次服务 client 为 None")
|
logger.info("MTProto 服务需要的 api_id 未配置 本次服务 client 为 None")
|
||||||
return
|
return
|
||||||
if bot.config.mtproto.api_hash is None:
|
if bot_config.mtproto.api_hash is None:
|
||||||
logger.info("MTProto 服务需要的 api_hash 未配置 本次服务 client 为 None")
|
logger.info("MTProto 服务需要的 api_hash 未配置 本次服务 client 为 None")
|
||||||
return
|
return
|
||||||
self.client = Client(
|
self.client = Client(
|
||||||
api_id=bot.config.mtproto.api_id,
|
api_id=bot_config.mtproto.api_id,
|
||||||
api_hash=bot.config.mtproto.api_hash,
|
api_hash=bot_config.mtproto.api_hash,
|
||||||
name=self.name,
|
name=self.name,
|
||||||
bot_token=bot.config.bot_token,
|
bot_token=bot_config.bot_token,
|
||||||
proxy=self.proxy,
|
proxy=self.proxy,
|
||||||
)
|
)
|
||||||
await self.client.start()
|
await self.client.start()
|
||||||
|
|
||||||
async def stop(self): # pylint: disable=W0221
|
async def shutdown(self): # pylint: disable=W0221
|
||||||
if self.client is not None:
|
if self.client is not None:
|
||||||
await self.client.stop(block=False)
|
await self.client.stop(block=False)
|
31
core/dependence/mtproto.pyi
Normal file
31
core/dependence/mtproto.pyi
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
|
from core.base_service import BaseService
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pyrogram import Client
|
||||||
|
from pyrogram.session import session
|
||||||
|
|
||||||
|
PYROGRAM_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
Client = None
|
||||||
|
session = None
|
||||||
|
PYROGRAM_AVAILABLE = False
|
||||||
|
|
||||||
|
__all__ = ("MTProto",)
|
||||||
|
|
||||||
|
class _ProxyType(TypedDict):
|
||||||
|
scheme: str
|
||||||
|
hostname: str | None
|
||||||
|
port: int | None
|
||||||
|
|
||||||
|
class MTProto(BaseService.Dependence):
|
||||||
|
name: str
|
||||||
|
session_path: str
|
||||||
|
client: Client | None
|
||||||
|
proxy: _ProxyType | None
|
||||||
|
|
||||||
|
async def get_session(self) -> str: ...
|
||||||
|
async def set_session(self, b: str) -> None: ...
|
||||||
|
def session_exists(self) -> bool: ...
|
50
core/dependence/mysql.py
Normal file
50
core/dependence/mysql.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
import contextlib
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy.engine import URL
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from core.base_service import BaseService
|
||||||
|
from core.config import ApplicationConfig
|
||||||
|
from core.sqlmodel.session import AsyncSession
|
||||||
|
|
||||||
|
__all__ = ("MySQL",)
|
||||||
|
|
||||||
|
|
||||||
|
class MySQL(BaseService.Dependence):
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: ApplicationConfig) -> Self:
|
||||||
|
return cls(**config.mysql.dict())
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: Optional[str] = None,
|
||||||
|
port: Optional[int] = None,
|
||||||
|
username: Optional[str] = None,
|
||||||
|
password: Optional[str] = None,
|
||||||
|
database: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.database = database
|
||||||
|
self.password = password
|
||||||
|
self.username = username
|
||||||
|
self.port = port
|
||||||
|
self.host = host
|
||||||
|
self.url = URL.create(
|
||||||
|
"mysql+asyncmy",
|
||||||
|
username=self.username,
|
||||||
|
password=self.password,
|
||||||
|
host=self.host,
|
||||||
|
port=self.port,
|
||||||
|
database=self.database,
|
||||||
|
)
|
||||||
|
self.engine = create_async_engine(self.url)
|
||||||
|
self.Session = sessionmaker(bind=self.engine, class_=AsyncSession)
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def session(self) -> AsyncSession:
|
||||||
|
yield self.Session()
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
self.Session.close_all()
|
@ -1,4 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import fakeredis.aioredis
|
import fakeredis.aioredis
|
||||||
@ -6,14 +5,16 @@ from redis import asyncio as aioredis
|
|||||||
from redis.exceptions import ConnectionError as RedisConnectionError, TimeoutError as RedisTimeoutError
|
from redis.exceptions import ConnectionError as RedisConnectionError, TimeoutError as RedisTimeoutError
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from core.config import BotConfig
|
from core.base_service import BaseService
|
||||||
from core.service import Service
|
from core.config import ApplicationConfig
|
||||||
from utils.log import logger
|
from utils.log import logger
|
||||||
|
|
||||||
|
__all__ = ["RedisDB"]
|
||||||
|
|
||||||
class RedisDB(Service):
|
|
||||||
|
class RedisDB(BaseService.Dependence):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: BotConfig) -> Self:
|
def from_config(cls, config: ApplicationConfig) -> Self:
|
||||||
return cls(**config.redis.dict())
|
return cls(**config.redis.dict())
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -24,6 +25,7 @@ class RedisDB(Service):
|
|||||||
self.key_prefix = "paimon_bot"
|
self.key_prefix = "paimon_bot"
|
||||||
|
|
||||||
async def ping(self):
|
async def ping(self):
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
if await self.client.ping():
|
if await self.client.ping():
|
||||||
logger.info("连接 [red]Redis[/] 成功", extra={"markup": True})
|
logger.info("连接 [red]Redis[/] 成功", extra={"markup": True})
|
||||||
else:
|
else:
|
||||||
@ -34,7 +36,7 @@ class RedisDB(Service):
|
|||||||
self.client = fakeredis.aioredis.FakeRedis()
|
self.client = fakeredis.aioredis.FakeRedis()
|
||||||
await self.ping()
|
await self.ping()
|
||||||
|
|
||||||
async def start(self): # pylint: disable=W0221
|
async def initialize(self):
|
||||||
logger.info("正在尝试建立与 [red]Redis[/] 连接", extra={"markup": True})
|
logger.info("正在尝试建立与 [red]Redis[/] 连接", extra={"markup": True})
|
||||||
try:
|
try:
|
||||||
await self.ping()
|
await self.ping()
|
||||||
@ -45,5 +47,5 @@ class RedisDB(Service):
|
|||||||
logger.warning("连接 [red]Redis[/] 失败,使用 [red]fakeredis[/] 模拟", extra={"markup": True})
|
logger.warning("连接 [red]Redis[/] 失败,使用 [red]fakeredis[/] 模拟", extra={"markup": True})
|
||||||
await self.start_fake_redis()
|
await self.start_fake_redis()
|
||||||
|
|
||||||
async def stop(self): # pylint: disable=W0221
|
async def shutdown(self):
|
||||||
await self.client.close()
|
await self.client.close()
|
@ -1,16 +0,0 @@
|
|||||||
from core.base.redisdb import RedisDB
|
|
||||||
from core.service import init_service
|
|
||||||
from .cache import GameCache
|
|
||||||
from .services import GameMaterialService, GameStrategyService
|
|
||||||
|
|
||||||
|
|
||||||
@init_service
|
|
||||||
def create_game_strategy_service(redis: RedisDB):
|
|
||||||
_cache = GameCache(redis, "game:strategy")
|
|
||||||
return GameStrategyService(_cache)
|
|
||||||
|
|
||||||
|
|
||||||
@init_service
|
|
||||||
def create_game_material_service(redis: RedisDB):
|
|
||||||
_cache = GameCache(redis, "game:material")
|
|
||||||
return GameMaterialService(_cache)
|
|
59
core/handler/adminhandler.py
Normal file
59
core/handler/adminhandler.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import TypeVar, TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
|
from telegram import Update
|
||||||
|
from telegram.ext import ApplicationHandlerStop, BaseHandler
|
||||||
|
|
||||||
|
from core.error import ServiceNotFoundError
|
||||||
|
from core.services.users.services import UserAdminService
|
||||||
|
from utils.log import logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.application import Application
|
||||||
|
from telegram.ext import Application as TelegramApplication
|
||||||
|
|
||||||
|
RT = TypeVar("RT")
|
||||||
|
UT = TypeVar("UT")
|
||||||
|
|
||||||
|
CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]")
|
||||||
|
|
||||||
|
|
||||||
|
class AdminHandler(BaseHandler[Update, CCT]):
|
||||||
|
_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
def __init__(self, handler: BaseHandler[Update, CCT], application: "Application") -> None:
|
||||||
|
self.handler = handler
|
||||||
|
self.application = application
|
||||||
|
self.user_service: Optional["UserAdminService"] = None
|
||||||
|
super().__init__(self.handler.callback)
|
||||||
|
|
||||||
|
def check_update(self, update: object) -> bool:
|
||||||
|
if not isinstance(update, Update):
|
||||||
|
return False
|
||||||
|
return self.handler.check_update(update)
|
||||||
|
|
||||||
|
async def _user_service(self) -> "UserAdminService":
|
||||||
|
async with self._lock:
|
||||||
|
if self.user_service is not None:
|
||||||
|
return self.user_service
|
||||||
|
user_service: UserAdminService = self.application.managers.services_map.get(UserAdminService, None)
|
||||||
|
if user_service is None:
|
||||||
|
raise ServiceNotFoundError("UserAdminService")
|
||||||
|
self.user_service = user_service
|
||||||
|
return self.user_service
|
||||||
|
|
||||||
|
async def handle_update(
|
||||||
|
self,
|
||||||
|
update: "UT",
|
||||||
|
application: "TelegramApplication[Any, CCT, Any, Any, Any, Any]",
|
||||||
|
check_result: Any,
|
||||||
|
context: "CCT",
|
||||||
|
) -> RT:
|
||||||
|
user_service = await self._user_service()
|
||||||
|
user = update.effective_user
|
||||||
|
if await user_service.is_admin(user.id):
|
||||||
|
return await self.handler.handle_update(update, application, check_result, context)
|
||||||
|
message = update.effective_message
|
||||||
|
logger.warning("用户 %s[%s] 触发尝试调用Admin命令但权限不足", user.full_name, user.id)
|
||||||
|
await message.reply_text("权限不足")
|
||||||
|
raise ApplicationHandlerStop
|
62
core/handler/callbackqueryhandler.py
Normal file
62
core/handler/callbackqueryhandler.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import asyncio
|
||||||
|
from contextlib import AbstractAsyncContextManager
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import TypeVar, TYPE_CHECKING, Any, Optional, Type
|
||||||
|
|
||||||
|
from telegram.ext import CallbackQueryHandler as BaseCallbackQueryHandler, ApplicationHandlerStop
|
||||||
|
|
||||||
|
from utils.log import logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from telegram.ext import Application
|
||||||
|
|
||||||
|
RT = TypeVar("RT")
|
||||||
|
UT = TypeVar("UT")
|
||||||
|
CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]")
|
||||||
|
|
||||||
|
|
||||||
|
class OverlappingException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OverlappingContext(AbstractAsyncContextManager):
|
||||||
|
_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
def __init__(self, context: "CCT"):
|
||||||
|
self.context = context
|
||||||
|
|
||||||
|
async def __aenter__(self) -> None:
|
||||||
|
async with self._lock:
|
||||||
|
flag = self.context.user_data.get("overlapping", False)
|
||||||
|
if flag:
|
||||||
|
raise OverlappingException
|
||||||
|
self.context.user_data["overlapping"] = True
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[Type[BaseException]],
|
||||||
|
exc: Optional[BaseException],
|
||||||
|
tb: Optional[TracebackType],
|
||||||
|
) -> None:
|
||||||
|
async with self._lock:
|
||||||
|
del self.context.user_data["overlapping"]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class CallbackQueryHandler(BaseCallbackQueryHandler):
|
||||||
|
async def handle_update(
|
||||||
|
self,
|
||||||
|
update: "UT",
|
||||||
|
application: "Application[Any, CCT, Any, Any, Any, Any]",
|
||||||
|
check_result: Any,
|
||||||
|
context: "CCT",
|
||||||
|
) -> RT:
|
||||||
|
self.collect_additional_context(context, update, application, check_result)
|
||||||
|
try:
|
||||||
|
async with OverlappingContext(context):
|
||||||
|
return await self.callback(update, context)
|
||||||
|
except OverlappingException as exc:
|
||||||
|
user = update.effective_user
|
||||||
|
logger.warning("用户 %s[%s] 触发 overlapping 该次命令已忽略", user.full_name, user.id)
|
||||||
|
raise ApplicationHandlerStop from exc
|
71
core/handler/limiterhandler.py
Normal file
71
core/handler/limiterhandler.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import TypeVar, Optional
|
||||||
|
|
||||||
|
from telegram import Update
|
||||||
|
from telegram.ext import ContextTypes, ApplicationHandlerStop, TypeHandler
|
||||||
|
|
||||||
|
from utils.log import logger
|
||||||
|
|
||||||
|
UT = TypeVar("UT")
|
||||||
|
CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]")
|
||||||
|
|
||||||
|
|
||||||
|
class LimiterHandler(TypeHandler[UT, CCT]):
|
||||||
|
_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, max_rate: float = 5, time_period: float = 10, amount: float = 1, limit_time: Optional[float] = None
|
||||||
|
):
|
||||||
|
"""Limiter Handler 通过
|
||||||
|
`Leaky bucket algorithm <https://en.wikipedia.org/wiki/Leaky_bucket>`_
|
||||||
|
实现对用户的输入的精确控制
|
||||||
|
|
||||||
|
输入超过一定速率后,代码会抛出
|
||||||
|
:class:`telegram.ext.ApplicationHandlerStop`
|
||||||
|
异常并在一段时间内防止用户执行任何其他操作
|
||||||
|
|
||||||
|
:param max_rate: 在抛出异常之前最多允许 频率/秒 的速度
|
||||||
|
:param time_period: 在限制速率的时间段的持续时间
|
||||||
|
:param amount: 提供的容量
|
||||||
|
:param limit_time: 限制时间 如果不提供限制时间为 max_rate / time_period * amount
|
||||||
|
"""
|
||||||
|
self.max_rate = max_rate
|
||||||
|
self.amount = amount
|
||||||
|
self._rate_per_sec = max_rate / time_period
|
||||||
|
self.limit_time = limit_time
|
||||||
|
super().__init__(Update, self.limiter_callback)
|
||||||
|
|
||||||
|
async def limiter_callback(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||||
|
if update.inline_query is not None:
|
||||||
|
return
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
async with self._lock:
|
||||||
|
time = loop.time()
|
||||||
|
user_data = context.user_data
|
||||||
|
if user_data is None:
|
||||||
|
return
|
||||||
|
user_limit_time = user_data.get("limit_time")
|
||||||
|
if user_limit_time is not None:
|
||||||
|
if time >= user_limit_time:
|
||||||
|
del user_data["limit_time"]
|
||||||
|
else:
|
||||||
|
raise ApplicationHandlerStop
|
||||||
|
last_task_time = user_data.get("last_task_time", 0)
|
||||||
|
if last_task_time:
|
||||||
|
task_level = user_data.get("task_level", 0)
|
||||||
|
elapsed = time - last_task_time
|
||||||
|
decrement = elapsed * self._rate_per_sec
|
||||||
|
task_level = max(task_level - decrement, 0)
|
||||||
|
user_data["task_level"] = task_level
|
||||||
|
if not task_level + self.amount <= self.max_rate:
|
||||||
|
if self.limit_time:
|
||||||
|
limit_time = self.limit_time
|
||||||
|
else:
|
||||||
|
limit_time = 1 / self._rate_per_sec * self.amount
|
||||||
|
user_data["limit_time"] = time + limit_time
|
||||||
|
user = update.effective_user
|
||||||
|
logger.warning("用户 %s[%s] 触发洪水限制 已被限制 %s 秒", user.full_name, user.id, limit_time)
|
||||||
|
raise ApplicationHandlerStop
|
||||||
|
user_data["last_task_time"] = time
|
||||||
|
task_level = user_data.get("task_level", 0)
|
||||||
|
user_data["task_level"] = task_level + self.amount
|
286
core/manager.py
Normal file
286
core/manager.py
Normal file
@ -0,0 +1,286 @@
|
|||||||
|
import asyncio
|
||||||
|
from importlib import import_module
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Generic, List, Optional, TYPE_CHECKING, Type, TypeVar
|
||||||
|
|
||||||
|
from arkowrapper import ArkoWrapper
|
||||||
|
from async_timeout import timeout
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
|
from core.base_service import BaseServiceType, ComponentType, DependenceType, get_all_services
|
||||||
|
from core.config import config as bot_config
|
||||||
|
from utils.const import PLUGIN_DIR, PROJECT_ROOT
|
||||||
|
from utils.helpers import gen_pkg
|
||||||
|
from utils.log import logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.application import Application
|
||||||
|
from core.plugin import PluginType
|
||||||
|
from core.builtins.executor import Executor
|
||||||
|
|
||||||
|
__all__ = ("DependenceManager", "PluginManager", "ComponentManager", "ServiceManager", "Managers")
|
||||||
|
|
||||||
|
R = TypeVar("R")
|
||||||
|
T = TypeVar("T")
|
||||||
|
P = ParamSpec("P")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_module(path: Path) -> None:
|
||||||
|
for pkg in gen_pkg(path):
|
||||||
|
try:
|
||||||
|
logger.debug('正在导入 "%s"', pkg)
|
||||||
|
import_module(pkg)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
'在导入 "%s" 的过程中遇到了错误 [red bold]%s[/]', pkg, type(e).__name__, exc_info=e, extra={"markup": True}
|
||||||
|
)
|
||||||
|
raise SystemExit from e
|
||||||
|
|
||||||
|
|
||||||
|
class Manager(Generic[T]):
|
||||||
|
"""生命周期控制基类"""
|
||||||
|
|
||||||
|
_executor: Optional["Executor"] = None
|
||||||
|
_lib: Dict[Type[T], T] = {}
|
||||||
|
_application: "Optional[Application]" = None
|
||||||
|
|
||||||
|
def set_application(self, application: "Application") -> None:
|
||||||
|
self._application = application
|
||||||
|
|
||||||
|
@property
|
||||||
|
def application(self) -> "Application":
|
||||||
|
if self._application is None:
|
||||||
|
raise RuntimeError(f"No application was set for this {self.__class__.__name__}.")
|
||||||
|
return self._application
|
||||||
|
|
||||||
|
@property
|
||||||
|
def executor(self) -> "Executor":
|
||||||
|
"""执行器"""
|
||||||
|
if self._executor is None:
|
||||||
|
raise RuntimeError(f"No executor was set for this {self.__class__.__name__}.")
|
||||||
|
return self._executor
|
||||||
|
|
||||||
|
def build_executor(self, name: str):
|
||||||
|
from core.builtins.executor import Executor
|
||||||
|
from core.builtins.dispatcher import BaseDispatcher
|
||||||
|
|
||||||
|
self._executor = Executor(name, dispatcher=BaseDispatcher)
|
||||||
|
self._executor.set_application(self.application)
|
||||||
|
|
||||||
|
|
||||||
|
class DependenceManager(Manager[DependenceType]):
|
||||||
|
"""基础依赖管理"""
|
||||||
|
|
||||||
|
_dependency: Dict[Type[DependenceType], DependenceType] = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dependency(self) -> List[DependenceType]:
|
||||||
|
return list(self._dependency.values())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dependency_map(self) -> Dict[Type[DependenceType], DependenceType]:
|
||||||
|
return self._dependency
|
||||||
|
|
||||||
|
async def start_dependency(self) -> None:
|
||||||
|
_load_module(PROJECT_ROOT / "core/dependence")
|
||||||
|
|
||||||
|
for dependence in filter(lambda x: x.is_dependence, get_all_services()):
|
||||||
|
dependence: Type[DependenceType]
|
||||||
|
instance: DependenceType
|
||||||
|
try:
|
||||||
|
if hasattr(dependence, "from_config"): # 如果有 from_config 方法
|
||||||
|
instance = dependence.from_config(bot_config) # 用 from_config 实例化服务
|
||||||
|
else:
|
||||||
|
instance = await self.executor(dependence)
|
||||||
|
|
||||||
|
await instance.initialize()
|
||||||
|
logger.success('基础服务 "%s" 启动成功', dependence.__name__)
|
||||||
|
|
||||||
|
self._lib[dependence] = instance
|
||||||
|
self._dependency[dependence] = instance
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception('基础服务 "%s" 初始化失败,BOT 将自动关闭', dependence.__name__)
|
||||||
|
raise SystemExit from e
|
||||||
|
|
||||||
|
async def stop_dependency(self) -> None:
|
||||||
|
async def task(d):
|
||||||
|
try:
|
||||||
|
async with timeout(5):
|
||||||
|
await d.shutdown()
|
||||||
|
logger.debug('基础服务 "%s" 关闭成功', d.__class__.__name__)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning('基础服务 "%s" 关闭超时', d.__class__.__name__)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error('基础服务 "%s" 关闭错误', d.__class__.__name__, exc_info=e)
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
for dependence in self._dependency.values():
|
||||||
|
tasks.append(asyncio.create_task(task(dependence)))
|
||||||
|
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
|
||||||
|
class ComponentManager(Manager[ComponentType]):
|
||||||
|
"""组件管理"""
|
||||||
|
|
||||||
|
_components: Dict[Type[ComponentType], ComponentType] = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def components(self) -> List[ComponentType]:
|
||||||
|
return list(self._components.values())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def components_map(self) -> Dict[Type[ComponentType], ComponentType]:
|
||||||
|
return self._components
|
||||||
|
|
||||||
|
async def init_components(self):
|
||||||
|
for path in filter(
|
||||||
|
lambda x: x.is_dir() and not x.name.startswith("_"), PROJECT_ROOT.joinpath("core/services").iterdir()
|
||||||
|
):
|
||||||
|
_load_module(path)
|
||||||
|
components = ArkoWrapper(get_all_services()).filter(lambda x: x.is_component)
|
||||||
|
retry_times = 0
|
||||||
|
max_retry_times = len(components)
|
||||||
|
while components:
|
||||||
|
start_len = len(components)
|
||||||
|
for component in list(components):
|
||||||
|
component: Type[ComponentType]
|
||||||
|
instance: ComponentType
|
||||||
|
try:
|
||||||
|
instance = await self.executor(component)
|
||||||
|
self._lib[component] = instance
|
||||||
|
self._components[component] = instance
|
||||||
|
components = components.remove(component)
|
||||||
|
except Exception as e: # pylint: disable=W0703
|
||||||
|
logger.debug('组件 "%s" 初始化失败: [red]%s[/]', component.__name__, e, extra={"markup": True})
|
||||||
|
end_len = len(list(components))
|
||||||
|
if start_len == end_len:
|
||||||
|
retry_times += 1
|
||||||
|
|
||||||
|
if retry_times == max_retry_times and components:
|
||||||
|
for component in components:
|
||||||
|
logger.error('组件 "%s" 初始化失败', component.__name__)
|
||||||
|
raise SystemExit
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceManager(Manager[BaseServiceType]):
|
||||||
|
"""服务控制类"""
|
||||||
|
|
||||||
|
_services: Dict[Type[BaseServiceType], BaseServiceType] = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def services(self) -> List[BaseServiceType]:
|
||||||
|
return list(self._services.values())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def services_map(self) -> Dict[Type[BaseServiceType], BaseServiceType]:
|
||||||
|
return self._services
|
||||||
|
|
||||||
|
async def _initialize_service(self, target: Type[BaseServiceType]) -> BaseServiceType:
|
||||||
|
instance: BaseServiceType
|
||||||
|
try:
|
||||||
|
if hasattr(target, "from_config"): # 如果有 from_config 方法
|
||||||
|
instance = target.from_config(bot_config) # 用 from_config 实例化服务
|
||||||
|
else:
|
||||||
|
instance = await self.executor(target)
|
||||||
|
|
||||||
|
await instance.initialize()
|
||||||
|
logger.success('服务 "%s" 启动成功', target.__name__)
|
||||||
|
|
||||||
|
return instance
|
||||||
|
|
||||||
|
except Exception as e: # pylint: disable=W0703
|
||||||
|
logger.exception('服务 "%s" 初始化失败,BOT 将自动关闭', target.__name__)
|
||||||
|
raise SystemExit from e
|
||||||
|
|
||||||
|
async def start_services(self) -> None:
|
||||||
|
for path in filter(
|
||||||
|
lambda x: x.is_dir() and not x.name.startswith("_"), PROJECT_ROOT.joinpath("core/services").iterdir()
|
||||||
|
):
|
||||||
|
_load_module(path)
|
||||||
|
|
||||||
|
for service in filter(lambda x: not x.is_component and not x.is_dependence, get_all_services()): # 遍历所有服务类
|
||||||
|
instance = await self._initialize_service(service)
|
||||||
|
|
||||||
|
self._lib[service] = instance
|
||||||
|
self._services[service] = instance
|
||||||
|
|
||||||
|
async def stop_services(self) -> None:
|
||||||
|
"""关闭服务"""
|
||||||
|
if not self._services:
|
||||||
|
return
|
||||||
|
|
||||||
|
async def task(s):
|
||||||
|
try:
|
||||||
|
async with timeout(5):
|
||||||
|
await s.shutdown()
|
||||||
|
logger.success('服务 "%s" 关闭成功', s.__class__.__name__)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning('服务 "%s" 关闭超时', s.__class__.__name__)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning('服务 "%s" 关闭失败', s.__class__.__name__, exc_info=e)
|
||||||
|
|
||||||
|
logger.info("正在关闭服务")
|
||||||
|
tasks = []
|
||||||
|
for service in self._services.values():
|
||||||
|
tasks.append(asyncio.create_task(task(service)))
|
||||||
|
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginManager(Manager["PluginType"]):
|
||||||
|
"""插件管理"""
|
||||||
|
|
||||||
|
_plugins: Dict[Type["PluginType"], "PluginType"] = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def plugins(self) -> List["PluginType"]:
|
||||||
|
"""所有已经加载的插件"""
|
||||||
|
return list(self._plugins.values())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def plugins_map(self) -> Dict[Type["PluginType"], "PluginType"]:
|
||||||
|
return self._plugins
|
||||||
|
|
||||||
|
async def install_plugins(self) -> None:
|
||||||
|
"""安装所有插件"""
|
||||||
|
from core.plugin import get_all_plugins
|
||||||
|
|
||||||
|
for path in filter(lambda x: x.is_dir(), PLUGIN_DIR.iterdir()):
|
||||||
|
_load_module(path)
|
||||||
|
|
||||||
|
for plugin in get_all_plugins():
|
||||||
|
plugin: Type["PluginType"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
instance: "PluginType" = await self.executor(plugin)
|
||||||
|
except Exception as e: # pylint: disable=W0703
|
||||||
|
logger.error('插件 "%s" 初始化失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._plugins[plugin] = instance
|
||||||
|
|
||||||
|
if self._application is not None:
|
||||||
|
instance.set_application(self._application)
|
||||||
|
|
||||||
|
await asyncio.create_task(self.plugin_install_task(plugin, instance))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def plugin_install_task(plugin: Type["PluginType"], instance: "PluginType"):
|
||||||
|
try:
|
||||||
|
await instance.install()
|
||||||
|
logger.success('插件 "%s" 安装成功', f"{plugin.__module__}.{plugin.__name__}")
|
||||||
|
except Exception as e: # pylint: disable=W0703
|
||||||
|
logger.error('插件 "%s" 安装失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e)
|
||||||
|
|
||||||
|
async def uninstall_plugins(self) -> None:
|
||||||
|
for plugin in self._plugins.values():
|
||||||
|
try:
|
||||||
|
await plugin.uninstall()
|
||||||
|
except Exception as e: # pylint: disable=W0703
|
||||||
|
logger.error('插件 "%s" 卸载失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e)
|
||||||
|
|
||||||
|
|
||||||
|
class Managers(DependenceManager, ComponentManager, ServiceManager, PluginManager):
|
||||||
|
"""BOT 除自身外的生命周期管理类"""
|
106
core/override/telegram.py
Normal file
106
core/override/telegram.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
"""重写 telegram.request.HTTPXRequest 使其使用 ujson 库进行 json 序列化"""
|
||||||
|
from typing import Any, AsyncIterable, Optional
|
||||||
|
|
||||||
|
import httpcore
|
||||||
|
from httpx import (
|
||||||
|
AsyncByteStream,
|
||||||
|
AsyncHTTPTransport as DefaultAsyncHTTPTransport,
|
||||||
|
Limits,
|
||||||
|
Response as DefaultResponse,
|
||||||
|
Timeout,
|
||||||
|
)
|
||||||
|
from telegram.request import HTTPXRequest as DefaultHTTPXRequest
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ujson as jsonlib
|
||||||
|
except ImportError:
|
||||||
|
import json as jsonlib
|
||||||
|
|
||||||
|
__all__ = ("HTTPXRequest",)
|
||||||
|
|
||||||
|
|
||||||
|
class Response(DefaultResponse):
|
||||||
|
def json(self, **kwargs: Any) -> Any:
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
from httpx._utils import guess_json_utf
|
||||||
|
|
||||||
|
if self.charset_encoding is None and self.content and len(self.content) > 3:
|
||||||
|
encoding = guess_json_utf(self.content)
|
||||||
|
if encoding is not None:
|
||||||
|
return jsonlib.loads(self.content.decode(encoding), **kwargs)
|
||||||
|
return jsonlib.loads(self.text, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
class AsyncHTTPTransport(DefaultAsyncHTTPTransport):
|
||||||
|
async def handle_async_request(self, request) -> Response:
|
||||||
|
from httpx._transports.default import (
|
||||||
|
map_httpcore_exceptions,
|
||||||
|
AsyncResponseStream,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(request.stream, AsyncByteStream):
|
||||||
|
raise AssertionError
|
||||||
|
|
||||||
|
req = httpcore.Request(
|
||||||
|
method=request.method,
|
||||||
|
url=httpcore.URL(
|
||||||
|
scheme=request.url.raw_scheme,
|
||||||
|
host=request.url.raw_host,
|
||||||
|
port=request.url.port,
|
||||||
|
target=request.url.raw_path,
|
||||||
|
),
|
||||||
|
headers=request.headers.raw,
|
||||||
|
content=request.stream,
|
||||||
|
extensions=request.extensions,
|
||||||
|
)
|
||||||
|
with map_httpcore_exceptions():
|
||||||
|
resp = await self._pool.handle_async_request(req)
|
||||||
|
|
||||||
|
if not isinstance(resp.stream, AsyncIterable):
|
||||||
|
raise AssertionError
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
status_code=resp.status,
|
||||||
|
headers=resp.headers,
|
||||||
|
stream=AsyncResponseStream(resp.stream),
|
||||||
|
extensions=resp.extensions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPXRequest(DefaultHTTPXRequest):
|
||||||
|
def __init__( # pylint: disable=W0231
|
||||||
|
self,
|
||||||
|
connection_pool_size: int = 1,
|
||||||
|
proxy_url: str = None,
|
||||||
|
read_timeout: Optional[float] = 5.0,
|
||||||
|
write_timeout: Optional[float] = 5.0,
|
||||||
|
connect_timeout: Optional[float] = 5.0,
|
||||||
|
pool_timeout: Optional[float] = 1.0,
|
||||||
|
):
|
||||||
|
timeout = Timeout(
|
||||||
|
connect=connect_timeout,
|
||||||
|
read=read_timeout,
|
||||||
|
write=write_timeout,
|
||||||
|
pool=pool_timeout,
|
||||||
|
)
|
||||||
|
limits = Limits(
|
||||||
|
max_connections=connection_pool_size,
|
||||||
|
max_keepalive_connections=connection_pool_size,
|
||||||
|
)
|
||||||
|
self._client_kwargs = dict(
|
||||||
|
timeout=timeout,
|
||||||
|
proxies=proxy_url,
|
||||||
|
limits=limits,
|
||||||
|
transport=AsyncHTTPTransport(limits=limits),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._client = self._build_client()
|
||||||
|
except ImportError as exc:
|
||||||
|
if "httpx[socks]" not in str(exc):
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
"To use Socks5 proxies, PTB must be installed via `pip install python-telegram-bot[socks]`."
|
||||||
|
) from exc
|
483
core/plugin.py
483
core/plugin.py
@ -1,483 +0,0 @@
|
|||||||
import copy
|
|
||||||
import datetime
|
|
||||||
import re
|
|
||||||
from importlib import import_module
|
|
||||||
from re import Pattern
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
|
||||||
|
|
||||||
# noinspection PyProtectedMember
|
|
||||||
from telegram._utils.defaultvalue import DEFAULT_TRUE
|
|
||||||
|
|
||||||
# noinspection PyProtectedMember
|
|
||||||
from telegram._utils.types import DVInput, JSONDict
|
|
||||||
from telegram.ext import BaseHandler, ConversationHandler, Job
|
|
||||||
|
|
||||||
# noinspection PyProtectedMember
|
|
||||||
from telegram.ext._utils.types import JobCallback
|
|
||||||
from telegram.ext.filters import BaseFilter
|
|
||||||
from typing_extensions import ParamSpec
|
|
||||||
|
|
||||||
__all__ = ["Plugin", "handler", "conversation", "job", "error_handler"]
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
T = TypeVar("T")
|
|
||||||
HandlerType = TypeVar("HandlerType", bound=BaseHandler)
|
|
||||||
TimeType = Union[float, datetime.timedelta, datetime.datetime, datetime.time]
|
|
||||||
|
|
||||||
_Module = import_module("telegram.ext")
|
|
||||||
|
|
||||||
_NORMAL_HANDLER_ATTR_NAME = "_handler_data"
|
|
||||||
_CONVERSATION_HANDLER_ATTR_NAME = "_conversation_data"
|
|
||||||
_JOB_ATTR_NAME = "_job_data"
|
|
||||||
|
|
||||||
_EXCLUDE_ATTRS = ["handlers", "jobs", "error_handlers"]
|
|
||||||
|
|
||||||
|
|
||||||
class _Plugin:
|
|
||||||
def _make_handler(self, datas: Union[List[Dict], Dict]) -> List[HandlerType]:
|
|
||||||
result = []
|
|
||||||
if isinstance(datas, list):
|
|
||||||
for data in filter(lambda x: x, datas):
|
|
||||||
func = getattr(self, data.pop("func"))
|
|
||||||
result.append(data.pop("type")(callback=func, **data.pop("kwargs")))
|
|
||||||
else:
|
|
||||||
func = getattr(self, datas.pop("func"))
|
|
||||||
result.append(datas.pop("type")(callback=func, **datas.pop("kwargs")))
|
|
||||||
return result
|
|
||||||
|
|
||||||
@property
|
|
||||||
def handlers(self) -> List[HandlerType]:
|
|
||||||
result = []
|
|
||||||
for attr in dir(self):
|
|
||||||
# noinspection PyUnboundLocalVariable
|
|
||||||
if (
|
|
||||||
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
|
||||||
and isinstance(func := getattr(self, attr), MethodType)
|
|
||||||
and (datas := getattr(func, _NORMAL_HANDLER_ATTR_NAME, None))
|
|
||||||
):
|
|
||||||
for data in datas:
|
|
||||||
if data["type"] not in ["error", "new_chat_member"]:
|
|
||||||
result.extend(self._make_handler(data))
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _new_chat_members_handler_funcs(self) -> List[Tuple[int, Callable]]:
|
|
||||||
result = []
|
|
||||||
for attr in dir(self):
|
|
||||||
# noinspection PyUnboundLocalVariable
|
|
||||||
if (
|
|
||||||
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
|
||||||
and isinstance(func := getattr(self, attr), MethodType)
|
|
||||||
and (datas := getattr(func, _NORMAL_HANDLER_ATTR_NAME, None))
|
|
||||||
):
|
|
||||||
for data in datas:
|
|
||||||
if data and data["type"] == "new_chat_member":
|
|
||||||
result.append((data["priority"], func))
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
@property
|
|
||||||
def error_handlers(self) -> Dict[Callable, bool]:
|
|
||||||
result = {}
|
|
||||||
for attr in dir(self):
|
|
||||||
# noinspection PyUnboundLocalVariable
|
|
||||||
if (
|
|
||||||
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
|
||||||
and isinstance(func := getattr(self, attr), MethodType)
|
|
||||||
and (datas := getattr(func, _NORMAL_HANDLER_ATTR_NAME, None))
|
|
||||||
):
|
|
||||||
for data in datas:
|
|
||||||
if data and data["type"] == "error":
|
|
||||||
result.update({func: data["block"]})
|
|
||||||
return result
|
|
||||||
|
|
||||||
@property
|
|
||||||
def jobs(self) -> List[Job]:
|
|
||||||
from core.bot import bot
|
|
||||||
|
|
||||||
result = []
|
|
||||||
for attr in dir(self):
|
|
||||||
# noinspection PyUnboundLocalVariable
|
|
||||||
if (
|
|
||||||
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
|
||||||
and isinstance(func := getattr(self, attr), MethodType)
|
|
||||||
and (datas := getattr(func, _JOB_ATTR_NAME, None))
|
|
||||||
):
|
|
||||||
for data in datas:
|
|
||||||
_job = getattr(bot.job_queue, data.pop("type"))(
|
|
||||||
callback=func, **data.pop("kwargs"), **{key: data.pop(key) for key in list(data.keys())}
|
|
||||||
)
|
|
||||||
result.append(_job)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class _Conversation(_Plugin):
|
|
||||||
_conversation_kwargs: Dict
|
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs):
|
|
||||||
cls._conversation_kwargs = kwargs
|
|
||||||
super(_Conversation, cls).__init_subclass__()
|
|
||||||
return cls
|
|
||||||
|
|
||||||
@property
|
|
||||||
def handlers(self) -> List[HandlerType]:
|
|
||||||
result: List[HandlerType] = []
|
|
||||||
|
|
||||||
entry_points: List[HandlerType] = []
|
|
||||||
states: Dict[Any, List[HandlerType]] = {}
|
|
||||||
fallbacks: List[HandlerType] = []
|
|
||||||
for attr in dir(self):
|
|
||||||
# noinspection PyUnboundLocalVariable
|
|
||||||
if (
|
|
||||||
not (attr.startswith("_") or attr == "handlers")
|
|
||||||
and isinstance(func := getattr(self, attr), Callable)
|
|
||||||
and (handler_datas := getattr(func, _NORMAL_HANDLER_ATTR_NAME, None))
|
|
||||||
):
|
|
||||||
conversation_data = getattr(func, _CONVERSATION_HANDLER_ATTR_NAME, None)
|
|
||||||
if attr == "cancel":
|
|
||||||
handler_datas = copy.deepcopy(handler_datas)
|
|
||||||
conversation_data = copy.deepcopy(conversation_data)
|
|
||||||
_handlers = self._make_handler(handler_datas)
|
|
||||||
if conversation_data:
|
|
||||||
if (_type := conversation_data.pop("type")) == "entry":
|
|
||||||
entry_points.extend(_handlers)
|
|
||||||
elif _type == "state":
|
|
||||||
if (key := conversation_data.pop("state")) in states:
|
|
||||||
states[key].extend(_handlers)
|
|
||||||
else:
|
|
||||||
states[key] = _handlers
|
|
||||||
elif _type == "fallback":
|
|
||||||
fallbacks.extend(_handlers)
|
|
||||||
else:
|
|
||||||
result.extend(_handlers)
|
|
||||||
if entry_points or states or fallbacks:
|
|
||||||
result.append(
|
|
||||||
ConversationHandler(
|
|
||||||
entry_points, states, fallbacks, **self.__class__._conversation_kwargs # pylint: disable=W0212
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class Plugin(_Plugin):
|
|
||||||
Conversation = _Conversation
|
|
||||||
|
|
||||||
|
|
||||||
class _Handler:
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
self.kwargs = kwargs
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _type(self) -> Type[BaseHandler]:
|
|
||||||
return getattr(_Module, f"{self.__class__.__name__.strip('_')}Handler")
|
|
||||||
|
|
||||||
def __call__(self, func: Callable[P, T]) -> Callable[P, T]:
|
|
||||||
data = {"type": self._type, "func": func.__name__, "kwargs": self.kwargs}
|
|
||||||
if hasattr(func, _NORMAL_HANDLER_ATTR_NAME):
|
|
||||||
handler_datas = getattr(func, _NORMAL_HANDLER_ATTR_NAME)
|
|
||||||
handler_datas.append(data)
|
|
||||||
setattr(func, _NORMAL_HANDLER_ATTR_NAME, handler_datas)
|
|
||||||
else:
|
|
||||||
setattr(func, _NORMAL_HANDLER_ATTR_NAME, [data])
|
|
||||||
return func
|
|
||||||
|
|
||||||
|
|
||||||
class _CallbackQuery(_Handler):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
pattern: Union[str, Pattern, type, Callable[[object], Optional[bool]]] = None,
|
|
||||||
block: DVInput[bool] = DEFAULT_TRUE,
|
|
||||||
):
|
|
||||||
super(_CallbackQuery, self).__init__(pattern=pattern, block=block)
|
|
||||||
|
|
||||||
|
|
||||||
class _ChatJoinRequest(_Handler):
|
|
||||||
def __init__(self, block: DVInput[bool] = DEFAULT_TRUE):
|
|
||||||
super(_ChatJoinRequest, self).__init__(block=block)
|
|
||||||
|
|
||||||
|
|
||||||
class _ChatMember(_Handler):
|
|
||||||
def __init__(self, chat_member_types: int = -1, block: DVInput[bool] = DEFAULT_TRUE):
|
|
||||||
super().__init__(chat_member_types=chat_member_types, block=block)
|
|
||||||
|
|
||||||
|
|
||||||
class _ChosenInlineResult(_Handler):
|
|
||||||
def __init__(self, block: DVInput[bool] = DEFAULT_TRUE, pattern: Union[str, Pattern] = None):
|
|
||||||
super().__init__(block=block, pattern=pattern)
|
|
||||||
|
|
||||||
|
|
||||||
class _Command(_Handler):
|
|
||||||
def __init__(self, command: str, filters: "BaseFilter" = None, block: DVInput[bool] = DEFAULT_TRUE):
|
|
||||||
super(_Command, self).__init__(command=command, filters=filters, block=block)
|
|
||||||
|
|
||||||
|
|
||||||
class _InlineQuery(_Handler):
|
|
||||||
def __init__(
|
|
||||||
self, pattern: Union[str, Pattern] = None, block: DVInput[bool] = DEFAULT_TRUE, chat_types: List[str] = None
|
|
||||||
):
|
|
||||||
super().__init__(pattern=pattern, block=block, chat_types=chat_types)
|
|
||||||
|
|
||||||
|
|
||||||
class _MessageNewChatMembers(_Handler):
|
|
||||||
def __init__(self, func: Callable[P, T] = None, *, priority: int = 5):
|
|
||||||
super().__init__()
|
|
||||||
self.func = func
|
|
||||||
self.priority = priority
|
|
||||||
|
|
||||||
def __call__(self, func: Callable[P, T] = None) -> Callable[P, T]:
|
|
||||||
self.func = self.func or func
|
|
||||||
data = {"type": "new_chat_member", "priority": self.priority}
|
|
||||||
if hasattr(func, _NORMAL_HANDLER_ATTR_NAME):
|
|
||||||
handler_datas = getattr(func, _NORMAL_HANDLER_ATTR_NAME)
|
|
||||||
handler_datas.append(data)
|
|
||||||
setattr(func, _NORMAL_HANDLER_ATTR_NAME, handler_datas)
|
|
||||||
else:
|
|
||||||
setattr(func, _NORMAL_HANDLER_ATTR_NAME, [data])
|
|
||||||
return func
|
|
||||||
|
|
||||||
|
|
||||||
class _Message(_Handler):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
filters: "BaseFilter",
|
|
||||||
block: DVInput[bool] = DEFAULT_TRUE,
|
|
||||||
):
|
|
||||||
super(_Message, self).__init__(filters=filters, block=block)
|
|
||||||
|
|
||||||
new_chat_members = _MessageNewChatMembers
|
|
||||||
|
|
||||||
|
|
||||||
class _PollAnswer(_Handler):
|
|
||||||
def __init__(self, block: DVInput[bool] = DEFAULT_TRUE):
|
|
||||||
super(_PollAnswer, self).__init__(block=block)
|
|
||||||
|
|
||||||
|
|
||||||
class _Poll(_Handler):
|
|
||||||
def __init__(self, block: DVInput[bool] = DEFAULT_TRUE):
|
|
||||||
super(_Poll, self).__init__(block=block)
|
|
||||||
|
|
||||||
|
|
||||||
class _PreCheckoutQuery(_Handler):
|
|
||||||
def __init__(self, block: DVInput[bool] = DEFAULT_TRUE):
|
|
||||||
super(_PreCheckoutQuery, self).__init__(block=block)
|
|
||||||
|
|
||||||
|
|
||||||
class _Prefix(_Handler):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
prefix: str,
|
|
||||||
command: str,
|
|
||||||
filters: BaseFilter = None,
|
|
||||||
block: DVInput[bool] = DEFAULT_TRUE,
|
|
||||||
):
|
|
||||||
super(_Prefix, self).__init__(prefix=prefix, command=command, filters=filters, block=block)
|
|
||||||
|
|
||||||
|
|
||||||
class _ShippingQuery(_Handler):
|
|
||||||
def __init__(self, block: DVInput[bool] = DEFAULT_TRUE):
|
|
||||||
super(_ShippingQuery, self).__init__(block=block)
|
|
||||||
|
|
||||||
|
|
||||||
class _StringCommand(_Handler):
|
|
||||||
def __init__(self, command: str):
|
|
||||||
super(_StringCommand, self).__init__(command=command)
|
|
||||||
|
|
||||||
|
|
||||||
class _StringRegex(_Handler):
|
|
||||||
def __init__(self, pattern: Union[str, Pattern], block: DVInput[bool] = DEFAULT_TRUE):
|
|
||||||
super(_StringRegex, self).__init__(pattern=pattern, block=block)
|
|
||||||
|
|
||||||
|
|
||||||
class _Type(_Handler):
|
|
||||||
# noinspection PyShadowingBuiltins
|
|
||||||
def __init__(
|
|
||||||
self, type: Type, strict: bool = False, block: DVInput[bool] = DEFAULT_TRUE # pylint: disable=redefined-builtin
|
|
||||||
):
|
|
||||||
super(_Type, self).__init__(type=type, strict=strict, block=block)
|
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyPep8Naming
|
|
||||||
class handler(_Handler):
|
|
||||||
def __init__(self, handler_type: Callable[P, HandlerType], **kwargs: P.kwargs):
|
|
||||||
self._type_ = handler_type
|
|
||||||
super(handler, self).__init__(**kwargs)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _type(self) -> Type[BaseHandler]:
|
|
||||||
# noinspection PyTypeChecker
|
|
||||||
return self._type_
|
|
||||||
|
|
||||||
callback_query = _CallbackQuery
|
|
||||||
chat_join_request = _ChatJoinRequest
|
|
||||||
chat_member = _ChatMember
|
|
||||||
chosen_inline_result = _ChosenInlineResult
|
|
||||||
command = _Command
|
|
||||||
inline_query = _InlineQuery
|
|
||||||
message = _Message
|
|
||||||
poll_answer = _PollAnswer
|
|
||||||
pool = _Poll
|
|
||||||
pre_checkout_query = _PreCheckoutQuery
|
|
||||||
prefix = _Prefix
|
|
||||||
shipping_query = _ShippingQuery
|
|
||||||
string_command = _StringCommand
|
|
||||||
string_regex = _StringRegex
|
|
||||||
type = _Type
|
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyPep8Naming
|
|
||||||
class error_handler:
|
|
||||||
def __init__(self, func: Callable[P, T] = None, *, block: bool = DEFAULT_TRUE):
|
|
||||||
self._func = func
|
|
||||||
self._block = block
|
|
||||||
|
|
||||||
def __call__(self, func: Callable[P, T] = None) -> Callable[P, T]:
|
|
||||||
self._func = func or self._func
|
|
||||||
data = {"type": "error", "block": self._block}
|
|
||||||
if hasattr(func, _NORMAL_HANDLER_ATTR_NAME):
|
|
||||||
handler_datas = getattr(func, _NORMAL_HANDLER_ATTR_NAME)
|
|
||||||
handler_datas.append(data)
|
|
||||||
setattr(func, _NORMAL_HANDLER_ATTR_NAME, handler_datas)
|
|
||||||
else:
|
|
||||||
setattr(func, _NORMAL_HANDLER_ATTR_NAME, [data])
|
|
||||||
return func
|
|
||||||
|
|
||||||
|
|
||||||
def _entry(func: Callable[P, T]) -> Callable[P, T]:
|
|
||||||
setattr(func, _CONVERSATION_HANDLER_ATTR_NAME, {"type": "entry"})
|
|
||||||
return func
|
|
||||||
|
|
||||||
|
|
||||||
class _State:
|
|
||||||
def __init__(self, state: Any):
|
|
||||||
self.state = state
|
|
||||||
|
|
||||||
def __call__(self, func: Callable[P, T] = None) -> Callable[P, T]:
|
|
||||||
setattr(func, _CONVERSATION_HANDLER_ATTR_NAME, {"type": "state", "state": self.state})
|
|
||||||
return func
|
|
||||||
|
|
||||||
|
|
||||||
def _fallback(func: Callable[P, T]) -> Callable[P, T]:
|
|
||||||
setattr(func, _CONVERSATION_HANDLER_ATTR_NAME, {"type": "fallback"})
|
|
||||||
return func
|
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyPep8Naming
|
|
||||||
class conversation(_Handler):
|
|
||||||
entry_point = _entry
|
|
||||||
state = _State
|
|
||||||
fallback = _fallback
|
|
||||||
|
|
||||||
|
|
||||||
class _Job:
|
|
||||||
kwargs: Dict = {}
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str = None,
|
|
||||||
data: object = None,
|
|
||||||
chat_id: int = None,
|
|
||||||
user_id: int = None,
|
|
||||||
job_kwargs: JSONDict = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.name = name
|
|
||||||
self.data = data
|
|
||||||
self.chat_id = chat_id
|
|
||||||
self.user_id = user_id
|
|
||||||
self.job_kwargs = {} if job_kwargs is None else job_kwargs
|
|
||||||
self.kwargs = kwargs
|
|
||||||
|
|
||||||
def __call__(self, func: JobCallback) -> JobCallback:
|
|
||||||
data = {
|
|
||||||
"name": self.name,
|
|
||||||
"data": self.data,
|
|
||||||
"chat_id": self.chat_id,
|
|
||||||
"user_id": self.user_id,
|
|
||||||
"job_kwargs": self.job_kwargs,
|
|
||||||
"kwargs": self.kwargs,
|
|
||||||
"type": re.sub(r"([A-Z])", lambda x: "_" + x.group().lower(), self.__class__.__name__).lstrip("_"),
|
|
||||||
}
|
|
||||||
if hasattr(func, _JOB_ATTR_NAME):
|
|
||||||
job_datas = getattr(func, _JOB_ATTR_NAME)
|
|
||||||
job_datas.append(data)
|
|
||||||
setattr(func, _JOB_ATTR_NAME, job_datas)
|
|
||||||
else:
|
|
||||||
setattr(func, _JOB_ATTR_NAME, [data])
|
|
||||||
return func
|
|
||||||
|
|
||||||
|
|
||||||
class _RunOnce(_Job):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
when: TimeType,
|
|
||||||
data: object = None,
|
|
||||||
name: str = None,
|
|
||||||
chat_id: int = None,
|
|
||||||
user_id: int = None,
|
|
||||||
job_kwargs: JSONDict = None,
|
|
||||||
):
|
|
||||||
super().__init__(name, data, chat_id, user_id, job_kwargs, when=when)
|
|
||||||
|
|
||||||
|
|
||||||
class _RunRepeating(_Job):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
interval: Union[float, datetime.timedelta],
|
|
||||||
first: TimeType = None,
|
|
||||||
last: TimeType = None,
|
|
||||||
data: object = None,
|
|
||||||
name: str = None,
|
|
||||||
chat_id: int = None,
|
|
||||||
user_id: int = None,
|
|
||||||
job_kwargs: JSONDict = None,
|
|
||||||
):
|
|
||||||
super().__init__(name, data, chat_id, user_id, job_kwargs, interval=interval, first=first, last=last)
|
|
||||||
|
|
||||||
|
|
||||||
class _RunMonthly(_Job):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
when: datetime.time,
|
|
||||||
day: int,
|
|
||||||
data: object = None,
|
|
||||||
name: str = None,
|
|
||||||
chat_id: int = None,
|
|
||||||
user_id: int = None,
|
|
||||||
job_kwargs: JSONDict = None,
|
|
||||||
):
|
|
||||||
super().__init__(name, data, chat_id, user_id, job_kwargs, when=when, day=day)
|
|
||||||
|
|
||||||
|
|
||||||
class _RunDaily(_Job):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
time: datetime.time,
|
|
||||||
days: Tuple[int, ...] = tuple(range(7)),
|
|
||||||
data: object = None,
|
|
||||||
name: str = None,
|
|
||||||
chat_id: int = None,
|
|
||||||
user_id: int = None,
|
|
||||||
job_kwargs: JSONDict = None,
|
|
||||||
):
|
|
||||||
super().__init__(name, data, chat_id, user_id, job_kwargs, time=time, days=days)
|
|
||||||
|
|
||||||
|
|
||||||
class _RunCustom(_Job):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
data: object = None,
|
|
||||||
name: str = None,
|
|
||||||
chat_id: int = None,
|
|
||||||
user_id: int = None,
|
|
||||||
job_kwargs: JSONDict = None,
|
|
||||||
):
|
|
||||||
super().__init__(name, data, chat_id, user_id, job_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyPep8Naming
|
|
||||||
class job:
|
|
||||||
run_once = _RunOnce
|
|
||||||
run_repeating = _RunRepeating
|
|
||||||
run_monthly = _RunMonthly
|
|
||||||
run_daily = _RunDaily
|
|
||||||
run_custom = _RunCustom
|
|
16
core/plugin/__init__.py
Normal file
16
core/plugin/__init__.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
"""插件"""
|
||||||
|
|
||||||
|
from core.plugin._handler import conversation, error_handler, handler
|
||||||
|
from core.plugin._job import TimeType, job
|
||||||
|
from core.plugin._plugin import Plugin, PluginType, get_all_plugins
|
||||||
|
|
||||||
|
__all__ = (
|
||||||
|
"Plugin",
|
||||||
|
"PluginType",
|
||||||
|
"get_all_plugins",
|
||||||
|
"handler",
|
||||||
|
"error_handler",
|
||||||
|
"conversation",
|
||||||
|
"job",
|
||||||
|
"TimeType",
|
||||||
|
)
|
175
core/plugin/_funcs.py
Normal file
175
core/plugin/_funcs.py
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Union, TYPE_CHECKING
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
|
import httpx
|
||||||
|
from httpx import UnsupportedProtocol
|
||||||
|
from telegram import Chat, Message, ReplyKeyboardRemove, Update
|
||||||
|
from telegram.error import BadRequest, Forbidden
|
||||||
|
from telegram.ext import CallbackContext, ConversationHandler, Job
|
||||||
|
|
||||||
|
from core.dependence.redisdb import RedisDB
|
||||||
|
from core.plugin._handler import conversation, handler
|
||||||
|
from utils.const import CACHE_DIR, REQUEST_HEADERS
|
||||||
|
from utils.error import UrlResourcesNotFoundError
|
||||||
|
from utils.helpers import sha1
|
||||||
|
from utils.log import logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.application import Application
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ujson as json
|
||||||
|
except ImportError:
|
||||||
|
import json
|
||||||
|
|
||||||
|
__all__ = (
|
||||||
|
"PluginFuncs",
|
||||||
|
"ConversationFuncs",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginFuncs:
|
||||||
|
_application: "Optional[Application]" = None
|
||||||
|
|
||||||
|
def set_application(self, application: "Application") -> None:
|
||||||
|
self._application = application
|
||||||
|
|
||||||
|
@property
|
||||||
|
def application(self) -> "Application":
|
||||||
|
if self._application is None:
|
||||||
|
raise RuntimeError("No application was set for this PluginManager.")
|
||||||
|
return self._application
|
||||||
|
|
||||||
|
async def _delete_message(self, context: CallbackContext) -> None:
|
||||||
|
job = context.job
|
||||||
|
message_id = job.data
|
||||||
|
chat_info = f"chat_id[{job.chat_id}]"
|
||||||
|
|
||||||
|
try:
|
||||||
|
chat = await self.get_chat(job.chat_id)
|
||||||
|
full_name = chat.full_name
|
||||||
|
if full_name:
|
||||||
|
chat_info = f"{full_name}[{chat.id}]"
|
||||||
|
else:
|
||||||
|
chat_info = f"{chat.title}[{chat.id}]"
|
||||||
|
except (BadRequest, Forbidden) as exc:
|
||||||
|
logger.warning("获取 chat info 失败 %s", exc.message)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("获取 chat info 消息失败 %s", str(exc))
|
||||||
|
|
||||||
|
logger.debug("删除消息 %s message_id[%s]", chat_info, message_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
await context.bot.delete_message(chat_id=job.chat_id, message_id=message_id)
|
||||||
|
except BadRequest as exc:
|
||||||
|
logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message)
|
||||||
|
|
||||||
|
async def get_chat(self, chat_id: Union[str, int], redis_db: Optional[RedisDB] = None, ttl: int = 86400) -> Chat:
|
||||||
|
application = self.application
|
||||||
|
redis_db: RedisDB = redis_db or self.application.managers.services_map.get(RedisDB, None)
|
||||||
|
|
||||||
|
if not redis_db:
|
||||||
|
return await application.bot.get_chat(chat_id)
|
||||||
|
|
||||||
|
qname = f"bot:chat:{chat_id}"
|
||||||
|
|
||||||
|
data = await redis_db.client.get(qname)
|
||||||
|
if data:
|
||||||
|
json_data = json.loads(data)
|
||||||
|
return Chat.de_json(json_data, application.telegram.bot)
|
||||||
|
|
||||||
|
chat_info = await application.telegram.bot.get_chat(chat_id)
|
||||||
|
await redis_db.client.set(qname, chat_info.to_json())
|
||||||
|
await redis_db.client.expire(qname, ttl)
|
||||||
|
return chat_info
|
||||||
|
|
||||||
|
def add_delete_message_job(
|
||||||
|
self,
|
||||||
|
message: Optional[Union[int, Message]] = None,
|
||||||
|
*,
|
||||||
|
delay: int = 60,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
chat: Optional[Union[int, Chat]] = None,
|
||||||
|
context: Optional[CallbackContext] = None,
|
||||||
|
) -> Job:
|
||||||
|
"""延迟删除消息"""
|
||||||
|
|
||||||
|
if isinstance(message, Message):
|
||||||
|
if chat is None:
|
||||||
|
chat = message.chat_id
|
||||||
|
message = message.id
|
||||||
|
|
||||||
|
chat = chat.id if isinstance(chat, Chat) else chat
|
||||||
|
|
||||||
|
job_queue = self.application.job_queue or context.job_queue
|
||||||
|
|
||||||
|
if job_queue is None:
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
return job_queue.run_once(
|
||||||
|
callback=self._delete_message,
|
||||||
|
when=delay,
|
||||||
|
data=message,
|
||||||
|
name=f"{chat}|{message}|{name}|delete_message" if name else f"{chat}|{message}|delete_message",
|
||||||
|
chat_id=chat,
|
||||||
|
job_kwargs={"replace_existing": True, "id": f"{chat}|{message}|delete_message"},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def download_resource(url: str, return_path: bool = False) -> str:
|
||||||
|
url_sha1 = sha1(url) # url 的 hash 值
|
||||||
|
pathed_url = Path(url)
|
||||||
|
|
||||||
|
file_name = url_sha1 + pathed_url.suffix
|
||||||
|
file_path = CACHE_DIR.joinpath(file_name)
|
||||||
|
|
||||||
|
if not file_path.exists(): # 若文件不存在,则下载
|
||||||
|
async with httpx.AsyncClient(headers=REQUEST_HEADERS) as client:
|
||||||
|
try:
|
||||||
|
response = await client.get(url)
|
||||||
|
except UnsupportedProtocol:
|
||||||
|
logger.error("链接不支持 url[%s]", url)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if response.is_error:
|
||||||
|
logger.error("请求出现错误 url[%s] status_code[%s]", url, response.status_code)
|
||||||
|
raise UrlResourcesNotFoundError(url)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error("download_resource 获取url[%s] 错误 status_code[%s]", url, response.status_code)
|
||||||
|
raise UrlResourcesNotFoundError(url)
|
||||||
|
|
||||||
|
async with aiofiles.open(file_path, mode="wb") as f:
|
||||||
|
await f.write(response.content)
|
||||||
|
|
||||||
|
logger.debug("download_resource 获取url[%s] 并下载到 file_dir[%s]", url, file_path)
|
||||||
|
|
||||||
|
return file_path if return_path else Path(file_path).as_uri()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_args(context: Optional[CallbackContext] = None) -> List[str]:
|
||||||
|
args = context.args
|
||||||
|
match = context.match
|
||||||
|
|
||||||
|
if args is None:
|
||||||
|
if match is not None and (command := match.groups()[0]):
|
||||||
|
temp = []
|
||||||
|
command_parts = command.split(" ")
|
||||||
|
for command_part in command_parts:
|
||||||
|
if command_part:
|
||||||
|
temp.append(command_part)
|
||||||
|
return temp
|
||||||
|
return []
|
||||||
|
if len(args) >= 1:
|
||||||
|
return args
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationFuncs:
|
||||||
|
@conversation.fallback
|
||||||
|
@handler.command(command="cancel", block=True)
|
||||||
|
async def cancel(self, update: Update, _) -> int:
|
||||||
|
await update.effective_message.reply_text("退出命令", reply_markup=ReplyKeyboardRemove())
|
||||||
|
return ConversationHandler.END
|
380
core/plugin/_handler.py
Normal file
380
core/plugin/_handler.py
Normal file
@ -0,0 +1,380 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from functools import wraps
|
||||||
|
from importlib import import_module
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
ClassVar,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Pattern,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
from telegram._utils.defaultvalue import DEFAULT_TRUE
|
||||||
|
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
from telegram._utils.types import DVInput
|
||||||
|
from telegram.ext import BaseHandler
|
||||||
|
from telegram.ext.filters import BaseFilter
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
|
from core.handler.callbackqueryhandler import CallbackQueryHandler
|
||||||
|
from utils.const import WRAPPER_ASSIGNMENTS as _WRAPPER_ASSIGNMENTS
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.builtins.dispatcher import AbstractDispatcher
|
||||||
|
|
||||||
|
__all__ = (
|
||||||
|
"handler",
|
||||||
|
"conversation",
|
||||||
|
"ConversationDataType",
|
||||||
|
"ConversationData",
|
||||||
|
"HandlerData",
|
||||||
|
"ErrorHandlerData",
|
||||||
|
"error_handler",
|
||||||
|
)
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
T = TypeVar("T")
|
||||||
|
R = TypeVar("R")
|
||||||
|
UT = TypeVar("UT")
|
||||||
|
|
||||||
|
HandlerType = TypeVar("HandlerType", bound=BaseHandler)
|
||||||
|
HandlerCls = Type[HandlerType]
|
||||||
|
|
||||||
|
Module = import_module("telegram.ext")
|
||||||
|
|
||||||
|
HANDLER_DATA_ATTR_NAME = "_handler_datas"
|
||||||
|
"""用于储存生成 handler 时所需要的参数(例如 block)的属性名"""
|
||||||
|
|
||||||
|
ERROR_HANDLER_ATTR_NAME = "_error_handler_data"
|
||||||
|
|
||||||
|
CONVERSATION_HANDLER_ATTR_NAME = "_conversation_handler_data"
|
||||||
|
"""用于储存生成 ConversationHandler 时所需要的参数(例如 block)的属性名"""
|
||||||
|
|
||||||
|
WRAPPER_ASSIGNMENTS = list(
|
||||||
|
set(
|
||||||
|
_WRAPPER_ASSIGNMENTS
|
||||||
|
+ [
|
||||||
|
HANDLER_DATA_ATTR_NAME,
|
||||||
|
ERROR_HANDLER_ATTR_NAME,
|
||||||
|
CONVERSATION_HANDLER_ATTR_NAME,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(init=True)
|
||||||
|
class HandlerData:
|
||||||
|
type: Type[HandlerType]
|
||||||
|
admin: bool
|
||||||
|
kwargs: Dict[str, Any]
|
||||||
|
dispatcher: Optional[Type["AbstractDispatcher"]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class _Handler:
|
||||||
|
_type: Type["HandlerType"]
|
||||||
|
|
||||||
|
kwargs: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs) -> None:
|
||||||
|
"""用于获取 python-telegram-bot 中对应的 handler class"""
|
||||||
|
|
||||||
|
handler_name = f"{cls.__name__.strip('_')}Handler"
|
||||||
|
|
||||||
|
if handler_name == "CallbackQueryHandler":
|
||||||
|
cls._type = CallbackQueryHandler
|
||||||
|
return
|
||||||
|
|
||||||
|
cls._type = getattr(Module, handler_name, None)
|
||||||
|
|
||||||
|
def __init__(self, admin: bool = False, dispatcher: Optional[Type["AbstractDispatcher"]] = None, **kwargs) -> None:
|
||||||
|
self.dispatcher = dispatcher
|
||||||
|
self.admin = admin
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def __call__(self, func: Callable[P, R]) -> Callable[P, R]:
|
||||||
|
"""decorator实现,从 func 生成 Handler"""
|
||||||
|
|
||||||
|
handler_datas = getattr(func, HANDLER_DATA_ATTR_NAME, [])
|
||||||
|
handler_datas.append(
|
||||||
|
HandlerData(type=self._type, admin=self.admin, kwargs=self.kwargs, dispatcher=self.dispatcher)
|
||||||
|
)
|
||||||
|
setattr(func, HANDLER_DATA_ATTR_NAME, handler_datas)
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
class _CallbackQuery(_Handler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pattern: Union[str, Pattern, type, Callable[[object], Optional[bool]]] = None,
|
||||||
|
*,
|
||||||
|
block: DVInput[bool] = DEFAULT_TRUE,
|
||||||
|
admin: bool = False,
|
||||||
|
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||||
|
):
|
||||||
|
super(_CallbackQuery, self).__init__(pattern=pattern, block=block, admin=admin, dispatcher=dispatcher)
|
||||||
|
|
||||||
|
|
||||||
|
class _ChatJoinRequest(_Handler):
|
||||||
|
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
|
||||||
|
super(_ChatJoinRequest, self).__init__(block=block, dispatcher=dispatcher)
|
||||||
|
|
||||||
|
|
||||||
|
class _ChatMember(_Handler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chat_member_types: int = -1,
|
||||||
|
*,
|
||||||
|
block: DVInput[bool] = DEFAULT_TRUE,
|
||||||
|
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(chat_member_types=chat_member_types, block=block, dispatcher=dispatcher)
|
||||||
|
|
||||||
|
|
||||||
|
class _ChosenInlineResult(_Handler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
block: DVInput[bool] = DEFAULT_TRUE,
|
||||||
|
*,
|
||||||
|
pattern: Union[str, Pattern] = None,
|
||||||
|
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(block=block, pattern=pattern, dispatcher=dispatcher)
|
||||||
|
|
||||||
|
|
||||||
|
class _Command(_Handler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
command: Union[str, List[str]],
|
||||||
|
filters: "BaseFilter" = None,
|
||||||
|
*,
|
||||||
|
block: DVInput[bool] = DEFAULT_TRUE,
|
||||||
|
admin: bool = False,
|
||||||
|
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||||
|
):
|
||||||
|
super(_Command, self).__init__(
|
||||||
|
command=command, filters=filters, block=block, admin=admin, dispatcher=dispatcher
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _InlineQuery(_Handler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pattern: Union[str, Pattern] = None,
|
||||||
|
chat_types: List[str] = None,
|
||||||
|
*,
|
||||||
|
block: DVInput[bool] = DEFAULT_TRUE,
|
||||||
|
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||||
|
):
|
||||||
|
super(_InlineQuery, self).__init__(pattern=pattern, block=block, chat_types=chat_types, dispatcher=dispatcher)
|
||||||
|
|
||||||
|
|
||||||
|
class _Message(_Handler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
filters: BaseFilter,
|
||||||
|
*,
|
||||||
|
block: DVInput[bool] = DEFAULT_TRUE,
|
||||||
|
admin: bool = False,
|
||||||
|
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||||
|
) -> None:
|
||||||
|
super(_Message, self).__init__(filters=filters, block=block, admin=admin, dispatcher=dispatcher)
|
||||||
|
|
||||||
|
|
||||||
|
class _PollAnswer(_Handler):
|
||||||
|
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
|
||||||
|
super(_PollAnswer, self).__init__(block=block, dispatcher=dispatcher)
|
||||||
|
|
||||||
|
|
||||||
|
class _Poll(_Handler):
|
||||||
|
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
|
||||||
|
super(_Poll, self).__init__(block=block, dispatcher=dispatcher)
|
||||||
|
|
||||||
|
|
||||||
|
class _PreCheckoutQuery(_Handler):
|
||||||
|
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
|
||||||
|
super(_PreCheckoutQuery, self).__init__(block=block, dispatcher=dispatcher)
|
||||||
|
|
||||||
|
|
||||||
|
class _Prefix(_Handler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix: str,
|
||||||
|
command: str,
|
||||||
|
filters: BaseFilter = None,
|
||||||
|
*,
|
||||||
|
block: DVInput[bool] = DEFAULT_TRUE,
|
||||||
|
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||||
|
):
|
||||||
|
super(_Prefix, self).__init__(
|
||||||
|
prefix=prefix, command=command, filters=filters, block=block, dispatcher=dispatcher
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _ShippingQuery(_Handler):
|
||||||
|
def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None):
|
||||||
|
super(_ShippingQuery, self).__init__(block=block, dispatcher=dispatcher)
|
||||||
|
|
||||||
|
|
||||||
|
class _StringCommand(_Handler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
command: str,
|
||||||
|
*,
|
||||||
|
admin: bool = False,
|
||||||
|
block: DVInput[bool] = DEFAULT_TRUE,
|
||||||
|
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||||
|
):
|
||||||
|
super(_StringCommand, self).__init__(command=command, block=block, admin=admin, dispatcher=dispatcher)
|
||||||
|
|
||||||
|
|
||||||
|
class _StringRegex(_Handler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pattern: Union[str, Pattern],
|
||||||
|
*,
|
||||||
|
block: DVInput[bool] = DEFAULT_TRUE,
|
||||||
|
admin: bool = False,
|
||||||
|
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||||
|
):
|
||||||
|
super(_StringRegex, self).__init__(pattern=pattern, block=block, admin=admin, dispatcher=dispatcher)
|
||||||
|
|
||||||
|
|
||||||
|
class _Type(_Handler):
|
||||||
|
# noinspection PyShadowingBuiltins
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
type: Type[UT], # pylint: disable=W0622
|
||||||
|
strict: bool = False,
|
||||||
|
*,
|
||||||
|
block: DVInput[bool] = DEFAULT_TRUE,
|
||||||
|
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||||
|
): # pylint: disable=redefined-builtin
|
||||||
|
super(_Type, self).__init__(type=type, strict=strict, block=block, dispatcher=dispatcher)
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyPep8Naming
|
||||||
|
class handler(_Handler):
|
||||||
|
callback_query = _CallbackQuery
|
||||||
|
chat_join_request = _ChatJoinRequest
|
||||||
|
chat_member = _ChatMember
|
||||||
|
chosen_inline_result = _ChosenInlineResult
|
||||||
|
command = _Command
|
||||||
|
inline_query = _InlineQuery
|
||||||
|
message = _Message
|
||||||
|
poll_answer = _PollAnswer
|
||||||
|
pool = _Poll
|
||||||
|
pre_checkout_query = _PreCheckoutQuery
|
||||||
|
prefix = _Prefix
|
||||||
|
shipping_query = _ShippingQuery
|
||||||
|
string_command = _StringCommand
|
||||||
|
string_regex = _StringRegex
|
||||||
|
type = _Type
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
handler_type: Union[Callable[P, "HandlerType"], Type["HandlerType"]],
|
||||||
|
*,
|
||||||
|
admin: bool = False,
|
||||||
|
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> None:
|
||||||
|
self._type = handler_type
|
||||||
|
super().__init__(admin=admin, dispatcher=dispatcher, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationDataType(Enum):
|
||||||
|
"""conversation handler 的类型"""
|
||||||
|
|
||||||
|
Entry = "entry"
|
||||||
|
State = "state"
|
||||||
|
Fallback = "fallback"
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationData(BaseModel):
|
||||||
|
"""用于储存 conversation handler 的数据"""
|
||||||
|
|
||||||
|
type: ConversationDataType
|
||||||
|
state: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class _ConversationType:
|
||||||
|
_type: ClassVar[ConversationDataType]
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs) -> None:
|
||||||
|
cls._type = ConversationDataType(cls.__name__.lstrip("_").lower())
|
||||||
|
|
||||||
|
|
||||||
|
def _entry(func: Callable[P, R]) -> Callable[P, R]:
|
||||||
|
setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=ConversationDataType.Entry))
|
||||||
|
|
||||||
|
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
||||||
|
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
|
class _State(_ConversationType):
|
||||||
|
def __init__(self, state: Any) -> None:
|
||||||
|
self.state = state
|
||||||
|
|
||||||
|
def __call__(self, func: Callable[P, T] = None) -> Callable[P, T]:
|
||||||
|
setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=self._type, state=self.state))
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
def _fallback(func: Callable[P, R]) -> Callable[P, R]:
|
||||||
|
setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=ConversationDataType.Fallback))
|
||||||
|
|
||||||
|
@wraps(func, assigned=WRAPPER_ASSIGNMENTS)
|
||||||
|
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyPep8Naming
|
||||||
|
class conversation(_Handler):
|
||||||
|
entry_point = _entry
|
||||||
|
state = _State
|
||||||
|
fallback = _fallback
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(init=True)
|
||||||
|
class ErrorHandlerData:
|
||||||
|
block: bool
|
||||||
|
func: Optional[Callable] = None
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyPep8Naming
|
||||||
|
class error_handler:
|
||||||
|
_func: Callable[P, R]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
block: bool = DEFAULT_TRUE,
|
||||||
|
):
|
||||||
|
self._block = block
|
||||||
|
|
||||||
|
def __call__(self, func: Callable[P, T]) -> Callable[P, T]:
|
||||||
|
self._func = func
|
||||||
|
wraps(func, assigned=WRAPPER_ASSIGNMENTS)(self)
|
||||||
|
|
||||||
|
handler_datas = getattr(func, ERROR_HANDLER_ATTR_NAME, [])
|
||||||
|
handler_datas.append(ErrorHandlerData(block=self._block))
|
||||||
|
setattr(self._func, ERROR_HANDLER_ATTR_NAME, handler_datas)
|
||||||
|
|
||||||
|
return self._func
|
173
core/plugin/_job.py
Normal file
173
core/plugin/_job.py
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
"""插件"""
|
||||||
|
import datetime
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
from telegram._utils.types import JSONDict
|
||||||
|
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
from telegram.ext._utils.types import JobCallback
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.builtins.dispatcher import AbstractDispatcher
|
||||||
|
|
||||||
|
__all__ = ["TimeType", "job", "JobData"]
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
T = TypeVar("T")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
TimeType = Union[float, datetime.timedelta, datetime.datetime, datetime.time]
|
||||||
|
|
||||||
|
_JOB_ATTR_NAME = "_job_data"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(init=True)
|
||||||
|
class JobData:
|
||||||
|
name: str
|
||||||
|
data: Any
|
||||||
|
chat_id: int
|
||||||
|
user_id: int
|
||||||
|
type: str
|
||||||
|
job_kwargs: JSONDict = field(default_factory=dict)
|
||||||
|
kwargs: JSONDict = field(default_factory=dict)
|
||||||
|
dispatcher: Optional[Type["AbstractDispatcher"]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class _Job:
|
||||||
|
kwargs: Dict = {}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str = None,
|
||||||
|
data: object = None,
|
||||||
|
chat_id: int = None,
|
||||||
|
user_id: int = None,
|
||||||
|
job_kwargs: JSONDict = None,
|
||||||
|
*,
|
||||||
|
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
self.data = data
|
||||||
|
self.chat_id = chat_id
|
||||||
|
self.user_id = user_id
|
||||||
|
self.job_kwargs = {} if job_kwargs is None else job_kwargs
|
||||||
|
self.kwargs = kwargs
|
||||||
|
if dispatcher is None:
|
||||||
|
from core.builtins.dispatcher import JobDispatcher
|
||||||
|
|
||||||
|
dispatcher = JobDispatcher
|
||||||
|
|
||||||
|
self.dispatcher = dispatcher
|
||||||
|
|
||||||
|
def __call__(self, func: JobCallback) -> JobCallback:
|
||||||
|
data = JobData(
|
||||||
|
name=self.name,
|
||||||
|
data=self.data,
|
||||||
|
chat_id=self.chat_id,
|
||||||
|
user_id=self.user_id,
|
||||||
|
job_kwargs=self.job_kwargs,
|
||||||
|
kwargs=self.kwargs,
|
||||||
|
type=re.sub(r"([A-Z])", lambda x: "_" + x.group().lower(), self.__class__.__name__).lstrip("_"),
|
||||||
|
dispatcher=self.dispatcher,
|
||||||
|
)
|
||||||
|
if hasattr(func, _JOB_ATTR_NAME):
|
||||||
|
job_datas = getattr(func, _JOB_ATTR_NAME)
|
||||||
|
job_datas.append(data)
|
||||||
|
setattr(func, _JOB_ATTR_NAME, job_datas)
|
||||||
|
else:
|
||||||
|
setattr(func, _JOB_ATTR_NAME, [data])
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
class _RunOnce(_Job):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
when: TimeType,
|
||||||
|
data: object = None,
|
||||||
|
name: str = None,
|
||||||
|
chat_id: int = None,
|
||||||
|
user_id: int = None,
|
||||||
|
job_kwargs: JSONDict = None,
|
||||||
|
*,
|
||||||
|
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, when=when)
|
||||||
|
|
||||||
|
|
||||||
|
class _RunRepeating(_Job):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
interval: Union[float, datetime.timedelta],
|
||||||
|
first: TimeType = None,
|
||||||
|
last: TimeType = None,
|
||||||
|
data: object = None,
|
||||||
|
name: str = None,
|
||||||
|
chat_id: int = None,
|
||||||
|
user_id: int = None,
|
||||||
|
job_kwargs: JSONDict = None,
|
||||||
|
*,
|
||||||
|
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, interval=interval, first=first, last=last
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _RunMonthly(_Job):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
when: datetime.time,
|
||||||
|
day: int,
|
||||||
|
data: object = None,
|
||||||
|
name: str = None,
|
||||||
|
chat_id: int = None,
|
||||||
|
user_id: int = None,
|
||||||
|
job_kwargs: JSONDict = None,
|
||||||
|
*,
|
||||||
|
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, when=when, day=day)
|
||||||
|
|
||||||
|
|
||||||
|
class _RunDaily(_Job):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
time: datetime.time,
|
||||||
|
days: Tuple[int, ...] = tuple(range(7)),
|
||||||
|
data: object = None,
|
||||||
|
name: str = None,
|
||||||
|
chat_id: int = None,
|
||||||
|
user_id: int = None,
|
||||||
|
job_kwargs: JSONDict = None,
|
||||||
|
*,
|
||||||
|
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, time=time, days=days)
|
||||||
|
|
||||||
|
|
||||||
|
class _RunCustom(_Job):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data: object = None,
|
||||||
|
name: str = None,
|
||||||
|
chat_id: int = None,
|
||||||
|
user_id: int = None,
|
||||||
|
job_kwargs: JSONDict = None,
|
||||||
|
*,
|
||||||
|
dispatcher: Optional[Type["AbstractDispatcher"]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher)
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyPep8Naming
|
||||||
|
class job:
|
||||||
|
run_once = _RunOnce
|
||||||
|
run_repeating = _RunRepeating
|
||||||
|
run_monthly = _RunMonthly
|
||||||
|
run_daily = _RunDaily
|
||||||
|
run_custom = _RunCustom
|
303
core/plugin/_plugin.py
Normal file
303
core/plugin/_plugin.py
Normal file
@ -0,0 +1,303 @@
|
|||||||
|
"""插件"""
|
||||||
|
import asyncio
|
||||||
|
from abc import ABC
|
||||||
|
from dataclasses import asdict
|
||||||
|
from datetime import timedelta
|
||||||
|
from functools import partial, wraps
|
||||||
|
from itertools import chain
|
||||||
|
from multiprocessing import RLock as Lock
|
||||||
|
from types import MethodType
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
ClassVar,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from telegram.ext import BaseHandler, ConversationHandler, Job, TypeHandler
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
|
from core.handler.adminhandler import AdminHandler
|
||||||
|
from core.plugin._funcs import ConversationFuncs, PluginFuncs
|
||||||
|
from core.plugin._handler import ConversationDataType
|
||||||
|
from utils.const import WRAPPER_ASSIGNMENTS
|
||||||
|
from utils.helpers import isabstract
|
||||||
|
from utils.log import logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.application import Application
|
||||||
|
from core.plugin._handler import ConversationData, HandlerData, ErrorHandlerData
|
||||||
|
from core.plugin._job import JobData
|
||||||
|
from multiprocessing.synchronize import RLock as LockType
|
||||||
|
|
||||||
|
__all__ = ("Plugin", "PluginType", "get_all_plugins")
|
||||||
|
|
||||||
|
wraps = partial(wraps, assigned=WRAPPER_ASSIGNMENTS)
|
||||||
|
P = ParamSpec("P")
|
||||||
|
T = TypeVar("T")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
HandlerType = TypeVar("HandlerType", bound=BaseHandler)
|
||||||
|
|
||||||
|
_HANDLER_DATA_ATTR_NAME = "_handler_datas"
|
||||||
|
"""用于储存生成 handler 时所需要的参数(例如 block)的属性名"""
|
||||||
|
|
||||||
|
_CONVERSATION_HANDLER_ATTR_NAME = "_conversation_handler_data"
|
||||||
|
"""用于储存生成 ConversationHandler 时所需要的参数(例如 block)的属性名"""
|
||||||
|
|
||||||
|
_ERROR_HANDLER_ATTR_NAME = "_error_handler_data"
|
||||||
|
|
||||||
|
_JOB_ATTR_NAME = "_job_data"
|
||||||
|
|
||||||
|
_EXCLUDE_ATTRS = ["handlers", "jobs", "error_handlers"]
|
||||||
|
|
||||||
|
|
||||||
|
class _Plugin(PluginFuncs):
|
||||||
|
"""插件"""
|
||||||
|
|
||||||
|
_lock: ClassVar["LockType"] = Lock()
|
||||||
|
_asyncio_lock: ClassVar["LockType"] = asyncio.Lock()
|
||||||
|
_installed: bool = False
|
||||||
|
|
||||||
|
_handlers: Optional[List[HandlerType]] = None
|
||||||
|
_error_handlers: Optional[List["ErrorHandlerData"]] = None
|
||||||
|
_jobs: Optional[List[Job]] = None
|
||||||
|
_application: "Optional[Application]" = None
|
||||||
|
|
||||||
|
def set_application(self, application: "Application") -> None:
|
||||||
|
self._application = application
|
||||||
|
|
||||||
|
@property
|
||||||
|
def application(self) -> "Application":
|
||||||
|
if self._application is None:
|
||||||
|
raise RuntimeError("No application was set for this Plugin.")
|
||||||
|
return self._application
|
||||||
|
|
||||||
|
@property
|
||||||
|
def handlers(self) -> List[HandlerType]:
|
||||||
|
"""该插件的所有 handler"""
|
||||||
|
with self._lock:
|
||||||
|
if self._handlers is None:
|
||||||
|
self._handlers = []
|
||||||
|
|
||||||
|
for attr in dir(self):
|
||||||
|
if (
|
||||||
|
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
||||||
|
and isinstance(func := getattr(self, attr), MethodType)
|
||||||
|
and (datas := getattr(func, _HANDLER_DATA_ATTR_NAME, []))
|
||||||
|
):
|
||||||
|
for data in datas:
|
||||||
|
data: "HandlerData"
|
||||||
|
if data.admin:
|
||||||
|
self._handlers.append(
|
||||||
|
AdminHandler(
|
||||||
|
handler=data.type(
|
||||||
|
callback=func,
|
||||||
|
**data.kwargs,
|
||||||
|
),
|
||||||
|
application=self.application,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._handlers.append(
|
||||||
|
data.type(
|
||||||
|
callback=func,
|
||||||
|
**data.kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return self._handlers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def error_handlers(self) -> List["ErrorHandlerData"]:
|
||||||
|
with self._lock:
|
||||||
|
if self._error_handlers is None:
|
||||||
|
self._error_handlers = []
|
||||||
|
for attr in dir(self):
|
||||||
|
if (
|
||||||
|
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
||||||
|
and isinstance(func := getattr(self, attr), MethodType)
|
||||||
|
and (datas := getattr(func, _ERROR_HANDLER_ATTR_NAME, []))
|
||||||
|
):
|
||||||
|
for data in datas:
|
||||||
|
data: "ErrorHandlerData"
|
||||||
|
data.func = func
|
||||||
|
self._error_handlers.append(data)
|
||||||
|
|
||||||
|
return self._error_handlers
|
||||||
|
|
||||||
|
def _install_jobs(self) -> None:
|
||||||
|
if self._jobs is None:
|
||||||
|
self._jobs = []
|
||||||
|
for attr in dir(self):
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
|
if (
|
||||||
|
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
||||||
|
and isinstance(func := getattr(self, attr), MethodType)
|
||||||
|
and (datas := getattr(func, _JOB_ATTR_NAME, []))
|
||||||
|
):
|
||||||
|
for data in datas:
|
||||||
|
data: "JobData"
|
||||||
|
self._jobs.append(
|
||||||
|
getattr(self.application.telegram.job_queue, data.type)(
|
||||||
|
callback=func,
|
||||||
|
**data.kwargs,
|
||||||
|
**{
|
||||||
|
key: value
|
||||||
|
for key, value in asdict(data).items()
|
||||||
|
if key not in ["type", "kwargs", "dispatcher"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def jobs(self) -> List[Job]:
|
||||||
|
with self._lock:
|
||||||
|
if self._jobs is None:
|
||||||
|
self._jobs = []
|
||||||
|
self._install_jobs()
|
||||||
|
return self._jobs
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""初始化插件"""
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
"""销毁插件"""
|
||||||
|
|
||||||
|
async def install(self) -> None:
|
||||||
|
"""安装"""
|
||||||
|
group = id(self)
|
||||||
|
if not self._installed:
|
||||||
|
await self.initialize()
|
||||||
|
# initialize 必须先执行 如果出现异常不会执行 add_handler 以免出现问题
|
||||||
|
async with self._asyncio_lock:
|
||||||
|
self._install_jobs()
|
||||||
|
|
||||||
|
for h in self.handlers:
|
||||||
|
if not isinstance(h, TypeHandler):
|
||||||
|
self.application.telegram.add_handler(h, group)
|
||||||
|
else:
|
||||||
|
self.application.telegram.add_handler(h, -1)
|
||||||
|
|
||||||
|
for h in self.error_handlers:
|
||||||
|
self.application.telegram.add_error_handler(h.func, h.block)
|
||||||
|
self._installed = True
|
||||||
|
|
||||||
|
async def uninstall(self) -> None:
|
||||||
|
"""卸载"""
|
||||||
|
group = id(self)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if self._installed:
|
||||||
|
if group in self.application.telegram.handlers:
|
||||||
|
del self.application.telegram.handlers[id(self)]
|
||||||
|
|
||||||
|
for h in self.handlers:
|
||||||
|
if isinstance(h, TypeHandler):
|
||||||
|
self.application.telegram.remove_handler(h, -1)
|
||||||
|
for h in self.error_handlers:
|
||||||
|
self.application.telegram.remove_error_handler(h.func)
|
||||||
|
|
||||||
|
for j in self.application.telegram.job_queue.jobs():
|
||||||
|
j.schedule_removal()
|
||||||
|
await self.shutdown()
|
||||||
|
self._installed = False
|
||||||
|
|
||||||
|
async def reload(self) -> None:
|
||||||
|
await self.uninstall()
|
||||||
|
await self.install()
|
||||||
|
|
||||||
|
|
||||||
|
class _Conversation(_Plugin, ConversationFuncs, ABC):
|
||||||
|
"""Conversation类"""
|
||||||
|
|
||||||
|
# noinspection SpellCheckingInspection
|
||||||
|
class Config(BaseModel):
|
||||||
|
allow_reentry: bool = False
|
||||||
|
per_chat: bool = True
|
||||||
|
per_user: bool = True
|
||||||
|
per_message: bool = False
|
||||||
|
conversation_timeout: Optional[Union[float, timedelta]] = None
|
||||||
|
name: Optional[str] = None
|
||||||
|
map_to_parent: Optional[Dict[object, object]] = None
|
||||||
|
block: bool = False
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs):
|
||||||
|
cls._conversation_kwargs = kwargs
|
||||||
|
super(_Conversation, cls).__init_subclass__()
|
||||||
|
return cls
|
||||||
|
|
||||||
|
@property
|
||||||
|
def handlers(self) -> List[HandlerType]:
|
||||||
|
with self._lock:
|
||||||
|
if self._handlers is None:
|
||||||
|
self._handlers = []
|
||||||
|
|
||||||
|
entry_points: List[HandlerType] = []
|
||||||
|
states: Dict[Any, List[HandlerType]] = {}
|
||||||
|
fallbacks: List[HandlerType] = []
|
||||||
|
for attr in dir(self):
|
||||||
|
if (
|
||||||
|
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
||||||
|
and (func := getattr(self, attr, None)) is not None
|
||||||
|
and (datas := getattr(func, _HANDLER_DATA_ATTR_NAME, []))
|
||||||
|
):
|
||||||
|
conversation_data: "ConversationData"
|
||||||
|
|
||||||
|
handlers: List[HandlerType] = []
|
||||||
|
for data in datas:
|
||||||
|
handlers.append(
|
||||||
|
data.type(
|
||||||
|
callback=func,
|
||||||
|
**data.kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if conversation_data := getattr(func, _CONVERSATION_HANDLER_ATTR_NAME, None):
|
||||||
|
if (_type := conversation_data.type) is ConversationDataType.Entry:
|
||||||
|
entry_points.extend(handlers)
|
||||||
|
elif _type is ConversationDataType.State:
|
||||||
|
if conversation_data.state in states:
|
||||||
|
states[conversation_data.state].extend(handlers)
|
||||||
|
else:
|
||||||
|
states[conversation_data.state] = handlers
|
||||||
|
elif _type is ConversationDataType.Fallback:
|
||||||
|
fallbacks.extend(handlers)
|
||||||
|
else:
|
||||||
|
self._handlers.extend(handlers)
|
||||||
|
else:
|
||||||
|
self._handlers.extend(handlers)
|
||||||
|
if entry_points and states and fallbacks:
|
||||||
|
kwargs = self._conversation_kwargs
|
||||||
|
kwargs.update(self.Config().dict())
|
||||||
|
self._handlers.append(ConversationHandler(entry_points, states, fallbacks, **kwargs))
|
||||||
|
else:
|
||||||
|
temp_dict = {"entry_points": entry_points, "states": states, "fallbacks": fallbacks}
|
||||||
|
reason = map(lambda x: f"'{x[0]}'", filter(lambda x: not x[1], temp_dict.items()))
|
||||||
|
logger.warning(
|
||||||
|
"'%s' 因缺少 '%s' 而生成无法生成 ConversationHandler", self.__class__.__name__, ", ".join(reason)
|
||||||
|
)
|
||||||
|
return self._handlers
|
||||||
|
|
||||||
|
|
||||||
|
class Plugin(_Plugin, ABC):
|
||||||
|
"""插件"""
|
||||||
|
|
||||||
|
Conversation = _Conversation
|
||||||
|
|
||||||
|
|
||||||
|
PluginType = TypeVar("PluginType", bound=_Plugin)
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_plugins() -> Iterable[Type[PluginType]]:
|
||||||
|
"""获取所有 Plugin 的子类"""
|
||||||
|
return filter(
|
||||||
|
lambda x: x.__name__[0] != "_" and not isabstract(x),
|
||||||
|
chain(Plugin.__subclasses__(), _Conversation.__subclasses__()),
|
||||||
|
)
|
@ -1,14 +0,0 @@
|
|||||||
from core.base.mysql import MySQL
|
|
||||||
from core.base.redisdb import RedisDB
|
|
||||||
from core.service import init_service
|
|
||||||
from .cache import QuizCache
|
|
||||||
from .repositories import QuizRepository
|
|
||||||
from .services import QuizService
|
|
||||||
|
|
||||||
|
|
||||||
@init_service
|
|
||||||
def create_quiz_service(mysql: MySQL, redis: RedisDB):
|
|
||||||
_repository = QuizRepository(mysql)
|
|
||||||
_cache = QuizCache(redis)
|
|
||||||
_service = QuizService(_repository, _cache)
|
|
||||||
return _service
|
|
@ -1,19 +0,0 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
from .models import Answer, Question
|
|
||||||
|
|
||||||
|
|
||||||
def CreatQuestionFromSQLData(data: tuple) -> List[Question]:
|
|
||||||
temp_list = []
|
|
||||||
for temp_data in data:
|
|
||||||
(question_id, text) = temp_data
|
|
||||||
temp_list.append(Question(question_id, text))
|
|
||||||
return temp_list
|
|
||||||
|
|
||||||
|
|
||||||
def CreatAnswerFromSQLData(data: tuple) -> List[Answer]:
|
|
||||||
temp_list = []
|
|
||||||
for temp_data in data:
|
|
||||||
(answer_id, question_id, is_correct, text) = temp_data
|
|
||||||
temp_list.append(Answer(answer_id, question_id, is_correct, text))
|
|
||||||
return temp_list
|
|
@ -1,98 +0,0 @@
|
|||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from sqlmodel import SQLModel, Field, Column, Integer, ForeignKey
|
|
||||||
|
|
||||||
from utils.baseobject import BaseObject
|
|
||||||
from utils.typedefs import JSONDict
|
|
||||||
|
|
||||||
|
|
||||||
class AnswerDB(SQLModel, table=True):
|
|
||||||
__tablename__ = "answer"
|
|
||||||
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
|
||||||
|
|
||||||
id: int = Field(primary_key=True)
|
|
||||||
question_id: Optional[int] = Field(
|
|
||||||
sa_column=Column(Integer, ForeignKey("question.id", ondelete="RESTRICT", onupdate="RESTRICT"))
|
|
||||||
)
|
|
||||||
is_correct: Optional[bool] = Field()
|
|
||||||
text: Optional[str] = Field()
|
|
||||||
|
|
||||||
|
|
||||||
class QuestionDB(SQLModel, table=True):
|
|
||||||
__tablename__ = "question"
|
|
||||||
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
|
||||||
|
|
||||||
id: int = Field(primary_key=True)
|
|
||||||
text: Optional[str] = Field()
|
|
||||||
|
|
||||||
|
|
||||||
class Answer(BaseObject):
|
|
||||||
def __init__(self, answer_id: int = 0, question_id: int = 0, is_correct: bool = True, text: str = ""):
|
|
||||||
"""Answer类
|
|
||||||
|
|
||||||
:param answer_id: 答案ID
|
|
||||||
:param question_id: 与之对应的问题ID
|
|
||||||
:param is_correct: 该答案是否正确
|
|
||||||
:param text: 答案文本
|
|
||||||
"""
|
|
||||||
self.answer_id = answer_id
|
|
||||||
self.question_id = question_id
|
|
||||||
self.text = text
|
|
||||||
self.is_correct = is_correct
|
|
||||||
|
|
||||||
__slots__ = ("answer_id", "question_id", "text", "is_correct")
|
|
||||||
|
|
||||||
def to_database_data(self) -> AnswerDB:
|
|
||||||
data = AnswerDB()
|
|
||||||
data.id = self.answer_id
|
|
||||||
data.question_id = self.question_id
|
|
||||||
data.text = self.text
|
|
||||||
data.is_correct = self.is_correct
|
|
||||||
return data
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def de_database_data(cls, data: Optional[AnswerDB]) -> Optional["Answer"]:
|
|
||||||
if data is None:
|
|
||||||
return cls()
|
|
||||||
return cls(answer_id=data.id, question_id=data.question_id, text=data.text, is_correct=data.is_correct)
|
|
||||||
|
|
||||||
|
|
||||||
class Question(BaseObject):
|
|
||||||
def __init__(self, question_id: int = 0, text: str = "", answers: List[Answer] = None):
|
|
||||||
"""Question类
|
|
||||||
|
|
||||||
:param question_id: 问题ID
|
|
||||||
:param text: 问题文本
|
|
||||||
:param answers: 答案列表
|
|
||||||
"""
|
|
||||||
self.question_id = question_id
|
|
||||||
self.text = text
|
|
||||||
self.answers = [] if answers is None else answers
|
|
||||||
|
|
||||||
def to_database_data(self) -> QuestionDB:
|
|
||||||
data = QuestionDB()
|
|
||||||
data.text = self.text
|
|
||||||
data.id = self.question_id
|
|
||||||
return data
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def de_database_data(cls, data: Optional[QuestionDB]) -> Optional["Question"]:
|
|
||||||
if data is None:
|
|
||||||
return cls()
|
|
||||||
return cls(question_id=data.id, text=data.text)
|
|
||||||
|
|
||||||
def to_dict(self) -> JSONDict:
|
|
||||||
data = super().to_dict()
|
|
||||||
if self.answers:
|
|
||||||
data["answers"] = [e.to_dict() for e in self.answers]
|
|
||||||
return data
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def de_json(cls, data: Optional[JSONDict]) -> Optional["Question"]:
|
|
||||||
data = cls._parse_data(data)
|
|
||||||
if not data:
|
|
||||||
return None
|
|
||||||
data["answers"] = Answer.de_list(data.get("answers"))
|
|
||||||
return cls(**data)
|
|
||||||
|
|
||||||
__slots__ = ("question_id", "text", "answers")
|
|
@ -1,10 +0,0 @@
|
|||||||
from core.service import init_service
|
|
||||||
from .services import SearchServices as _SearchServices
|
|
||||||
|
|
||||||
__all__ = []
|
|
||||||
|
|
||||||
|
|
||||||
@init_service
|
|
||||||
def create_search_service():
|
|
||||||
_service = _SearchServices()
|
|
||||||
return _service
|
|
@ -1,31 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
from utils.log import logger
|
|
||||||
|
|
||||||
__all__ = ["Service", "init_service"]
|
|
||||||
|
|
||||||
|
|
||||||
class Service(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
"""初始化"""
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
"""启动 service"""
|
|
||||||
|
|
||||||
async def stop(self):
|
|
||||||
"""关闭 service"""
|
|
||||||
|
|
||||||
|
|
||||||
def init_service(func: Callable):
|
|
||||||
from core.bot import bot
|
|
||||||
|
|
||||||
if bot.is_running:
|
|
||||||
try:
|
|
||||||
service = bot.init_inject(func)
|
|
||||||
logger.success(f'服务 "{service.__class__.__name__}" 初始化成功')
|
|
||||||
bot.add_service(service)
|
|
||||||
except Exception as e: # pylint: disable=W0703
|
|
||||||
logger.exception(f"来自{func.__module__}的服务初始化失败:{e}")
|
|
||||||
return func
|
|
0
core/services/__init__.py
Normal file
0
core/services/__init__.py
Normal file
5
core/services/cookies/__init__.py
Normal file
5
core/services/cookies/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
"""CookieService"""
|
||||||
|
|
||||||
|
from core.services.cookies.services import CookiesService, PublicCookiesService
|
||||||
|
|
||||||
|
__all__ = ("CookiesService", "PublicCookiesService")
|
@ -1,12 +1,15 @@
|
|||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
from core.base.redisdb import RedisDB
|
from core.base_service import BaseService
|
||||||
|
from core.basemodel import RegionEnum
|
||||||
|
from core.dependence.redisdb import RedisDB
|
||||||
|
from core.services.cookies.error import CookiesCachePoolExhausted
|
||||||
from utils.error import RegionNotFoundError
|
from utils.error import RegionNotFoundError
|
||||||
from utils.models.base import RegionEnum
|
|
||||||
from .error import CookiesCachePoolExhausted
|
__all__ = ("PublicCookiesCache",)
|
||||||
|
|
||||||
|
|
||||||
class PublicCookiesCache:
|
class PublicCookiesCache(BaseService.Component):
|
||||||
"""使用优先级(score)进行排序,对使用次数最少的Cookies进行审核"""
|
"""使用优先级(score)进行排序,对使用次数最少的Cookies进行审核"""
|
||||||
|
|
||||||
def __init__(self, redis: RedisDB):
|
def __init__(self, redis: RedisDB):
|
||||||
@ -19,10 +22,9 @@ class PublicCookiesCache:
|
|||||||
def get_public_cookies_queue_name(self, region: RegionEnum):
|
def get_public_cookies_queue_name(self, region: RegionEnum):
|
||||||
if region == RegionEnum.HYPERION:
|
if region == RegionEnum.HYPERION:
|
||||||
return f"{self.score_qname}:yuanshen"
|
return f"{self.score_qname}:yuanshen"
|
||||||
elif region == RegionEnum.HOYOLAB:
|
if region == RegionEnum.HOYOLAB:
|
||||||
return f"{self.score_qname}:genshin"
|
return f"{self.score_qname}:genshin"
|
||||||
else:
|
raise RegionNotFoundError(region.name)
|
||||||
raise RegionNotFoundError(region.name)
|
|
||||||
|
|
||||||
async def putback_public_cookies(self, uid: int, region: RegionEnum):
|
async def putback_public_cookies(self, uid: int, region: RegionEnum):
|
||||||
"""重新添加单个到缓存列表
|
"""重新添加单个到缓存列表
|
@ -7,11 +7,6 @@ class CookiesCachePoolExhausted(CookieServiceError):
|
|||||||
super().__init__("Cookies cache pool is exhausted")
|
super().__init__("Cookies cache pool is exhausted")
|
||||||
|
|
||||||
|
|
||||||
class CookiesNotFoundError(CookieServiceError):
|
|
||||||
def __init__(self, user_id):
|
|
||||||
super().__init__(f"{user_id} cookies not found")
|
|
||||||
|
|
||||||
|
|
||||||
class TooManyRequestPublicCookies(CookieServiceError):
|
class TooManyRequestPublicCookies(CookieServiceError):
|
||||||
def __init__(self, user_id):
|
def __init__(self, user_id):
|
||||||
super().__init__(f"{user_id} too many request public cookies")
|
super().__init__(f"{user_id} too many request public cookies")
|
39
core/services/cookies/models.py
Normal file
39
core/services/cookies/models.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import enum
|
||||||
|
from typing import Optional, Dict
|
||||||
|
|
||||||
|
from sqlmodel import SQLModel, Field, Boolean, Column, Enum, JSON, Integer, BigInteger, Index
|
||||||
|
|
||||||
|
from core.basemodel import RegionEnum
|
||||||
|
|
||||||
|
__all__ = ("Cookies", "CookiesDataBase", "CookiesStatusEnum")
|
||||||
|
|
||||||
|
|
||||||
|
class CookiesStatusEnum(int, enum.Enum):
|
||||||
|
STATUS_SUCCESS = 0
|
||||||
|
INVALID_COOKIES = 1
|
||||||
|
TOO_MANY_REQUESTS = 2
|
||||||
|
|
||||||
|
|
||||||
|
class Cookies(SQLModel):
|
||||||
|
__table_args__ = (
|
||||||
|
Index("index_user_account", "user_id", "account_id", unique=True),
|
||||||
|
dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"),
|
||||||
|
)
|
||||||
|
id: Optional[int] = Field(default=None, sa_column=Column(Integer, primary_key=True, autoincrement=True))
|
||||||
|
user_id: int = Field(
|
||||||
|
sa_column=Column(BigInteger()),
|
||||||
|
)
|
||||||
|
account_id: int = Field(
|
||||||
|
default=None,
|
||||||
|
sa_column=Column(
|
||||||
|
BigInteger(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
data: Optional[Dict[str, str]] = Field(sa_column=Column(JSON))
|
||||||
|
status: Optional[CookiesStatusEnum] = Field(sa_column=Column(Enum(CookiesStatusEnum)))
|
||||||
|
region: RegionEnum = Field(sa_column=Column(Enum(RegionEnum)))
|
||||||
|
is_share: Optional[bool] = Field(sa_column=Column(Boolean))
|
||||||
|
|
||||||
|
|
||||||
|
class CookiesDataBase(Cookies, table=True):
|
||||||
|
__tablename__ = "cookies"
|
55
core/services/cookies/repositories.py
Normal file
55
core/services/cookies/repositories.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from sqlmodel import select
|
||||||
|
|
||||||
|
from core.base_service import BaseService
|
||||||
|
from core.basemodel import RegionEnum
|
||||||
|
from core.dependence.mysql import MySQL
|
||||||
|
from core.services.cookies.models import CookiesDataBase as Cookies
|
||||||
|
from core.sqlmodel.session import AsyncSession
|
||||||
|
|
||||||
|
__all__ = ("CookiesRepository",)
|
||||||
|
|
||||||
|
|
||||||
|
class CookiesRepository(BaseService.Component):
|
||||||
|
def __init__(self, mysql: MySQL):
|
||||||
|
self.engine = mysql.engine
|
||||||
|
|
||||||
|
async def get(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
account_id: Optional[int] = None,
|
||||||
|
region: Optional[RegionEnum] = None,
|
||||||
|
) -> Optional[Cookies]:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
statement = select(Cookies).where(Cookies.user_id == user_id)
|
||||||
|
if account_id is not None:
|
||||||
|
statement = statement.where(Cookies.account_id == account_id)
|
||||||
|
if region is not None:
|
||||||
|
statement = statement.where(Cookies.region == region)
|
||||||
|
results = await session.exec(statement)
|
||||||
|
return results.first()
|
||||||
|
|
||||||
|
async def add(self, cookies: Cookies) -> None:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
session.add(cookies)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def update(self, cookies: Cookies) -> Cookies:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
session.add(cookies)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(cookies)
|
||||||
|
return cookies
|
||||||
|
|
||||||
|
async def delete(self, cookies: Cookies) -> None:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
await session.delete(cookies)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def get_all_by_region(self, region: RegionEnum) -> List[Cookies]:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
statement = select(Cookies).where(Cookies.region == region)
|
||||||
|
results = await session.exec(statement)
|
||||||
|
cookies = results.all()
|
||||||
|
return cookies
|
@ -1,67 +1,73 @@
|
|||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
import genshin
|
import genshin
|
||||||
from genshin import GenshinException, InvalidCookies, TooManyRequests, types, Game
|
from genshin import Game, GenshinException, InvalidCookies, TooManyRequests, types
|
||||||
|
|
||||||
|
from core.base_service import BaseService
|
||||||
|
from core.basemodel import RegionEnum
|
||||||
|
from core.services.cookies.cache import PublicCookiesCache
|
||||||
|
from core.services.cookies.error import CookieServiceError, TooManyRequestPublicCookies
|
||||||
|
from core.services.cookies.models import CookiesDataBase as Cookies, CookiesStatusEnum
|
||||||
|
from core.services.cookies.repositories import CookiesRepository
|
||||||
from utils.log import logger
|
from utils.log import logger
|
||||||
from utils.models.base import RegionEnum
|
|
||||||
from .cache import PublicCookiesCache
|
__all__ = ("CookiesService", "PublicCookiesService")
|
||||||
from .error import TooManyRequestPublicCookies, CookieServiceError
|
|
||||||
from .models import CookiesStatusEnum
|
|
||||||
from .repositories import CookiesNotFoundError, CookiesRepository
|
|
||||||
|
|
||||||
|
|
||||||
class CookiesService:
|
class CookiesService(BaseService):
|
||||||
def __init__(self, cookies_repository: CookiesRepository) -> None:
|
def __init__(self, cookies_repository: CookiesRepository) -> None:
|
||||||
self._repository: CookiesRepository = cookies_repository
|
self._repository: CookiesRepository = cookies_repository
|
||||||
|
|
||||||
async def update_cookies(self, user_id: int, cookies: dict, region: RegionEnum):
|
async def update(self, cookies: Cookies):
|
||||||
await self._repository.update_cookies(user_id, cookies, region)
|
await self._repository.update(cookies)
|
||||||
|
|
||||||
async def add_cookies(self, user_id: int, cookies: dict, region: RegionEnum):
|
async def add(self, cookies: Cookies):
|
||||||
await self._repository.add_cookies(user_id, cookies, region)
|
await self._repository.add(cookies)
|
||||||
|
|
||||||
async def get_cookies(self, user_id: int, region: RegionEnum):
|
async def get(
|
||||||
return await self._repository.get_cookies(user_id, region)
|
self,
|
||||||
|
user_id: int,
|
||||||
|
account_id: Optional[int] = None,
|
||||||
|
region: Optional[RegionEnum] = None,
|
||||||
|
) -> Optional[Cookies]:
|
||||||
|
return await self._repository.get(user_id, account_id, region)
|
||||||
|
|
||||||
async def del_cookies(self, user_id: int, region: RegionEnum):
|
async def delete(self, cookies: Cookies) -> None:
|
||||||
return await self._repository.del_cookies(user_id, region)
|
return await self._repository.delete(cookies)
|
||||||
|
|
||||||
async def add_or_update_cookies(self, user_id: int, cookies: dict, region: RegionEnum):
|
|
||||||
try:
|
|
||||||
await self.get_cookies(user_id, region)
|
|
||||||
await self.update_cookies(user_id, cookies, region)
|
|
||||||
except CookiesNotFoundError:
|
|
||||||
await self.add_cookies(user_id, cookies, region)
|
|
||||||
|
|
||||||
|
|
||||||
class PublicCookiesService:
|
class PublicCookiesService(BaseService):
|
||||||
def __init__(self, cookies_repository: CookiesRepository, public_cookies_cache: PublicCookiesCache):
|
def __init__(self, cookies_repository: CookiesRepository, public_cookies_cache: PublicCookiesCache):
|
||||||
self._cache = public_cookies_cache
|
self._cache = public_cookies_cache
|
||||||
self._repository: CookiesRepository = cookies_repository
|
self._repository: CookiesRepository = cookies_repository
|
||||||
self.count: int = 0
|
self.count: int = 0
|
||||||
self.user_times_limiter = 3 * 3
|
self.user_times_limiter = 3 * 3
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.info("正在初始化公共Cookies池")
|
||||||
|
await self.refresh()
|
||||||
|
logger.success("刷新公共Cookies池成功")
|
||||||
|
|
||||||
async def refresh(self):
|
async def refresh(self):
|
||||||
"""刷新公共Cookies 定时任务
|
"""刷新公共Cookies 定时任务
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
user_list: List[int] = []
|
user_list: List[int] = []
|
||||||
cookies_list = await self._repository.get_all_cookies(RegionEnum.HYPERION) # 从数据库获取2
|
cookies_list = await self._repository.get_all_by_region(RegionEnum.HYPERION) # 从数据库获取2
|
||||||
for cookies in cookies_list:
|
for cookies in cookies_list:
|
||||||
if cookies.status is None or cookies.status == CookiesStatusEnum.STATUS_SUCCESS:
|
if cookies.status is None or cookies.status == CookiesStatusEnum.STATUS_SUCCESS:
|
||||||
user_list.append(cookies.user_id)
|
user_list.append(cookies.user_id)
|
||||||
if len(user_list) > 0:
|
if len(user_list) > 0:
|
||||||
add, count = await self._cache.add_public_cookies(user_list, RegionEnum.HYPERION)
|
add, count = await self._cache.add_public_cookies(user_list, RegionEnum.HYPERION)
|
||||||
logger.info(f"国服公共Cookies池已经添加[{add}]个 当前成员数为[{count}]")
|
logger.info("国服公共Cookies池已经添加[%s]个 当前成员数为[%s]", add, count)
|
||||||
user_list.clear()
|
user_list.clear()
|
||||||
cookies_list = await self._repository.get_all_cookies(RegionEnum.HOYOLAB)
|
cookies_list = await self._repository.get_all_by_region(RegionEnum.HOYOLAB)
|
||||||
for cookies in cookies_list:
|
for cookies in cookies_list:
|
||||||
if cookies.status is None or cookies.status == CookiesStatusEnum.STATUS_SUCCESS:
|
if cookies.status is None or cookies.status == CookiesStatusEnum.STATUS_SUCCESS:
|
||||||
user_list.append(cookies.user_id)
|
user_list.append(cookies.user_id)
|
||||||
if len(user_list) > 0:
|
if len(user_list) > 0:
|
||||||
add, count = await self._cache.add_public_cookies(user_list, RegionEnum.HOYOLAB)
|
add, count = await self._cache.add_public_cookies(user_list, RegionEnum.HOYOLAB)
|
||||||
logger.info(f"国际服公共Cookies池已经添加[{add}]个 当前成员数为[{count}]")
|
logger.info("国际服公共Cookies池已经添加[%s]个 当前成员数为[%s]", add, count)
|
||||||
|
|
||||||
async def get_cookies(self, user_id: int, region: RegionEnum = RegionEnum.NULL):
|
async def get_cookies(self, user_id: int, region: RegionEnum = RegionEnum.NULL):
|
||||||
"""获取公共Cookies
|
"""获取公共Cookies
|
||||||
@ -71,20 +77,19 @@ class PublicCookiesService:
|
|||||||
"""
|
"""
|
||||||
user_times = await self._cache.incr_by_user_times(user_id)
|
user_times = await self._cache.incr_by_user_times(user_id)
|
||||||
if int(user_times) > self.user_times_limiter:
|
if int(user_times) > self.user_times_limiter:
|
||||||
logger.warning(f"用户 [{user_id}] 使用公共Cookie次数已经到达上限")
|
logger.warning("用户 %s 使用公共Cookie次数已经到达上限", user_id)
|
||||||
raise TooManyRequestPublicCookies(user_id)
|
raise TooManyRequestPublicCookies(user_id)
|
||||||
while True:
|
while True:
|
||||||
public_id, count = await self._cache.get_public_cookies(region)
|
public_id, count = await self._cache.get_public_cookies(region)
|
||||||
try:
|
cookies = await self._repository.get(public_id, region=region)
|
||||||
cookies = await self._repository.get_cookies(public_id, region)
|
if cookies is None:
|
||||||
except CookiesNotFoundError:
|
|
||||||
await self._cache.delete_public_cookies(public_id, region)
|
await self._cache.delete_public_cookies(public_id, region)
|
||||||
continue
|
continue
|
||||||
if region == RegionEnum.HYPERION:
|
if region == RegionEnum.HYPERION:
|
||||||
client = genshin.Client(cookies=cookies.cookies, game=types.Game.GENSHIN, region=types.Region.CHINESE)
|
client = genshin.Client(cookies=cookies.data, game=types.Game.GENSHIN, region=types.Region.CHINESE)
|
||||||
elif region == RegionEnum.HOYOLAB:
|
elif region == RegionEnum.HOYOLAB:
|
||||||
client = genshin.Client(
|
client = genshin.Client(
|
||||||
cookies=cookies.cookies, game=types.Game.GENSHIN, region=types.Region.OVERSEAS, lang="zh-cn"
|
cookies=cookies.data, game=types.Game.GENSHIN, region=types.Region.OVERSEAS, lang="zh-cn"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise CookieServiceError
|
raise CookieServiceError
|
||||||
@ -101,13 +106,13 @@ class PublicCookiesService:
|
|||||||
logger.warning("Cookies无效 ")
|
logger.warning("Cookies无效 ")
|
||||||
logger.exception(exc)
|
logger.exception(exc)
|
||||||
cookies.status = CookiesStatusEnum.INVALID_COOKIES
|
cookies.status = CookiesStatusEnum.INVALID_COOKIES
|
||||||
await self._repository.update_cookies_ex(cookies, region)
|
await self._repository.update(cookies)
|
||||||
await self._cache.delete_public_cookies(cookies.user_id, region)
|
await self._cache.delete_public_cookies(cookies.user_id, region)
|
||||||
continue
|
continue
|
||||||
except TooManyRequests:
|
except TooManyRequests:
|
||||||
logger.warning("用户 [%s] 查询次数太多或操作频繁", public_id)
|
logger.warning("用户 [%s] 查询次数太多或操作频繁", public_id)
|
||||||
cookies.status = CookiesStatusEnum.TOO_MANY_REQUESTS
|
cookies.status = CookiesStatusEnum.TOO_MANY_REQUESTS
|
||||||
await self._repository.update_cookies_ex(cookies, region)
|
await self._repository.update(cookies)
|
||||||
await self._cache.delete_public_cookies(cookies.user_id, region)
|
await self._cache.delete_public_cookies(cookies.user_id, region)
|
||||||
continue
|
continue
|
||||||
except GenshinException as exc:
|
except GenshinException as exc:
|
1
core/services/game/__init__.py
Normal file
1
core/services/game/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""GameService"""
|
@ -1,12 +1,16 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from core.base.redisdb import RedisDB
|
from core.base_service import BaseService
|
||||||
|
from core.dependence.redisdb import RedisDB
|
||||||
|
|
||||||
|
__all__ = ["GameCache", "GameCacheForStrategy", "GameCacheForMaterial"]
|
||||||
|
|
||||||
|
|
||||||
class GameCache:
|
class GameCache:
|
||||||
def __init__(self, redis: RedisDB, qname: str, ttl: int = 3600):
|
qname: str
|
||||||
|
|
||||||
|
def __init__(self, redis: RedisDB, ttl: int = 3600):
|
||||||
self.client = redis.client
|
self.client = redis.client
|
||||||
self.qname = qname
|
|
||||||
self.ttl = ttl
|
self.ttl = ttl
|
||||||
|
|
||||||
async def get_url_list(self, character_name: str):
|
async def get_url_list(self, character_name: str):
|
||||||
@ -19,3 +23,11 @@ class GameCache:
|
|||||||
await self.client.lpush(qname, *str_list)
|
await self.client.lpush(qname, *str_list)
|
||||||
await self.client.expire(qname, self.ttl)
|
await self.client.expire(qname, self.ttl)
|
||||||
return await self.client.llen(qname)
|
return await self.client.llen(qname)
|
||||||
|
|
||||||
|
|
||||||
|
class GameCacheForStrategy(BaseService.Component, GameCache):
|
||||||
|
qname = "game:strategy"
|
||||||
|
|
||||||
|
|
||||||
|
class GameCacheForMaterial(BaseService.Component, GameCache):
|
||||||
|
qname = "game:material"
|
@ -1,11 +1,14 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from core.base_service import BaseService
|
||||||
|
from core.services.game.cache import GameCacheForMaterial, GameCacheForStrategy
|
||||||
from modules.apihelper.client.components.hyperion import Hyperion
|
from modules.apihelper.client.components.hyperion import Hyperion
|
||||||
from .cache import GameCache
|
|
||||||
|
__all__ = ("GameMaterialService", "GameStrategyService")
|
||||||
|
|
||||||
|
|
||||||
class GameStrategyService:
|
class GameStrategyService(BaseService):
|
||||||
def __init__(self, cache: GameCache, collections: Optional[List[int]] = None):
|
def __init__(self, cache: GameCacheForStrategy, collections: Optional[List[int]] = None):
|
||||||
self._cache = cache
|
self._cache = cache
|
||||||
self._hyperion = Hyperion()
|
self._hyperion = Hyperion()
|
||||||
if collections is None:
|
if collections is None:
|
||||||
@ -49,8 +52,8 @@ class GameStrategyService:
|
|||||||
return artwork_info.image_urls[0]
|
return artwork_info.image_urls[0]
|
||||||
|
|
||||||
|
|
||||||
class GameMaterialService:
|
class GameMaterialService(BaseService):
|
||||||
def __init__(self, cache: GameCache, collections: Optional[List[int]] = None):
|
def __init__(self, cache: GameCacheForMaterial, collections: Optional[List[int]] = None):
|
||||||
self._cache = cache
|
self._cache = cache
|
||||||
self._hyperion = Hyperion()
|
self._hyperion = Hyperion()
|
||||||
self._collections = [428421, 1164644] if collections is None else collections
|
self._collections = [428421, 1164644] if collections is None else collections
|
||||||
@ -91,9 +94,8 @@ class GameMaterialService:
|
|||||||
await self._cache.set_url_list(character_name, image_url_list)
|
await self._cache.set_url_list(character_name, image_url_list)
|
||||||
if len(image_url_list) == 0:
|
if len(image_url_list) == 0:
|
||||||
return ""
|
return ""
|
||||||
elif len(image_url_list) == 1:
|
if len(image_url_list) == 1:
|
||||||
return image_url_list[0]
|
return image_url_list[0]
|
||||||
elif character_name in self._special:
|
if character_name in self._special:
|
||||||
return image_url_list[2]
|
return image_url_list[2]
|
||||||
else:
|
return image_url_list[1]
|
||||||
return image_url_list[1]
|
|
3
core/services/players/__init__.py
Normal file
3
core/services/players/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .services import PlayersService
|
||||||
|
|
||||||
|
__all__ = ("PlayersService",)
|
2
core/services/players/error.py
Normal file
2
core/services/players/error.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
class PlayerNotFoundError(Exception):
|
||||||
|
pass
|
96
core/services/players/models.py
Normal file
96
core/services/players/models.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, BaseSettings
|
||||||
|
from sqlalchemy import TypeDecorator
|
||||||
|
from sqlmodel import Boolean, Column, Enum, Field, SQLModel, Integer, Index, BigInteger, VARCHAR, func, DateTime
|
||||||
|
|
||||||
|
from core.basemodel import RegionEnum
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ujson as jsonlib
|
||||||
|
except ImportError:
|
||||||
|
import json as jsonlib
|
||||||
|
|
||||||
|
__all__ = ("Player", "PlayersDataBase", "PlayerInfo", "PlayerInfoSQLModel")
|
||||||
|
|
||||||
|
|
||||||
|
class Player(SQLModel):
|
||||||
|
__table_args__ = (
|
||||||
|
Index("index_user_account_player", "user_id", "account_id", "player_id", unique=True),
|
||||||
|
dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"),
|
||||||
|
)
|
||||||
|
id: Optional[int] = Field(
|
||||||
|
default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True)
|
||||||
|
)
|
||||||
|
user_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
|
||||||
|
account_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
|
||||||
|
player_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
|
||||||
|
region: RegionEnum = Field(sa_column=Column(Enum(RegionEnum)))
|
||||||
|
is_chosen: Optional[bool] = Field(sa_column=Column(Boolean))
|
||||||
|
|
||||||
|
|
||||||
|
class PlayersDataBase(Player, table=True):
|
||||||
|
__tablename__ = "players"
|
||||||
|
|
||||||
|
|
||||||
|
class ExtraPlayerInfo(BaseModel):
|
||||||
|
class Config(BaseSettings.Config):
|
||||||
|
json_loads = jsonlib.loads
|
||||||
|
json_dumps = jsonlib.dumps
|
||||||
|
|
||||||
|
waifu_id: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ExtraPlayerType(TypeDecorator): # pylint: disable=W0223
|
||||||
|
impl = VARCHAR(length=521)
|
||||||
|
|
||||||
|
cache_ok = True
|
||||||
|
|
||||||
|
def process_bind_param(self, value, dialect):
|
||||||
|
"""
|
||||||
|
:param value: ExtraPlayerInfo | obj | None
|
||||||
|
:param dialect:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if value is not None:
|
||||||
|
if isinstance(value, ExtraPlayerInfo):
|
||||||
|
return value.json()
|
||||||
|
raise TypeError
|
||||||
|
return value
|
||||||
|
|
||||||
|
def process_result_value(self, value, dialect):
|
||||||
|
"""
|
||||||
|
:param value: str | obj | None
|
||||||
|
:param dialect:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if value is not None:
|
||||||
|
return ExtraPlayerInfo.parse_raw(value)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class PlayerInfo(SQLModel):
|
||||||
|
__table_args__ = (
|
||||||
|
Index("index_user_account_player", "user_id", "player_id", unique=True),
|
||||||
|
dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"),
|
||||||
|
)
|
||||||
|
id: Optional[int] = Field(
|
||||||
|
default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True)
|
||||||
|
)
|
||||||
|
user_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
|
||||||
|
player_id: int = Field(primary_key=True, sa_column=Column(BigInteger()))
|
||||||
|
nickname: Optional[str] = Field()
|
||||||
|
signature: Optional[str] = Field()
|
||||||
|
hand_image: Optional[int] = Field()
|
||||||
|
name_card: Optional[int] = Field()
|
||||||
|
extra_data: Optional[ExtraPlayerInfo] = Field(sa_column=Column(ExtraPlayerType))
|
||||||
|
create_time: Optional[datetime] = Field(
|
||||||
|
sa_column=Column(DateTime, server_default=func.now()) # pylint: disable=E1102
|
||||||
|
)
|
||||||
|
last_save_time: Optional[datetime] = Field(sa_column=Column(DateTime, onupdate=func.now())) # pylint: disable=E1102
|
||||||
|
is_update: Optional[bool] = Field(sa_column=Column(Boolean))
|
||||||
|
|
||||||
|
|
||||||
|
class PlayerInfoSQLModel(PlayerInfo, table=True):
|
||||||
|
__tablename__ = "players_info"
|
109
core/services/players/repositories.py
Normal file
109
core/services/players/repositories.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from sqlmodel import select, delete
|
||||||
|
|
||||||
|
from core.base_service import BaseService
|
||||||
|
from core.basemodel import RegionEnum
|
||||||
|
from core.dependence.mysql import MySQL
|
||||||
|
from core.services.players.models import PlayerInfoSQLModel
|
||||||
|
from core.services.players.models import PlayersDataBase as Player
|
||||||
|
from core.sqlmodel.session import AsyncSession
|
||||||
|
|
||||||
|
__all__ = ("PlayersRepository", "PlayerInfoRepository")
|
||||||
|
|
||||||
|
|
||||||
|
class PlayersRepository(BaseService.Component):
|
||||||
|
def __init__(self, mysql: MySQL):
|
||||||
|
self.engine = mysql.engine
|
||||||
|
|
||||||
|
async def get(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
player_id: Optional[int] = None,
|
||||||
|
account_id: Optional[int] = None,
|
||||||
|
region: Optional[RegionEnum] = None,
|
||||||
|
is_chosen: Optional[bool] = None,
|
||||||
|
) -> Optional[Player]:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
statement = select(Player).where(Player.user_id == user_id)
|
||||||
|
if player_id is not None:
|
||||||
|
statement = statement.where(Player.player_id == player_id)
|
||||||
|
if account_id is not None:
|
||||||
|
statement = statement.where(Player.account_id == account_id)
|
||||||
|
if region is not None:
|
||||||
|
statement = statement.where(Player.region == region)
|
||||||
|
if is_chosen is not None:
|
||||||
|
statement = statement.where(Player.is_chosen == is_chosen)
|
||||||
|
results = await session.exec(statement)
|
||||||
|
return results.first()
|
||||||
|
|
||||||
|
async def add(self, player: Player) -> None:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
session.add(player)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def delete(self, player: Player) -> None:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
await session.delete(player)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def update(self, player: Player) -> None:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
session.add(player)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(player)
|
||||||
|
|
||||||
|
async def get_all_by_user_id(self, user_id: int) -> List[Player]:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
statement = select(Player).where(Player.user_id == user_id)
|
||||||
|
results = await session.exec(statement)
|
||||||
|
players = results.all()
|
||||||
|
return players
|
||||||
|
|
||||||
|
|
||||||
|
class PlayerInfoRepository(BaseService.Component):
|
||||||
|
def __init__(self, mysql: MySQL):
|
||||||
|
self.engine = mysql.engine
|
||||||
|
|
||||||
|
async def get(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
player_id: int,
|
||||||
|
) -> Optional[PlayerInfoSQLModel]:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
statement = (
|
||||||
|
select(PlayerInfoSQLModel)
|
||||||
|
.where(PlayerInfoSQLModel.player_id == player_id)
|
||||||
|
.where(PlayerInfoSQLModel.user_id == user_id)
|
||||||
|
)
|
||||||
|
results = await session.exec(statement)
|
||||||
|
return results.first()
|
||||||
|
|
||||||
|
async def add(self, player: PlayerInfoSQLModel) -> None:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
session.add(player)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def delete(self, player: PlayerInfoSQLModel) -> None:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
await session.delete(player)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def delete_by_id(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
player_id: int,
|
||||||
|
) -> None:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
statement = (
|
||||||
|
delete(PlayerInfoSQLModel)
|
||||||
|
.where(PlayerInfoSQLModel.player_id == player_id)
|
||||||
|
.where(PlayerInfoSQLModel.user_id == user_id)
|
||||||
|
)
|
||||||
|
await session.execute(statement)
|
||||||
|
|
||||||
|
async def update(self, player: PlayerInfoSQLModel) -> None:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
session.add(player)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(player)
|
184
core/services/players/services.py
Normal file
184
core/services/players/services.py
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from aiohttp import ClientConnectorError
|
||||||
|
from enkanetwork import (
|
||||||
|
EnkaNetworkAPI,
|
||||||
|
VaildateUIDError,
|
||||||
|
HTTPException,
|
||||||
|
EnkaPlayerNotFound,
|
||||||
|
PlayerInfo as EnkaPlayerInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
from core.base_service import BaseService
|
||||||
|
from core.basemodel import RegionEnum
|
||||||
|
from core.config import config
|
||||||
|
from core.dependence.redisdb import RedisDB
|
||||||
|
from core.services.players.models import PlayersDataBase as Player, PlayerInfoSQLModel, PlayerInfo
|
||||||
|
from core.services.players.repositories import PlayersRepository, PlayerInfoRepository
|
||||||
|
from utils.enkanetwork import RedisCache
|
||||||
|
from utils.log import logger
|
||||||
|
from utils.patch.aiohttp import AioHttpTimeoutException
|
||||||
|
|
||||||
|
__all__ = ("PlayersService", "PlayerInfoService")
|
||||||
|
|
||||||
|
|
||||||
|
class PlayersService(BaseService):
|
||||||
|
def __init__(self, players_repository: PlayersRepository) -> None:
|
||||||
|
self._repository = players_repository
|
||||||
|
|
||||||
|
async def get(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
player_id: Optional[int] = None,
|
||||||
|
account_id: Optional[int] = None,
|
||||||
|
region: Optional[RegionEnum] = None,
|
||||||
|
is_chosen: Optional[bool] = None,
|
||||||
|
) -> Optional[Player]:
|
||||||
|
return await self._repository.get(user_id, player_id, account_id, region, is_chosen)
|
||||||
|
|
||||||
|
async def get_player(self, user_id: int, region: Optional[RegionEnum] = None) -> Optional[Player]:
|
||||||
|
return await self._repository.get(user_id, region=region, is_chosen=True)
|
||||||
|
|
||||||
|
async def add(self, player: Player) -> None:
|
||||||
|
await self._repository.add(player)
|
||||||
|
|
||||||
|
async def update(self, player: Player) -> None:
|
||||||
|
await self._repository.update(player)
|
||||||
|
|
||||||
|
async def get_all_by_user_id(self, user_id: int) -> List[Player]:
|
||||||
|
return await self._repository.get_all_by_user_id(user_id)
|
||||||
|
|
||||||
|
async def remove_all_by_user_id(self, user_id: int):
|
||||||
|
players = await self._repository.get_all_by_user_id(user_id)
|
||||||
|
for player in players:
|
||||||
|
await self._repository.delete(player)
|
||||||
|
|
||||||
|
async def delete(self, player: Player):
|
||||||
|
await self._repository.delete(player)
|
||||||
|
|
||||||
|
|
||||||
|
class PlayerInfoService(BaseService):
|
||||||
|
def __init__(self, redis: RedisDB, players_info_repository: PlayerInfoRepository):
|
||||||
|
self.cache = redis.client
|
||||||
|
self._players_info_repository = players_info_repository
|
||||||
|
self.enka_client = EnkaNetworkAPI(lang="chs", user_agent=config.enka_network_api_agent)
|
||||||
|
self.enka_client.set_cache(RedisCache(redis.client, key="players_info:enka_network", ttl=60))
|
||||||
|
self.qname = "players_info"
|
||||||
|
|
||||||
|
async def get_form_cache(self, player: Player):
|
||||||
|
qname = f"{self.qname}:{player.user_id}:{player.player_id}"
|
||||||
|
data = await self.cache.get(qname)
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
json_data = str(data, encoding="utf-8")
|
||||||
|
return PlayerInfo.parse_raw(json_data)
|
||||||
|
|
||||||
|
async def set_form_cache(self, player: PlayerInfo):
|
||||||
|
qname = f"{self.qname}:{player.user_id}:{player.player_id}"
|
||||||
|
await self.cache.set(qname, player.json(), ex=60)
|
||||||
|
|
||||||
|
async def get_player_info_from_enka(self, player_id: int) -> Optional[EnkaPlayerInfo]:
|
||||||
|
try:
|
||||||
|
response = await self.enka_client.fetch_user(player_id, info=True)
|
||||||
|
return response.player
|
||||||
|
except (VaildateUIDError, EnkaPlayerNotFound, HTTPException) as exc:
|
||||||
|
logger.warning("EnkaNetwork 请求失败: %s", str(exc))
|
||||||
|
except AioHttpTimeoutException as exc:
|
||||||
|
logger.warning("EnkaNetwork 请求超时: %s", str(exc))
|
||||||
|
except ClientConnectorError as exc:
|
||||||
|
logger.warning("EnkaNetwork 请求错误: %s", str(exc))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("EnkaNetwork 请求失败: %s", exc_info=exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get(self, player: Player) -> Optional[PlayerInfo]:
|
||||||
|
player_info = await self.get_form_cache(player)
|
||||||
|
if player_info is not None:
|
||||||
|
return player_info
|
||||||
|
player_info = await self._players_info_repository.get(player.user_id, player.player_id)
|
||||||
|
if player_info is None:
|
||||||
|
player_info_enka = await self.get_player_info_from_enka(player.player_id)
|
||||||
|
if player_info_enka is None:
|
||||||
|
return None
|
||||||
|
player_info = PlayerInfo(
|
||||||
|
user_id=player.user_id,
|
||||||
|
player_id=player.player_id,
|
||||||
|
nickname=player_info_enka.nickname,
|
||||||
|
signature=player_info_enka.signature,
|
||||||
|
name_card=player_info_enka.namecard.id,
|
||||||
|
hand_image=player_info_enka.avatar.id,
|
||||||
|
create_time=datetime.now(),
|
||||||
|
last_save_time=datetime.now(),
|
||||||
|
is_update=True,
|
||||||
|
)
|
||||||
|
await self._players_info_repository.add(PlayerInfoSQLModel.from_orm(player_info))
|
||||||
|
await self.set_form_cache(player_info)
|
||||||
|
return player_info
|
||||||
|
if player_info.is_update:
|
||||||
|
expiration_time = datetime.now() - timedelta(days=7)
|
||||||
|
if player_info.last_save_time is None or player_info.last_save_time <= expiration_time:
|
||||||
|
player_info_enka = await self.get_player_info_from_enka(player.player_id)
|
||||||
|
if player_info_enka is None:
|
||||||
|
player_info.last_save_time = datetime.now()
|
||||||
|
await self._players_info_repository.update(PlayerInfoSQLModel.from_orm(player_info))
|
||||||
|
await self.set_form_cache(player_info)
|
||||||
|
return player_info
|
||||||
|
player_info.nickname = player_info_enka.nickname
|
||||||
|
player_info.name_card = player_info_enka.namecard.id
|
||||||
|
player_info.signature = player_info_enka.signature
|
||||||
|
player_info.hand_image = player_info_enka.avatar.id
|
||||||
|
player_info.nickname = player_info_enka.nickname
|
||||||
|
player_info.last_save_time = datetime.now()
|
||||||
|
await self._players_info_repository.update(PlayerInfoSQLModel.from_orm(player_info))
|
||||||
|
await self.set_form_cache(player_info)
|
||||||
|
return player_info
|
||||||
|
|
||||||
|
async def update_from_enka(self, player: Player) -> bool:
|
||||||
|
player_info = await self._players_info_repository.get(player.user_id, player.player_id)
|
||||||
|
if player_info is not None:
|
||||||
|
player_info_enka = await self.get_player_info_from_enka(player.player_id)
|
||||||
|
if player_info_enka is None:
|
||||||
|
return False
|
||||||
|
player_info.nickname = player_info_enka.nickname
|
||||||
|
player_info.name_card = player_info_enka.namecard.id
|
||||||
|
player_info.signature = player_info_enka.signature
|
||||||
|
player_info.hand_image = player_info_enka.avatar.id
|
||||||
|
player_info.nickname = player_info_enka.nickname
|
||||||
|
player_info.last_save_time = datetime.now()
|
||||||
|
await self._players_info_repository.update(player_info)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def add_from_enka(self, player: Player) -> bool:
|
||||||
|
player_info = await self._players_info_repository.get(player.user_id, player.player_id)
|
||||||
|
if player_info is None:
|
||||||
|
player_info_enka = await self.get_player_info_from_enka(player.player_id)
|
||||||
|
if player_info_enka is None:
|
||||||
|
return False
|
||||||
|
player_info = PlayerInfoSQLModel(
|
||||||
|
user_id=player.user_id,
|
||||||
|
player_id=player.player_id,
|
||||||
|
nickname=player_info_enka.nickname,
|
||||||
|
signature=player_info_enka.signature,
|
||||||
|
name_card=player_info_enka.namecard.id,
|
||||||
|
hand_image=player_info_enka.avatar.id,
|
||||||
|
create_time=datetime.now(),
|
||||||
|
last_save_time=datetime.now(),
|
||||||
|
is_update=True,
|
||||||
|
)
|
||||||
|
await self._players_info_repository.add(player_info)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_form_sql(self, player: Player):
|
||||||
|
return await self._players_info_repository.get(player.user_id, player.player_id)
|
||||||
|
|
||||||
|
async def delete_form_player(self, player: Player):
|
||||||
|
await self._players_info_repository.delete_by_id(user_id=player.user_id, player_id=player.player_id)
|
||||||
|
|
||||||
|
async def add(self, player_info: PlayerInfo):
|
||||||
|
await self._players_info_repository.add(PlayerInfoSQLModel.from_orm(player_info))
|
||||||
|
|
||||||
|
async def delete(self, player_info: PlayerInfo):
|
||||||
|
await self._players_info_repository.delete(PlayerInfoSQLModel.from_orm(player_info))
|
1
core/services/quiz/__init__.py
Normal file
1
core/services/quiz/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""QuizService"""
|
@ -1,12 +1,13 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import ujson
|
from core.base_service import BaseService
|
||||||
|
from core.dependence.redisdb import RedisDB
|
||||||
|
from core.services.quiz.models import Answer, Question
|
||||||
|
|
||||||
from core.base.redisdb import RedisDB
|
__all__ = ("QuizCache",)
|
||||||
from .models import Answer, Question
|
|
||||||
|
|
||||||
|
|
||||||
class QuizCache:
|
class QuizCache(BaseService.Component):
|
||||||
def __init__(self, redis: RedisDB):
|
def __init__(self, redis: RedisDB):
|
||||||
self.client = redis.client
|
self.client = redis.client
|
||||||
self.question_qname = "quiz:question"
|
self.question_qname = "quiz:question"
|
||||||
@ -18,7 +19,7 @@ class QuizCache:
|
|||||||
data_list = [self.question_qname + f":{question_id}" for question_id in await self.client.lrange(qname, 0, -1)]
|
data_list = [self.question_qname + f":{question_id}" for question_id in await self.client.lrange(qname, 0, -1)]
|
||||||
data = await self.client.mget(data_list)
|
data = await self.client.mget(data_list)
|
||||||
for i in data:
|
for i in data:
|
||||||
temp_list.append(Question.de_json(ujson.loads(i)))
|
temp_list.append(Question.parse_raw(i))
|
||||||
return temp_list
|
return temp_list
|
||||||
|
|
||||||
async def get_all_question_id_list(self) -> List[str]:
|
async def get_all_question_id_list(self) -> List[str]:
|
||||||
@ -29,19 +30,19 @@ class QuizCache:
|
|||||||
qname = f"{self.question_qname}:{question_id}"
|
qname = f"{self.question_qname}:{question_id}"
|
||||||
data = await self.client.get(qname)
|
data = await self.client.get(qname)
|
||||||
json_data = str(data, encoding="utf-8")
|
json_data = str(data, encoding="utf-8")
|
||||||
return Question.de_json(ujson.loads(json_data))
|
return Question.parse_raw(json_data)
|
||||||
|
|
||||||
async def get_one_answer(self, answer_id: int) -> Answer:
|
async def get_one_answer(self, answer_id: int) -> Answer:
|
||||||
qname = f"{self.answer_qname}:{answer_id}"
|
qname = f"{self.answer_qname}:{answer_id}"
|
||||||
data = await self.client.get(qname)
|
data = await self.client.get(qname)
|
||||||
json_data = str(data, encoding="utf-8")
|
json_data = str(data, encoding="utf-8")
|
||||||
return Answer.de_json(ujson.loads(json_data))
|
return Answer.parse_raw(json_data)
|
||||||
|
|
||||||
async def add_question(self, question_list: List[Question] = None) -> int:
|
async def add_question(self, question_list: List[Question] = None) -> int:
|
||||||
if not question_list:
|
if not question_list:
|
||||||
return 0
|
return 0
|
||||||
for question in question_list:
|
for question in question_list:
|
||||||
await self.client.set(f"{self.question_qname}:{question.question_id}", ujson.dumps(question.to_dict()))
|
await self.client.set(f"{self.question_qname}:{question.question_id}", question.json())
|
||||||
question_id_list = [question.question_id for question in question_list]
|
question_id_list = [question.question_id for question in question_list]
|
||||||
await self.client.lpush(f"{self.question_qname}:id_list", *question_id_list)
|
await self.client.lpush(f"{self.question_qname}:id_list", *question_id_list)
|
||||||
return await self.client.llen(f"{self.question_qname}:id_list")
|
return await self.client.llen(f"{self.question_qname}:id_list")
|
||||||
@ -62,7 +63,7 @@ class QuizCache:
|
|||||||
if not answer_list:
|
if not answer_list:
|
||||||
return 0
|
return 0
|
||||||
for answer in answer_list:
|
for answer in answer_list:
|
||||||
await self.client.set(f"{self.answer_qname}:{answer.answer_id}", ujson.dumps(answer.to_dict()))
|
await self.client.set(f"{self.answer_qname}:{answer.answer_id}", answer.json())
|
||||||
answer_id_list = [answer.answer_id for answer in answer_list]
|
answer_id_list = [answer.answer_id for answer in answer_list]
|
||||||
await self.client.lpush(f"{self.answer_qname}:id_list", *answer_id_list)
|
await self.client.lpush(f"{self.answer_qname}:id_list", *answer_id_list)
|
||||||
return await self.client.llen(f"{self.answer_qname}:id_list")
|
return await self.client.llen(f"{self.answer_qname}:id_list")
|
57
core/services/quiz/models.py
Normal file
57
core/services/quiz/models.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlmodel import Column, Field, ForeignKey, Integer, SQLModel
|
||||||
|
|
||||||
|
__all__ = ("Answer", "AnswerDB", "Question", "QuestionDB")
|
||||||
|
|
||||||
|
|
||||||
|
class AnswerDB(SQLModel, table=True):
|
||||||
|
__tablename__ = "answer"
|
||||||
|
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
||||||
|
|
||||||
|
id: Optional[int] = Field(
|
||||||
|
default=None, primary_key=True, sa_column=Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
)
|
||||||
|
question_id: Optional[int] = Field(
|
||||||
|
sa_column=Column(Integer, ForeignKey("question.id", ondelete="RESTRICT", onupdate="RESTRICT"))
|
||||||
|
)
|
||||||
|
is_correct: Optional[bool] = Field()
|
||||||
|
text: Optional[str] = Field()
|
||||||
|
|
||||||
|
|
||||||
|
class QuestionDB(SQLModel, table=True):
|
||||||
|
__tablename__ = "question"
|
||||||
|
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
||||||
|
|
||||||
|
id: Optional[int] = Field(
|
||||||
|
default=None, primary_key=True, sa_column=Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
)
|
||||||
|
text: Optional[str] = Field()
|
||||||
|
|
||||||
|
|
||||||
|
class Answer(BaseModel):
|
||||||
|
answer_id: int = 0
|
||||||
|
question_id: int = 0
|
||||||
|
is_correct: bool = True
|
||||||
|
text: str = ""
|
||||||
|
|
||||||
|
def to_database_data(self) -> AnswerDB:
|
||||||
|
return AnswerDB(id=self.answer_id, question_id=self.question_id, text=self.text, is_correct=self.is_correct)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def de_database_data(cls, data: AnswerDB) -> Optional["Answer"]:
|
||||||
|
return cls(answer_id=data.id, question_id=data.question_id, text=data.text, is_correct=data.is_correct)
|
||||||
|
|
||||||
|
|
||||||
|
class Question(BaseModel):
|
||||||
|
question_id: int = 0
|
||||||
|
text: str = ""
|
||||||
|
answers: List[Answer] = []
|
||||||
|
|
||||||
|
def to_database_data(self) -> QuestionDB:
|
||||||
|
return QuestionDB(text=self.text, id=self.question_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def de_database_data(cls, data: QuestionDB) -> Optional["Question"]:
|
||||||
|
return cls(question_id=data.id, text=data.text)
|
@ -2,54 +2,55 @@ from typing import List
|
|||||||
|
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
|
|
||||||
from core.base.mysql import MySQL
|
from core.base_service import BaseService
|
||||||
from .models import AnswerDB, QuestionDB
|
from core.dependence.mysql import MySQL
|
||||||
|
from core.services.quiz.models import AnswerDB, QuestionDB
|
||||||
|
from core.sqlmodel.session import AsyncSession
|
||||||
|
|
||||||
|
__all__ = ("QuizRepository",)
|
||||||
|
|
||||||
|
|
||||||
class QuizRepository:
|
class QuizRepository(BaseService.Component):
|
||||||
def __init__(self, mysql: MySQL):
|
def __init__(self, mysql: MySQL):
|
||||||
self.mysql = mysql
|
self.engine = mysql.engine
|
||||||
|
|
||||||
async def get_question_list(self) -> List[QuestionDB]:
|
async def get_question_list(self) -> List[QuestionDB]:
|
||||||
async with self.mysql.Session() as session:
|
async with AsyncSession(self.engine) as session:
|
||||||
query = select(QuestionDB)
|
query = select(QuestionDB)
|
||||||
results = await session.exec(query)
|
results = await session.exec(query)
|
||||||
questions = results.all()
|
return results.all()
|
||||||
return questions
|
|
||||||
|
|
||||||
async def get_answers_from_question_id(self, question_id: int) -> List[AnswerDB]:
|
async def get_answers_from_question_id(self, question_id: int) -> List[AnswerDB]:
|
||||||
async with self.mysql.Session() as session:
|
async with AsyncSession(self.engine) as session:
|
||||||
query = select(AnswerDB).where(AnswerDB.question_id == question_id)
|
query = select(AnswerDB).where(AnswerDB.question_id == question_id)
|
||||||
results = await session.exec(query)
|
results = await session.exec(query)
|
||||||
answers = results.all()
|
return results.all()
|
||||||
return answers
|
|
||||||
|
|
||||||
async def add_question(self, question: QuestionDB):
|
async def add_question(self, question: QuestionDB):
|
||||||
async with self.mysql.Session() as session:
|
async with AsyncSession(self.engine) as session:
|
||||||
session.add(question)
|
session.add(question)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def get_question_by_text(self, text: str) -> QuestionDB:
|
async def get_question_by_text(self, text: str) -> QuestionDB:
|
||||||
async with self.mysql.Session() as session:
|
async with AsyncSession(self.engine) as session:
|
||||||
query = select(QuestionDB).where(QuestionDB.text == text)
|
query = select(QuestionDB).where(QuestionDB.text == text)
|
||||||
results = await session.exec(query)
|
results = await session.exec(query)
|
||||||
question = results.first()
|
return results.first()
|
||||||
return question[0]
|
|
||||||
|
|
||||||
async def add_answer(self, answer: AnswerDB):
|
async def add_answer(self, answer: AnswerDB):
|
||||||
async with self.mysql.Session() as session:
|
async with AsyncSession(self.engine) as session:
|
||||||
session.add(answer)
|
session.add(answer)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def delete_question_by_id(self, question_id: int):
|
async def delete_question_by_id(self, question_id: int):
|
||||||
async with self.mysql.Session() as session:
|
async with AsyncSession(self.engine) as session:
|
||||||
statement = select(QuestionDB).where(QuestionDB.id == question_id)
|
statement = select(QuestionDB).where(QuestionDB.id == question_id)
|
||||||
results = await session.exec(statement)
|
results = await session.exec(statement)
|
||||||
question = results.one()
|
question = results.one()
|
||||||
await session.delete(question)
|
await session.delete(question)
|
||||||
|
|
||||||
async def delete_answer_by_id(self, answer_id: int):
|
async def delete_answer_by_id(self, answer_id: int):
|
||||||
async with self.mysql.Session() as session:
|
async with AsyncSession(self.engine) as session:
|
||||||
statement = select(AnswerDB).where(AnswerDB.id == answer_id)
|
statement = select(AnswerDB).where(AnswerDB.id == answer_id)
|
||||||
results = await session.exec(statement)
|
results = await session.exec(statement)
|
||||||
answer = results.one()
|
answer = results.one()
|
@ -1,12 +1,15 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from .cache import QuizCache
|
from core.base_service import BaseService
|
||||||
from .models import Answer, Question
|
from core.services.quiz.cache import QuizCache
|
||||||
from .repositories import QuizRepository
|
from core.services.quiz.models import Answer, Question
|
||||||
|
from core.services.quiz.repositories import QuizRepository
|
||||||
|
|
||||||
|
__all__ = ("QuizService",)
|
||||||
|
|
||||||
|
|
||||||
class QuizService:
|
class QuizService(BaseService):
|
||||||
def __init__(self, repository: QuizRepository, cache: QuizCache):
|
def __init__(self, repository: QuizRepository, cache: QuizCache):
|
||||||
self._repository = repository
|
self._repository = repository
|
||||||
self._cache = cache
|
self._cache = cache
|
1
core/services/search/__init__.py
Normal file
1
core/services/search/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""SearchService"""
|
@ -1,12 +1,11 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Optional, List
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
__all__ = ["BaseEntry", "WeaponEntry", "WeaponsEntry", "StrategyEntry", "StrategyEntryList"]
|
|
||||||
|
|
||||||
from thefuzz import fuzz
|
from thefuzz import fuzz
|
||||||
|
|
||||||
|
__all__ = ("BaseEntry", "WeaponEntry", "WeaponsEntry", "StrategyEntry", "StrategyEntryList")
|
||||||
|
|
||||||
|
|
||||||
class BaseEntry(BaseModel):
|
class BaseEntry(BaseModel):
|
||||||
"""所有可搜索条目的基类。
|
"""所有可搜索条目的基类。
|
@ -5,19 +5,22 @@ import json
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Tuple, List, Optional, Dict
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
from async_lru import alru_cache
|
from async_lru import alru_cache
|
||||||
|
|
||||||
from core.search.models import WeaponEntry, BaseEntry, WeaponsEntry, StrategyEntry, StrategyEntryList
|
from core.base_service import BaseService
|
||||||
|
from core.services.search.models import BaseEntry, StrategyEntry, StrategyEntryList, WeaponEntry, WeaponsEntry
|
||||||
from utils.const import PROJECT_ROOT
|
from utils.const import PROJECT_ROOT
|
||||||
|
|
||||||
|
__all__ = ("SearchServices",)
|
||||||
|
|
||||||
ENTRY_DAYA_PATH = PROJECT_ROOT.joinpath("data", "entry")
|
ENTRY_DAYA_PATH = PROJECT_ROOT.joinpath("data", "entry")
|
||||||
ENTRY_DAYA_PATH.mkdir(parents=True, exist_ok=True)
|
ENTRY_DAYA_PATH.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
class SearchServices:
|
class SearchServices(BaseService):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._lock = asyncio.Lock() # 访问和修改操作成员变量必须加锁操作
|
self._lock = asyncio.Lock() # 访问和修改操作成员变量必须加锁操作
|
||||||
self.weapons: List[WeaponEntry] = []
|
self.weapons: List[WeaponEntry] = []
|
1
core/services/sign/__init__.py
Normal file
1
core/services/sign/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""SignService"""
|
@ -2,8 +2,10 @@ import enum
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func, BigInteger
|
||||||
from sqlmodel import SQLModel, Field, Enum, Column, DateTime
|
from sqlmodel import Column, DateTime, Enum, Field, SQLModel, Integer
|
||||||
|
|
||||||
|
__all__ = ("SignStatusEnum", "Sign")
|
||||||
|
|
||||||
|
|
||||||
class SignStatusEnum(int, enum.Enum):
|
class SignStatusEnum(int, enum.Enum):
|
||||||
@ -19,10 +21,13 @@ class SignStatusEnum(int, enum.Enum):
|
|||||||
|
|
||||||
class Sign(SQLModel, table=True):
|
class Sign(SQLModel, table=True):
|
||||||
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
||||||
|
id: Optional[int] = Field(
|
||||||
id: int = Field(primary_key=True)
|
default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True)
|
||||||
user_id: int = Field(foreign_key="user.user_id")
|
)
|
||||||
|
user_id: int = Field(primary_key=True, sa_column=Column(BigInteger(), index=True))
|
||||||
chat_id: Optional[int] = Field(default=None)
|
chat_id: Optional[int] = Field(default=None)
|
||||||
time_created: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True), server_default=func.now()))
|
time_created: Optional[datetime] = Field(
|
||||||
time_updated: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True), onupdate=func.now()))
|
sa_column=Column(DateTime, server_default=func.now()) # pylint: disable=E1102
|
||||||
|
)
|
||||||
|
time_updated: Optional[datetime] = Field(sa_column=Column(DateTime, onupdate=func.now())) # pylint: disable=E1102
|
||||||
status: Optional[SignStatusEnum] = Field(sa_column=Column(Enum(SignStatusEnum)))
|
status: Optional[SignStatusEnum] = Field(sa_column=Column(Enum(SignStatusEnum)))
|
50
core/services/sign/repositories.py
Normal file
50
core/services/sign/repositories.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from sqlmodel import select
|
||||||
|
|
||||||
|
from core.base_service import BaseService
|
||||||
|
from core.dependence.mysql import MySQL
|
||||||
|
from core.services.sign.models import Sign
|
||||||
|
from core.sqlmodel.session import AsyncSession
|
||||||
|
|
||||||
|
__all__ = ("SignRepository",)
|
||||||
|
|
||||||
|
|
||||||
|
class SignRepository(BaseService.Component):
|
||||||
|
def __init__(self, mysql: MySQL):
|
||||||
|
self.engine = mysql.engine
|
||||||
|
|
||||||
|
async def add(self, sign: Sign):
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
session.add(sign)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def remove(self, sign: Sign):
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
await session.delete(sign)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def update(self, sign: Sign) -> Sign:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
session.add(sign)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(sign)
|
||||||
|
return sign
|
||||||
|
|
||||||
|
async def get_by_user_id(self, user_id: int) -> Optional[Sign]:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
statement = select(Sign).where(Sign.user_id == user_id)
|
||||||
|
results = await session.exec(statement)
|
||||||
|
return results.first()
|
||||||
|
|
||||||
|
async def get_by_chat_id(self, chat_id: int) -> Optional[List[Sign]]:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
statement = select(Sign).where(Sign.chat_id == chat_id)
|
||||||
|
results = await session.exec(statement)
|
||||||
|
return results.all()
|
||||||
|
|
||||||
|
async def get_all(self) -> List[Sign]:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
query = select(Sign)
|
||||||
|
results = await session.exec(query)
|
||||||
|
return results.all()
|
@ -1,8 +1,11 @@
|
|||||||
from .models import Sign
|
from core.base_service import BaseService
|
||||||
from .repositories import SignRepository
|
from core.services.sign.models import Sign
|
||||||
|
from core.services.sign.repositories import SignRepository
|
||||||
|
|
||||||
|
__all__ = ["SignServices"]
|
||||||
|
|
||||||
|
|
||||||
class SignServices:
|
class SignServices(BaseService):
|
||||||
def __init__(self, sign_repository: SignRepository) -> None:
|
def __init__(self, sign_repository: SignRepository) -> None:
|
||||||
self._repository: SignRepository = sign_repository
|
self._repository: SignRepository = sign_repository
|
||||||
|
|
1
core/services/template/__init__.py
Normal file
1
core/services/template/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""TemplateService"""
|
@ -3,10 +3,14 @@ import pickle # nosec B403
|
|||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from core.base.redisdb import RedisDB
|
from core.base_service import BaseService
|
||||||
|
|
||||||
|
from core.dependence.redisdb import RedisDB
|
||||||
|
|
||||||
|
__all__ = ["TemplatePreviewCache", "HtmlToFileIdCache"]
|
||||||
|
|
||||||
|
|
||||||
class TemplatePreviewCache:
|
class TemplatePreviewCache(BaseService.Component):
|
||||||
"""暂存渲染模板的数据用于预览"""
|
"""暂存渲染模板的数据用于预览"""
|
||||||
|
|
||||||
def __init__(self, redis: RedisDB):
|
def __init__(self, redis: RedisDB):
|
||||||
@ -29,7 +33,7 @@ class TemplatePreviewCache:
|
|||||||
return f"{self.qname}:{key}"
|
return f"{self.qname}:{key}"
|
||||||
|
|
||||||
|
|
||||||
class HtmlToFileIdCache:
|
class HtmlToFileIdCache(BaseService.Component):
|
||||||
"""html to file_id 的缓存"""
|
"""html to file_id 的缓存"""
|
||||||
|
|
||||||
def __init__(self, redis: RedisDB):
|
def __init__(self, redis: RedisDB):
|
@ -1,10 +1,12 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Union, List
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from telegram import Message, InputMediaPhoto, InputMediaDocument
|
from telegram import InputMediaDocument, InputMediaPhoto, Message
|
||||||
|
|
||||||
from core.template.cache import HtmlToFileIdCache
|
from core.services.template.cache import HtmlToFileIdCache
|
||||||
from core.template.error import ErrorFileType, FileIdNotFound
|
from core.services.template.error import ErrorFileType, FileIdNotFound
|
||||||
|
|
||||||
|
__all__ = ["FileType", "RenderResult", "RenderGroupResult"]
|
||||||
|
|
||||||
|
|
||||||
class FileType(Enum):
|
class FileType(Enum):
|
||||||
@ -16,10 +18,9 @@ class FileType(Enum):
|
|||||||
"""对应的 Telegram media 类型"""
|
"""对应的 Telegram media 类型"""
|
||||||
if file_type == FileType.PHOTO:
|
if file_type == FileType.PHOTO:
|
||||||
return InputMediaPhoto
|
return InputMediaPhoto
|
||||||
elif file_type == FileType.DOCUMENT:
|
if file_type == FileType.DOCUMENT:
|
||||||
return InputMediaDocument
|
return InputMediaDocument
|
||||||
else:
|
raise ErrorFileType
|
||||||
raise ErrorFileType
|
|
||||||
|
|
||||||
|
|
||||||
class RenderResult:
|
class RenderResult:
|
@ -1,44 +1,31 @@
|
|||||||
import time
|
import asyncio
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from urllib.parse import (
|
from urllib.parse import urlencode, urljoin, urlsplit
|
||||||
urlencode,
|
|
||||||
urljoin,
|
|
||||||
urlsplit,
|
|
||||||
)
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from fastapi import HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.responses import (
|
from fastapi.responses import FileResponse, HTMLResponse
|
||||||
FileResponse,
|
|
||||||
HTMLResponse,
|
|
||||||
)
|
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from jinja2 import (
|
from jinja2 import Environment, FileSystemLoader, Template
|
||||||
Environment,
|
|
||||||
FileSystemLoader,
|
|
||||||
Template,
|
|
||||||
)
|
|
||||||
from playwright.async_api import ViewportSize
|
from playwright.async_api import ViewportSize
|
||||||
|
|
||||||
from core.base.aiobrowser import AioBrowser
|
from core.application import Application
|
||||||
from core.base.webserver import webapp
|
from core.base_service import BaseService
|
||||||
from core.bot import bot
|
from core.config import config as application_config
|
||||||
from core.template.cache import (
|
from core.dependence.aiobrowser import AioBrowser
|
||||||
HtmlToFileIdCache,
|
from core.services.template.cache import HtmlToFileIdCache, TemplatePreviewCache
|
||||||
TemplatePreviewCache,
|
from core.services.template.error import QuerySelectorNotFound
|
||||||
)
|
from core.services.template.models import FileType, RenderResult
|
||||||
from core.template.error import QuerySelectorNotFound
|
|
||||||
from core.template.models import (
|
|
||||||
FileType,
|
|
||||||
RenderResult,
|
|
||||||
)
|
|
||||||
from utils.const import PROJECT_ROOT
|
from utils.const import PROJECT_ROOT
|
||||||
from utils.log import logger
|
from utils.log import logger
|
||||||
|
|
||||||
|
__all__ = ("TemplateService", "TemplatePreviewer")
|
||||||
|
|
||||||
class TemplateService:
|
|
||||||
|
class TemplateService(BaseService):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
app: Application,
|
||||||
browser: AioBrowser,
|
browser: AioBrowser,
|
||||||
html_to_file_id_cache: HtmlToFileIdCache,
|
html_to_file_id_cache: HtmlToFileIdCache,
|
||||||
preview_cache: TemplatePreviewCache,
|
preview_cache: TemplatePreviewCache,
|
||||||
@ -51,10 +38,12 @@ class TemplateService:
|
|||||||
loader=FileSystemLoader(template_dir),
|
loader=FileSystemLoader(template_dir),
|
||||||
enable_async=True,
|
enable_async=True,
|
||||||
autoescape=True,
|
autoescape=True,
|
||||||
auto_reload=bot.config.debug,
|
auto_reload=application_config.debug,
|
||||||
)
|
)
|
||||||
|
self.using_preview = application_config.debug and application_config.webserver.enable
|
||||||
|
|
||||||
self.previewer = TemplatePreviewer(self, preview_cache)
|
if self.using_preview:
|
||||||
|
self.previewer = TemplatePreviewer(self, preview_cache, app.web_app)
|
||||||
|
|
||||||
self.html_to_file_id_cache = html_to_file_id_cache
|
self.html_to_file_id_cache = html_to_file_id_cache
|
||||||
|
|
||||||
@ -66,10 +55,11 @@ class TemplateService:
|
|||||||
:param template_name: 模板文件名
|
:param template_name: 模板文件名
|
||||||
:param template_data: 模板数据
|
:param template_data: 模板数据
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
loop = asyncio.get_event_loop()
|
||||||
|
start_time = loop.time()
|
||||||
template = self.get_template(template_name)
|
template = self.get_template(template_name)
|
||||||
html = await template.render_async(**template_data)
|
html = await template.render_async(**template_data)
|
||||||
logger.debug(f"{template_name} 模板渲染使用了 {str(time.time() - start_time)}")
|
logger.debug("%s 模板渲染使用了 %s", template_name, str(loop.time() - start_time))
|
||||||
return html
|
return html
|
||||||
|
|
||||||
async def render(
|
async def render(
|
||||||
@ -100,19 +90,20 @@ class TemplateService:
|
|||||||
:param filename: 文件名字
|
:param filename: 文件名字
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
loop = asyncio.get_event_loop()
|
||||||
|
start_time = loop.time()
|
||||||
template = self.get_template(template_name)
|
template = self.get_template(template_name)
|
||||||
|
|
||||||
if bot.config.debug:
|
if self.using_preview:
|
||||||
preview_url = await self.previewer.get_preview_url(template_name, template_data)
|
preview_url = await self.previewer.get_preview_url(template_name, template_data)
|
||||||
logger.debug(f"调试模板 URL: {preview_url}")
|
logger.debug("调试模板 URL: \n%s", preview_url)
|
||||||
|
|
||||||
html = await template.render_async(**template_data)
|
html = await template.render_async(**template_data)
|
||||||
logger.debug(f"{template_name} 模板渲染使用了 {str(time.time() - start_time)}")
|
logger.debug("%s 模板渲染使用了 %s", template_name, str(loop.time() - start_time))
|
||||||
|
|
||||||
file_id = await self.html_to_file_id_cache.get_data(html, file_type.name)
|
file_id = await self.html_to_file_id_cache.get_data(html, file_type.name)
|
||||||
if file_id and not bot.config.debug:
|
if file_id and not application_config.debug:
|
||||||
logger.debug(f"{template_name} 命中缓存,返回 file_id {file_id}")
|
logger.debug("%s 命中缓存,返回 file_id[%s]", template_name, file_id)
|
||||||
return RenderResult(
|
return RenderResult(
|
||||||
html=html,
|
html=html,
|
||||||
photo=file_id,
|
photo=file_id,
|
||||||
@ -125,7 +116,7 @@ class TemplateService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
browser = await self._browser.get_browser()
|
browser = await self._browser.get_browser()
|
||||||
start_time = time.time()
|
start_time = loop.time()
|
||||||
page = await browser.new_page(viewport=viewport)
|
page = await browser.new_page(viewport=viewport)
|
||||||
uri = (PROJECT_ROOT / template.filename).as_uri()
|
uri = (PROJECT_ROOT / template.filename).as_uri()
|
||||||
await page.goto(uri)
|
await page.goto(uri)
|
||||||
@ -142,10 +133,10 @@ class TemplateService:
|
|||||||
if not clip:
|
if not clip:
|
||||||
raise QuerySelectorNotFound
|
raise QuerySelectorNotFound
|
||||||
except QuerySelectorNotFound:
|
except QuerySelectorNotFound:
|
||||||
logger.warning(f"未找到 {query_selector} 元素")
|
logger.warning("未找到 %s 元素", query_selector)
|
||||||
png_data = await page.screenshot(clip=clip, full_page=full_page)
|
png_data = await page.screenshot(clip=clip, full_page=full_page)
|
||||||
await page.close()
|
await page.close()
|
||||||
logger.debug(f"{template_name} 图片渲染使用了 {str(time.time() - start_time)}")
|
logger.debug("%s 图片渲染使用了 %s", template_name, str(loop.time() - start_time))
|
||||||
return RenderResult(
|
return RenderResult(
|
||||||
html=html,
|
html=html,
|
||||||
photo=png_data,
|
photo=png_data,
|
||||||
@ -158,15 +149,21 @@ class TemplateService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TemplatePreviewer:
|
class TemplatePreviewer(BaseService, load=application_config.webserver.enable and application_config.debug):
|
||||||
def __init__(self, template_service: TemplateService, cache: TemplatePreviewCache):
|
def __init__(
|
||||||
|
self,
|
||||||
|
template_service: TemplateService,
|
||||||
|
cache: TemplatePreviewCache,
|
||||||
|
web_app: FastAPI,
|
||||||
|
):
|
||||||
|
self.web_app = web_app
|
||||||
self.template_service = template_service
|
self.template_service = template_service
|
||||||
self.cache = cache
|
self.cache = cache
|
||||||
self.register_routes()
|
self.register_routes()
|
||||||
|
|
||||||
async def get_preview_url(self, template: str, data: dict):
|
async def get_preview_url(self, template: str, data: dict):
|
||||||
"""获取预览 URL"""
|
"""获取预览 URL"""
|
||||||
components = urlsplit(bot.config.webserver.url)
|
components = urlsplit(application_config.webserver.url)
|
||||||
path = urljoin("/preview/", template)
|
path = urljoin("/preview/", template)
|
||||||
query = {}
|
query = {}
|
||||||
|
|
||||||
@ -176,12 +173,13 @@ class TemplatePreviewer:
|
|||||||
await self.cache.set_data(key, data)
|
await self.cache.set_data(key, data)
|
||||||
query["key"] = key
|
query["key"] = key
|
||||||
|
|
||||||
|
# noinspection PyProtectedMember
|
||||||
return components._replace(path=path, query=urlencode(query)).geturl()
|
return components._replace(path=path, query=urlencode(query)).geturl()
|
||||||
|
|
||||||
def register_routes(self):
|
def register_routes(self):
|
||||||
"""注册预览用到的路由"""
|
"""注册预览用到的路由"""
|
||||||
|
|
||||||
@webapp.get("/preview/{path:path}")
|
@self.web_app.get("/preview/{path:path}")
|
||||||
async def preview_template(path: str, key: Optional[str] = None): # pylint: disable=W0612
|
async def preview_template(path: str, key: Optional[str] = None): # pylint: disable=W0612
|
||||||
# 如果是 /preview/ 开头的静态文件,直接返回内容。比如使用相对链接 ../ 引入的静态资源
|
# 如果是 /preview/ 开头的静态文件,直接返回内容。比如使用相对链接 ../ 引入的静态资源
|
||||||
if not path.endswith(".html"):
|
if not path.endswith(".html"):
|
||||||
@ -206,4 +204,4 @@ class TemplatePreviewer:
|
|||||||
for name in ["cache", "resources"]:
|
for name in ["cache", "resources"]:
|
||||||
directory = PROJECT_ROOT / name
|
directory = PROJECT_ROOT / name
|
||||||
directory.mkdir(exist_ok=True)
|
directory.mkdir(exist_ok=True)
|
||||||
webapp.mount(f"/{name}", StaticFiles(directory=PROJECT_ROOT / name), name=name)
|
self.web_app.mount(f"/{name}", StaticFiles(directory=PROJECT_ROOT / name), name=name)
|
0
core/services/users/__init__.py
Normal file
0
core/services/users/__init__.py
Normal file
24
core/services/users/cache.py
Normal file
24
core/services/users/cache.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from core.base_service import BaseService
|
||||||
|
from core.dependence.redisdb import RedisDB
|
||||||
|
|
||||||
|
__all__ = ("UserAdminCache",)
|
||||||
|
|
||||||
|
|
||||||
|
class UserAdminCache(BaseService.Component):
|
||||||
|
def __init__(self, redis: RedisDB):
|
||||||
|
self.client = redis.client
|
||||||
|
self.qname = "users:admin"
|
||||||
|
|
||||||
|
async def ismember(self, user_id: int) -> bool:
|
||||||
|
return self.client.sismember(self.qname, user_id)
|
||||||
|
|
||||||
|
async def get_all(self) -> List[int]:
|
||||||
|
return [int(str_data) for str_data in await self.client.smembers(self.qname)]
|
||||||
|
|
||||||
|
async def set(self, user_id: int) -> bool:
|
||||||
|
return await self.client.sadd(self.qname, user_id)
|
||||||
|
|
||||||
|
async def remove(self, user_id: int) -> bool:
|
||||||
|
return await self.client.srem(self.qname, user_id)
|
34
core/services/users/models.py
Normal file
34
core/services/users/models.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import enum
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlmodel import SQLModel, Field, DateTime, Column, Enum, BigInteger, Integer
|
||||||
|
|
||||||
|
__all__ = (
|
||||||
|
"User",
|
||||||
|
"UserDataBase",
|
||||||
|
"PermissionsEnum",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionsEnum(int, enum.Enum):
|
||||||
|
OWNER = 1
|
||||||
|
ADMIN = 2
|
||||||
|
PUBLIC = 3
|
||||||
|
|
||||||
|
|
||||||
|
class User(SQLModel):
|
||||||
|
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
||||||
|
id: Optional[int] = Field(
|
||||||
|
default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True)
|
||||||
|
)
|
||||||
|
user_id: int = Field(unique=True, sa_column=Column(BigInteger()))
|
||||||
|
permissions: Optional[PermissionsEnum] = Field(sa_column=Column(Enum(PermissionsEnum)))
|
||||||
|
locale: Optional[str] = Field()
|
||||||
|
ban_end_time: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True)))
|
||||||
|
ban_start_time: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True)))
|
||||||
|
is_banned: Optional[int] = Field()
|
||||||
|
|
||||||
|
|
||||||
|
class UserDataBase(User, table=True):
|
||||||
|
__tablename__ = "users"
|
44
core/services/users/repositories.py
Normal file
44
core/services/users/repositories.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from sqlmodel import select
|
||||||
|
|
||||||
|
from core.base_service import BaseService
|
||||||
|
from core.dependence.mysql import MySQL
|
||||||
|
from core.services.users.models import UserDataBase as User
|
||||||
|
from core.sqlmodel.session import AsyncSession
|
||||||
|
|
||||||
|
__all__ = ("UserRepository",)
|
||||||
|
|
||||||
|
|
||||||
|
class UserRepository(BaseService.Component):
|
||||||
|
def __init__(self, mysql: MySQL):
|
||||||
|
self.engine = mysql.engine
|
||||||
|
|
||||||
|
async def get_by_user_id(self, user_id: int) -> Optional[User]:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
statement = select(User).where(User.user_id == user_id)
|
||||||
|
results = await session.exec(statement)
|
||||||
|
return results.first()
|
||||||
|
|
||||||
|
async def add(self, user: User):
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def update(self, user: User) -> User:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def remove(self, user: User):
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
await session.delete(user)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def get_all(self) -> List[User]:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
statement = select(User)
|
||||||
|
results = await session.exec(statement)
|
||||||
|
return results.all()
|
79
core/services/users/services.py
Normal file
79
core/services/users/services.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from core.base_service import BaseService
|
||||||
|
from core.config import config
|
||||||
|
from core.services.users.cache import UserAdminCache
|
||||||
|
from core.services.users.models import PermissionsEnum, UserDataBase as User
|
||||||
|
from core.services.users.repositories import UserRepository
|
||||||
|
|
||||||
|
__all__ = ("UserService", "UserAdminService")
|
||||||
|
|
||||||
|
from utils.log import logger
|
||||||
|
|
||||||
|
|
||||||
|
class UserService(BaseService):
|
||||||
|
def __init__(self, user_repository: UserRepository) -> None:
|
||||||
|
self._repository: UserRepository = user_repository
|
||||||
|
|
||||||
|
async def get_user_by_id(self, user_id: int) -> Optional[User]:
|
||||||
|
"""从数据库获取用户信息
|
||||||
|
:param user_id:用户ID
|
||||||
|
:return: User
|
||||||
|
"""
|
||||||
|
return await self._repository.get_by_user_id(user_id)
|
||||||
|
|
||||||
|
async def remove(self, user: User):
|
||||||
|
return await self._repository.remove(user)
|
||||||
|
|
||||||
|
async def update_user(self, user: User):
|
||||||
|
return await self._repository.add(user)
|
||||||
|
|
||||||
|
|
||||||
|
class UserAdminService(BaseService):
|
||||||
|
def __init__(self, user_repository: UserRepository, cache: UserAdminCache):
|
||||||
|
self.user_repository = user_repository
|
||||||
|
self._cache = cache
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
owner = config.owner
|
||||||
|
if owner:
|
||||||
|
user = await self.user_repository.get_by_user_id(owner)
|
||||||
|
await self._cache.set(user.user_id)
|
||||||
|
if user:
|
||||||
|
if user.permissions != PermissionsEnum.OWNER:
|
||||||
|
user.permissions = PermissionsEnum.OWNER
|
||||||
|
await self.user_repository.update(user)
|
||||||
|
else:
|
||||||
|
user = User(user_id=owner, permissions=PermissionsEnum.OWNER)
|
||||||
|
await self.user_repository.add(user)
|
||||||
|
else:
|
||||||
|
logger.warning("检测到未配置Bot所有者 会导无法正常使用管理员权限")
|
||||||
|
|
||||||
|
async def is_admin(self, user_id: int) -> bool:
|
||||||
|
return await self._cache.ismember(user_id)
|
||||||
|
|
||||||
|
async def get_admin_list(self) -> List[int]:
|
||||||
|
return await self._cache.get_all()
|
||||||
|
|
||||||
|
async def add_admin(self, user_id: int) -> bool:
|
||||||
|
user = await self.user_repository.get_by_user_id(user_id)
|
||||||
|
if user:
|
||||||
|
if user.permissions == PermissionsEnum.OWNER:
|
||||||
|
return False
|
||||||
|
if user.permissions != PermissionsEnum.ADMIN:
|
||||||
|
user.permissions = PermissionsEnum.ADMIN
|
||||||
|
await self.user_repository.update(user)
|
||||||
|
else:
|
||||||
|
user = User(user_id=user_id, permissions=PermissionsEnum.ADMIN)
|
||||||
|
await self.user_repository.add(user)
|
||||||
|
return await self._cache.set(user.user_id)
|
||||||
|
|
||||||
|
async def delete_admin(self, user_id: int) -> bool:
|
||||||
|
user = await self.user_repository.get_by_user_id(user_id)
|
||||||
|
if user:
|
||||||
|
if user.permissions == PermissionsEnum.OWNER:
|
||||||
|
return True # 假装移除成功
|
||||||
|
user.permissions = PermissionsEnum.PUBLIC
|
||||||
|
await self.user_repository.update(user)
|
||||||
|
return await self._cache.remove(user.user_id)
|
||||||
|
return False
|
1
core/services/wiki/__init__.py
Normal file
1
core/services/wiki/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""WikiService"""
|
@ -1,10 +1,13 @@
|
|||||||
import ujson as json
|
import ujson as json
|
||||||
|
|
||||||
from core.base.redisdb import RedisDB
|
from core.base_service import BaseService
|
||||||
|
from core.dependence.redisdb import RedisDB
|
||||||
from modules.wiki.base import Model
|
from modules.wiki.base import Model
|
||||||
|
|
||||||
|
__all__ = ["WikiCache"]
|
||||||
|
|
||||||
class WikiCache:
|
|
||||||
|
class WikiCache(BaseService.Component):
|
||||||
def __init__(self, redis: RedisDB):
|
def __init__(self, redis: RedisDB):
|
||||||
self.client = redis.client
|
self.client = redis.client
|
||||||
self.qname = "wiki"
|
self.qname = "wiki"
|
@ -1,12 +1,15 @@
|
|||||||
from typing import List, NoReturn, Optional
|
from typing import List, NoReturn, Optional
|
||||||
|
|
||||||
from core.wiki.cache import WikiCache
|
from core.base_service import BaseService
|
||||||
|
from core.services.wiki.cache import WikiCache
|
||||||
from modules.wiki.character import Character
|
from modules.wiki.character import Character
|
||||||
from modules.wiki.weapon import Weapon
|
from modules.wiki.weapon import Weapon
|
||||||
from utils.log import logger
|
from utils.log import logger
|
||||||
|
|
||||||
|
__all__ = ["WikiService"]
|
||||||
|
|
||||||
class WikiService:
|
|
||||||
|
class WikiService(BaseService):
|
||||||
def __init__(self, cache: WikiCache):
|
def __init__(self, cache: WikiCache):
|
||||||
self._cache = cache
|
self._cache = cache
|
||||||
"""Redis 在这里的作用是作为持久化"""
|
"""Redis 在这里的作用是作为持久化"""
|
||||||
@ -18,7 +21,7 @@ class WikiService:
|
|||||||
|
|
||||||
async def refresh_weapon(self) -> NoReturn:
|
async def refresh_weapon(self) -> NoReturn:
|
||||||
weapon_name_list = await Weapon.get_name_list()
|
weapon_name_list = await Weapon.get_name_list()
|
||||||
logger.info(f"一共找到 {len(weapon_name_list)} 把武器信息")
|
logger.info("一共找到 %s 把武器信息", len(weapon_name_list))
|
||||||
|
|
||||||
weapon_list = []
|
weapon_list = []
|
||||||
num = 0
|
num = 0
|
||||||
@ -26,7 +29,7 @@ class WikiService:
|
|||||||
weapon_list.append(weapon)
|
weapon_list.append(weapon)
|
||||||
num += 1
|
num += 1
|
||||||
if num % 10 == 0:
|
if num % 10 == 0:
|
||||||
logger.info(f"现在已经获取到 {num} 把武器信息")
|
logger.info("现在已经获取到 %s 把武器信息", num)
|
||||||
|
|
||||||
logger.info("写入武器信息到Redis")
|
logger.info("写入武器信息到Redis")
|
||||||
self._weapon_list = weapon_list
|
self._weapon_list = weapon_list
|
||||||
@ -35,7 +38,7 @@ class WikiService:
|
|||||||
|
|
||||||
async def refresh_characters(self) -> NoReturn:
|
async def refresh_characters(self) -> NoReturn:
|
||||||
character_name_list = await Character.get_name_list()
|
character_name_list = await Character.get_name_list()
|
||||||
logger.info(f"一共找到 {len(character_name_list)} 个角色信息")
|
logger.info("一共找到 %s 个角色信息", len(character_name_list))
|
||||||
|
|
||||||
character_list = []
|
character_list = []
|
||||||
num = 0
|
num = 0
|
||||||
@ -43,7 +46,7 @@ class WikiService:
|
|||||||
character_list.append(character)
|
character_list.append(character)
|
||||||
num += 1
|
num += 1
|
||||||
if num % 10 == 0:
|
if num % 10 == 0:
|
||||||
logger.info(f"现在已经获取到 {num} 个角色信息")
|
logger.info("现在已经获取到 %s 个角色信息", num)
|
||||||
|
|
||||||
logger.info("写入角色信息到Redis")
|
logger.info("写入角色信息到Redis")
|
||||||
self._character_list = character_list
|
self._character_list = character_list
|
@ -1,11 +0,0 @@
|
|||||||
from core.base.mysql import MySQL
|
|
||||||
from core.service import init_service
|
|
||||||
from .repositories import SignRepository
|
|
||||||
from .services import SignServices
|
|
||||||
|
|
||||||
|
|
||||||
@init_service
|
|
||||||
def create_game_strategy_service(mysql: MySQL):
|
|
||||||
_repository = SignRepository(mysql)
|
|
||||||
_service = SignServices(_repository)
|
|
||||||
return _service
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user