🎨 Improve config

This commit is contained in:
Karako 2022-10-28 15:11:14 +08:00 committed by GitHub
parent be5f4c51d5
commit f0b287dcfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 150 additions and 156 deletions

View File

@ -1,20 +1,24 @@
from itertools import chain
import os
import asyncio import asyncio
import os
from importlib import import_module from importlib import import_module
from itertools import chain
from logging.config import fileConfig from logging.config import fileConfig
from typing import Iterator from typing import Iterator
from sqlalchemy import engine_from_config from alembic import context
from sqlalchemy import pool from sqlalchemy import (
engine_from_config,
pool,
)
from sqlalchemy.engine import Connection from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.ext.asyncio import AsyncEngine
from sqlmodel import SQLModel from sqlmodel import SQLModel
from alembic import context from utils.const import (
CORE_DIR,
from utils.const import CORE_DIR, PLUGIN_DIR, PROJECT_ROOT PLUGIN_DIR,
PROJECT_ROOT,
)
from utils.log import logger from utils.log import logger
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides

View File

@ -3,7 +3,10 @@ import asyncio
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from core.config import BotConfig, config as botConfig from core.config import (
BotConfig,
config as botConfig,
)
from core.service import Service from core.service import Service
__all__ = ["webapp", "WebServer"] __all__ = ["webapp", "WebServer"]
@ -28,7 +31,7 @@ class WebServer(Service):
@classmethod @classmethod
def from_config(cls, config: BotConfig) -> Service: def from_config(cls, config: BotConfig) -> Service:
return cls(debug=config.debug, **config.webserver.dict()) return cls(debug=config.debug, host=config.webserver.host, port=config.webserver.port)
def __init__(self, debug: bool, host: str, port: int): def __init__(self, debug: bool, host: str, port: int):
self.debug = debug self.debug = debug

View File

@ -7,10 +7,15 @@ from typing import (
) )
import dotenv import dotenv
import ujson as json from pydantic import (
from pydantic import BaseModel, BaseSettings, validator AnyUrl,
BaseModel,
Field,
validator,
)
from utils.const import PROJECT_ROOT from utils.const import PROJECT_ROOT
from utils.models.base import Settings
__all__ = ["BotConfig", "config", "JoinGroups"] __all__ = ["BotConfig", "config", "JoinGroups"]
@ -23,116 +28,6 @@ class JoinGroups(str, Enum):
ALLOW_ALL = "ALLOW_ALL" ALLOW_ALL = "ALLOW_ALL"
class BotConfig(BaseSettings):
debug: bool = False
db_host: str = ""
db_port: int = 0
db_username: str = ""
db_password: str = ""
db_database: str = ""
redis_host: str = ""
redis_port: int = 0
redis_db: int = 0
bot_token: str = ""
error_notification_chat_id: Optional[str] = None
api_id: Optional[int] = None
api_hash: Optional[str] = None
channels: List["ConfigChannel"] = []
admins: List["ConfigUser"] = []
verify_groups: List[Union[int, str]] = []
join_groups: Optional[JoinGroups] = JoinGroups.NO_ALLOW
logger_name: str = "TGPaimon"
logger_width: int = 180
logger_log_path: str = "./logs"
logger_time_format: str = "[%Y-%m-%d %X]"
logger_traceback_max_frames: int = 20
logger_render_keywords: List[str] = ["BOT"]
logger_locals_max_depth: Optional[int] = 0
logger_locals_max_length: int = 10
logger_locals_max_string: int = 80
logger_filtered_names: List[str] = ["uvicorn"]
timeout: int = 10
read_timeout: float = 2
write_timeout: Optional[float] = None
connect_timeout: Optional[float] = None
pool_timeout: Optional[float] = None
genshin_ttl: Optional[int] = None
enka_network_api_agent: str = ""
pass_challenge_api: str = ""
pass_challenge_app_key: str = ""
web_url: str = "http://localhost:8080/"
web_host: str = "localhost"
web_port: int = 8080
error_pb_url: str = ""
error_pb_sunset: int = 43200
error_pb_max_lines: int = 1000
error_sentry_dsn: str = ""
class Config:
case_sensitive = False
json_loads = json.loads
json_dumps = json.dumps
@property
def mysql(self) -> "MySqlConfig":
return MySqlConfig(
host=self.db_host,
port=self.db_port,
username=self.db_username,
password=self.db_password,
database=self.db_database,
)
@property
def redis(self) -> "RedisConfig":
return RedisConfig(
host=self.redis_host,
port=self.redis_port,
database=self.redis_db,
)
@property
def logger(self) -> "LoggerConfig":
return LoggerConfig(
name=self.logger_name,
width=self.logger_width,
traceback_max_frames=self.logger_traceback_max_frames,
path=PROJECT_ROOT.joinpath(self.logger_log_path).resolve(),
time_format=self.logger_time_format,
render_keywords=self.logger_render_keywords,
locals_max_length=self.logger_locals_max_length,
locals_max_string=self.logger_locals_max_string,
locals_max_depth=self.logger_locals_max_depth,
filtered_names=self.logger_filtered_names,
)
@property
def mtproto(self) -> "MTProtoConfig":
return MTProtoConfig(
api_id=self.api_id,
api_hash=self.api_hash,
)
@property
def webserver(self) -> "WebServerConfig":
return WebServerConfig(
host=self.web_host,
port=self.web_port,
)
class ConfigChannel(BaseModel): class ConfigChannel(BaseModel):
name: str name: str
chat_id: int chat_id: int
@ -143,21 +38,27 @@ class ConfigUser(BaseModel):
user_id: int user_id: int
class MySqlConfig(BaseModel): class MySqlConfig(Settings):
host: str = "127.0.0.1" host: str = "127.0.0.1"
port: int = 3306 port: int = 3306
username: str username: str
password: str password: str
database: str database: str
class Config(Settings.Config):
env_prefix = "db_"
class RedisConfig(BaseModel):
class RedisConfig(Settings):
host: str = "127.0.0.1" host: str = "127.0.0.1"
port: int port: int = 6379
database: int = 0 database: int = Field(env="redis_db")
class Config(Settings.Config):
env_prefix = "redis_"
class LoggerConfig(BaseModel): class LoggerConfig(Settings):
name: str = "TGPaimon" name: str = "TGPaimon"
width: int = 180 width: int = 180
time_format: str = "[%Y-%m-%d %X]" time_format: str = "[%Y-%m-%d %X]"
@ -171,19 +72,67 @@ class LoggerConfig(BaseModel):
@validator("locals_max_depth", pre=True, check_fields=False) @validator("locals_max_depth", pre=True, check_fields=False)
def locals_max_depth_validator(cls, value) -> Optional[int]: # pylint: disable=R0201 def locals_max_depth_validator(cls, value) -> Optional[int]: # pylint: disable=R0201
if value <= 0: if int(value) <= 0:
return None return None
return value return value
class Config(Settings.Config):
class MTProtoConfig(BaseModel): env_prefix = "logger_"
api_id: Optional[int]
api_hash: Optional[str]
class WebServerConfig(BaseModel): class MTProtoConfig(Settings):
host: Optional[str] api_id: Optional[int] = None
port: Optional[int] api_hash: Optional[str] = None
class WebServerConfig(Settings):
url: AnyUrl = "http://localhost:8080"
host: str = "localhost"
port: int = 8080
class Config(Settings.Config):
env_prefix = "web_"
class ErrorConfig(Settings):
pb_url: str = ""
pb_sunset: int = 43200
pb_max_lines: int = 1000
sentry_dsn: str = ""
notification_chat_id: Optional[str] = None
class Config(Settings.Config):
env_prefix = "error_"
class BotConfig(Settings):
debug: bool = False
bot_token: str = ""
channels: List["ConfigChannel"] = []
admins: List["ConfigUser"] = []
verify_groups: List[Union[int, str]] = []
join_groups: Optional[JoinGroups] = JoinGroups.NO_ALLOW
timeout: int = 10
read_timeout: float = 2
write_timeout: Optional[float] = None
connect_timeout: Optional[float] = None
pool_timeout: Optional[float] = None
genshin_ttl: Optional[int] = None
enka_network_api_agent: str = ""
pass_challenge_api: str = ""
pass_challenge_app_key: str = ""
mysql: MySqlConfig = MySqlConfig()
logger: LoggerConfig = LoggerConfig()
webserver: WebServerConfig = WebServerConfig()
redis: RedisConfig = RedisConfig()
mtproto: MTProtoConfig = MTProtoConfig()
error: ErrorConfig = ErrorConfig()
BotConfig.update_forward_refs() BotConfig.update_forward_refs()

View File

@ -1,20 +1,37 @@
import time import time
from typing import Optional from typing import Optional
from urllib.parse import urlencode, urljoin, urlsplit from urllib.parse import (
urlencode,
urljoin,
urlsplit,
)
from uuid import uuid4 from uuid import uuid4
from fastapi import HTTPException from fastapi import HTTPException
from fastapi.responses import FileResponse, HTMLResponse from fastapi.responses import (
FileResponse,
HTMLResponse,
)
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from jinja2 import Environment, FileSystemLoader, Template from jinja2 import (
Environment,
FileSystemLoader,
Template,
)
from playwright.async_api import ViewportSize from playwright.async_api import ViewportSize
from core.base.aiobrowser import AioBrowser from core.base.aiobrowser import AioBrowser
from core.base.webserver import webapp from core.base.webserver import webapp
from core.bot import bot from core.bot import bot
from core.template.cache import HtmlToFileIdCache, TemplatePreviewCache from core.template.cache import (
HtmlToFileIdCache,
TemplatePreviewCache,
)
from core.template.error import QuerySelectorNotFound from core.template.error import QuerySelectorNotFound
from core.template.models import FileType, RenderResult from core.template.models import (
FileType,
RenderResult,
)
from utils.const import PROJECT_ROOT from utils.const import PROJECT_ROOT
from utils.log import logger from utils.log import logger
@ -149,7 +166,7 @@ class TemplatePreviewer:
async def get_preview_url(self, template: str, data: dict): async def get_preview_url(self, template: str, data: dict):
"""获取预览 URL""" """获取预览 URL"""
components = urlsplit(bot.config.web_url) components = urlsplit(bot.config.webserver.url)
path = urljoin("/preview/", template) path = urljoin("/preview/", template)
query = {} query = {}
@ -187,4 +204,6 @@ class TemplatePreviewer:
# 其他静态资源 # 其他静态资源
for name in ["cache", "resources"]: for name in ["cache", "resources"]:
directory = PROJECT_ROOT / name
directory.mkdir(exist_ok=True)
webapp.mount(f"/{name}", StaticFiles(directory=PROJECT_ROOT / name), name=name) webapp.mount(f"/{name}", StaticFiles(directory=PROJECT_ROOT / name), name=name)

View File

@ -7,16 +7,16 @@ from utils.log import logger
class PbClient: class PbClient:
def __init__(self): def __init__(self):
self.client = httpx.AsyncClient() self.client = httpx.AsyncClient()
self.PB_API = config.error_pb_url self.PB_API = config.error.pb_url
self.sunset: int = config.error_pb_sunset # 自动销毁时间 单位为秒 self.sunset: int = config.error.pb_sunset # 自动销毁时间 单位为秒
self.private: bool = True self.private: bool = True
self.max_lines: int = config.error_pb_max_lines self.max_lines: int = config.error.pb_max_lines
async def create_pb(self, content: str) -> str: async def create_pb(self, content: str) -> str:
if not self.PB_API: if not self.PB_API:
return "" return ""
logger.info("正在上传日记到 pb") logger.info("正在上传日记到 pb")
content = "\n".join(content.splitlines()[-self.max_lines:]) + "\n" content = "\n".join(content.splitlines()[-self.max_lines :]) + "\n"
data = { data = {
"c": content, "c": content,
} }

View File

@ -15,7 +15,7 @@ from utils.log import logger
repo = Repo(os.getcwd()) repo = Repo(os.getcwd())
sentry_sdk_git_hash = rev_parse(repo, "HEAD").hexsha sentry_sdk_git_hash = rev_parse(repo, "HEAD").hexsha
sentry_sdk.init( sentry_sdk.init(
config.error_sentry_dsn, config.error.sentry_dsn,
traces_sample_rate=1.0, traces_sample_rate=1.0,
release=sentry_sdk_git_hash, release=sentry_sdk_git_hash,
environment="production", environment="production",
@ -31,7 +31,7 @@ sentry_sdk.init(
class Sentry: class Sentry:
@staticmethod @staticmethod
def report_error(update: object, exc_info): def report_error(update: object, exc_info):
if not config.error_sentry_dsn: if not config.error.sentry_dsn:
return return
logger.info("正在上传日记到 sentry") logger.info("正在上传日记到 sentry")
message: str = "" message: str = ""
@ -45,8 +45,6 @@ class Sentry:
if update.effective_message: if update.effective_message:
if update.effective_message.text: if update.effective_message.text:
message = update.effective_message.text message = update.effective_message.text
sentry_sdk.set_context( sentry_sdk.set_context("Target", {"ChatID": str(chat_id), "UserID": str(user_id), "Msg": message})
"Target", {"ChatID": str(chat_id), "UserID": str(user_id), "Msg": message}
)
sentry_sdk.capture_exception(exc_info) sentry_sdk.capture_exception(exc_info)
logger.success("上传日记到 sentry 成功") logger.success("上传日记到 sentry 成功")

View File

@ -4,18 +4,18 @@ import time
import traceback import traceback
import aiofiles import aiofiles
from telegram import Update, ReplyKeyboardRemove from telegram import ReplyKeyboardRemove, Update
from telegram.constants import ParseMode from telegram.constants import ParseMode
from telegram.error import BadRequest, Forbidden from telegram.error import BadRequest, Forbidden
from telegram.ext import CallbackContext from telegram.ext import CallbackContext
from core.bot import bot from core.bot import bot
from core.plugin import error_handler, Plugin from core.plugin import Plugin, error_handler
from modules.error.pb import PbClient from modules.error.pb import PbClient
from modules.error.sentry import Sentry from modules.error.sentry import Sentry
from utils.log import logger from utils.log import logger
notice_chat_id = bot.config.error_notification_chat_id notice_chat_id = bot.config.error.notification_chat_id
current_dir = os.getcwd() current_dir = os.getcwd()
logs_dir = os.path.join(current_dir, "logs") logs_dir = os.path.join(current_dir, "logs")
if not os.path.exists(logs_dir): if not os.path.exists(logs_dir):

View File

@ -5,6 +5,7 @@ from httpx import URL
__all__ = [ __all__ = [
"PROJECT_ROOT", "PROJECT_ROOT",
"CORE_DIR",
"PLUGIN_DIR", "PLUGIN_DIR",
"RESOURCE_DIR", "RESOURCE_DIR",
"NOT_SET", "NOT_SET",

View File

@ -4,7 +4,10 @@ from typing import TYPE_CHECKING
from core.config import config from core.config import config
from utils.log._config import LoggerConfig from utils.log._config import LoggerConfig
from utils.log._logger import LogFilter, Logger from utils.log._logger import (
LogFilter,
Logger,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from logging import LogRecord from logging import LogRecord
@ -21,7 +24,7 @@ logger = Logger(
keywords=config.logger.render_keywords, keywords=config.logger.render_keywords,
traceback_locals_max_depth=config.logger.locals_max_depth, traceback_locals_max_depth=config.logger.locals_max_depth,
traceback_locals_max_length=config.logger.locals_max_length, traceback_locals_max_length=config.logger.locals_max_length,
traceback_locals_max_string=config.logger_locals_max_string, traceback_locals_max_string=config.logger.locals_max_string,
) )
) )

View File

@ -1,7 +1,13 @@
import imghdr import imghdr
import os import os
from enum import Enum from enum import Enum
from typing import Union, Optional from typing import (
Optional,
Union,
)
import ujson as json
from pydantic import BaseSettings
from utils.baseobject import BaseObject from utils.baseobject import BaseObject
@ -96,3 +102,14 @@ class ModuleInfo:
def __str__(self): def __str__(self):
return self.module_name return self.module_name
class Settings(BaseSettings):
def __new__(cls, *args, **kwargs):
cls.update_forward_refs()
return super(Settings, cls).__new__(cls)
class Config(BaseSettings.Config):
case_sensitive = False
json_loads = json.loads
json_dumps = json.dumps