🎨 Improve config

This commit is contained in:
Karako 2022-10-28 15:11:14 +08:00 committed by GitHub
parent be5f4c51d5
commit f0b287dcfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 150 additions and 156 deletions

View File

@ -1,20 +1,24 @@
from itertools import chain
import os
import asyncio
import os
from importlib import import_module
from itertools import chain
from logging.config import fileConfig
from typing import Iterator
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from sqlalchemy import (
engine_from_config,
pool,
)
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlmodel import SQLModel
from alembic import context
from utils.const import CORE_DIR, PLUGIN_DIR, PROJECT_ROOT
from utils.const import (
CORE_DIR,
PLUGIN_DIR,
PROJECT_ROOT,
)
from utils.log import logger
# this is the Alembic Config object, which provides

View File

@ -3,7 +3,10 @@ import asyncio
import uvicorn
from fastapi import FastAPI
from core.config import BotConfig, config as botConfig
from core.config import (
BotConfig,
config as botConfig,
)
from core.service import Service
__all__ = ["webapp", "WebServer"]
@ -28,7 +31,7 @@ class WebServer(Service):
@classmethod
def from_config(cls, config: BotConfig) -> Service:
return cls(debug=config.debug, **config.webserver.dict())
return cls(debug=config.debug, host=config.webserver.host, port=config.webserver.port)
def __init__(self, debug: bool, host: str, port: int):
self.debug = debug

View File

@ -7,10 +7,15 @@ from typing import (
)
import dotenv
import ujson as json
from pydantic import BaseModel, BaseSettings, validator
from pydantic import (
AnyUrl,
BaseModel,
Field,
validator,
)
from utils.const import PROJECT_ROOT
from utils.models.base import Settings
__all__ = ["BotConfig", "config", "JoinGroups"]
@ -23,116 +28,6 @@ class JoinGroups(str, Enum):
ALLOW_ALL = "ALLOW_ALL"
class BotConfig(BaseSettings):
debug: bool = False
db_host: str = ""
db_port: int = 0
db_username: str = ""
db_password: str = ""
db_database: str = ""
redis_host: str = ""
redis_port: int = 0
redis_db: int = 0
bot_token: str = ""
error_notification_chat_id: Optional[str] = None
api_id: Optional[int] = None
api_hash: Optional[str] = None
channels: List["ConfigChannel"] = []
admins: List["ConfigUser"] = []
verify_groups: List[Union[int, str]] = []
join_groups: Optional[JoinGroups] = JoinGroups.NO_ALLOW
logger_name: str = "TGPaimon"
logger_width: int = 180
logger_log_path: str = "./logs"
logger_time_format: str = "[%Y-%m-%d %X]"
logger_traceback_max_frames: int = 20
logger_render_keywords: List[str] = ["BOT"]
logger_locals_max_depth: Optional[int] = 0
logger_locals_max_length: int = 10
logger_locals_max_string: int = 80
logger_filtered_names: List[str] = ["uvicorn"]
timeout: int = 10
read_timeout: float = 2
write_timeout: Optional[float] = None
connect_timeout: Optional[float] = None
pool_timeout: Optional[float] = None
genshin_ttl: Optional[int] = None
enka_network_api_agent: str = ""
pass_challenge_api: str = ""
pass_challenge_app_key: str = ""
web_url: str = "http://localhost:8080/"
web_host: str = "localhost"
web_port: int = 8080
error_pb_url: str = ""
error_pb_sunset: int = 43200
error_pb_max_lines: int = 1000
error_sentry_dsn: str = ""
class Config:
case_sensitive = False
json_loads = json.loads
json_dumps = json.dumps
@property
def mysql(self) -> "MySqlConfig":
return MySqlConfig(
host=self.db_host,
port=self.db_port,
username=self.db_username,
password=self.db_password,
database=self.db_database,
)
@property
def redis(self) -> "RedisConfig":
return RedisConfig(
host=self.redis_host,
port=self.redis_port,
database=self.redis_db,
)
@property
def logger(self) -> "LoggerConfig":
return LoggerConfig(
name=self.logger_name,
width=self.logger_width,
traceback_max_frames=self.logger_traceback_max_frames,
path=PROJECT_ROOT.joinpath(self.logger_log_path).resolve(),
time_format=self.logger_time_format,
render_keywords=self.logger_render_keywords,
locals_max_length=self.logger_locals_max_length,
locals_max_string=self.logger_locals_max_string,
locals_max_depth=self.logger_locals_max_depth,
filtered_names=self.logger_filtered_names,
)
@property
def mtproto(self) -> "MTProtoConfig":
return MTProtoConfig(
api_id=self.api_id,
api_hash=self.api_hash,
)
@property
def webserver(self) -> "WebServerConfig":
return WebServerConfig(
host=self.web_host,
port=self.web_port,
)
class ConfigChannel(BaseModel):
name: str
chat_id: int
@ -143,21 +38,27 @@ class ConfigUser(BaseModel):
user_id: int
class MySqlConfig(BaseModel):
class MySqlConfig(Settings):
host: str = "127.0.0.1"
port: int = 3306
username: str
password: str
database: str
class Config(Settings.Config):
env_prefix = "db_"
class RedisConfig(BaseModel):
class RedisConfig(Settings):
host: str = "127.0.0.1"
port: int
database: int = 0
port: int = 6379
database: int = Field(env="redis_db")
class Config(Settings.Config):
env_prefix = "redis_"
class LoggerConfig(BaseModel):
class LoggerConfig(Settings):
name: str = "TGPaimon"
width: int = 180
time_format: str = "[%Y-%m-%d %X]"
@ -171,19 +72,67 @@ class LoggerConfig(BaseModel):
@validator("locals_max_depth", pre=True, check_fields=False)
def locals_max_depth_validator(cls, value) -> Optional[int]: # pylint: disable=R0201
if value <= 0:
if int(value) <= 0:
return None
return value
class MTProtoConfig(BaseModel):
api_id: Optional[int]
api_hash: Optional[str]
class Config(Settings.Config):
env_prefix = "logger_"
class WebServerConfig(BaseModel):
host: Optional[str]
port: Optional[int]
class MTProtoConfig(Settings):
api_id: Optional[int] = None
api_hash: Optional[str] = None
class WebServerConfig(Settings):
url: AnyUrl = "http://localhost:8080"
host: str = "localhost"
port: int = 8080
class Config(Settings.Config):
env_prefix = "web_"
class ErrorConfig(Settings):
pb_url: str = ""
pb_sunset: int = 43200
pb_max_lines: int = 1000
sentry_dsn: str = ""
notification_chat_id: Optional[str] = None
class Config(Settings.Config):
env_prefix = "error_"
class BotConfig(Settings):
debug: bool = False
bot_token: str = ""
channels: List["ConfigChannel"] = []
admins: List["ConfigUser"] = []
verify_groups: List[Union[int, str]] = []
join_groups: Optional[JoinGroups] = JoinGroups.NO_ALLOW
timeout: int = 10
read_timeout: float = 2
write_timeout: Optional[float] = None
connect_timeout: Optional[float] = None
pool_timeout: Optional[float] = None
genshin_ttl: Optional[int] = None
enka_network_api_agent: str = ""
pass_challenge_api: str = ""
pass_challenge_app_key: str = ""
mysql: MySqlConfig = MySqlConfig()
logger: LoggerConfig = LoggerConfig()
webserver: WebServerConfig = WebServerConfig()
redis: RedisConfig = RedisConfig()
mtproto: MTProtoConfig = MTProtoConfig()
error: ErrorConfig = ErrorConfig()
BotConfig.update_forward_refs()

View File

@ -1,20 +1,37 @@
import time
from typing import Optional
from urllib.parse import urlencode, urljoin, urlsplit
from urllib.parse import (
urlencode,
urljoin,
urlsplit,
)
from uuid import uuid4
from fastapi import HTTPException
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.responses import (
FileResponse,
HTMLResponse,
)
from fastapi.staticfiles import StaticFiles
from jinja2 import Environment, FileSystemLoader, Template
from jinja2 import (
Environment,
FileSystemLoader,
Template,
)
from playwright.async_api import ViewportSize
from core.base.aiobrowser import AioBrowser
from core.base.webserver import webapp
from core.bot import bot
from core.template.cache import HtmlToFileIdCache, TemplatePreviewCache
from core.template.cache import (
HtmlToFileIdCache,
TemplatePreviewCache,
)
from core.template.error import QuerySelectorNotFound
from core.template.models import FileType, RenderResult
from core.template.models import (
FileType,
RenderResult,
)
from utils.const import PROJECT_ROOT
from utils.log import logger
@ -149,7 +166,7 @@ class TemplatePreviewer:
async def get_preview_url(self, template: str, data: dict):
"""获取预览 URL"""
components = urlsplit(bot.config.web_url)
components = urlsplit(bot.config.webserver.url)
path = urljoin("/preview/", template)
query = {}
@ -187,4 +204,6 @@ class TemplatePreviewer:
# 其他静态资源
for name in ["cache", "resources"]:
directory = PROJECT_ROOT / name
directory.mkdir(exist_ok=True)
webapp.mount(f"/{name}", StaticFiles(directory=PROJECT_ROOT / name), name=name)

View File

@ -7,10 +7,10 @@ from utils.log import logger
class PbClient:
def __init__(self):
self.client = httpx.AsyncClient()
self.PB_API = config.error_pb_url
self.sunset: int = config.error_pb_sunset # 自动销毁时间 单位为秒
self.PB_API = config.error.pb_url
self.sunset: int = config.error.pb_sunset # 自动销毁时间 单位为秒
self.private: bool = True
self.max_lines: int = config.error_pb_max_lines
self.max_lines: int = config.error.pb_max_lines
async def create_pb(self, content: str) -> str:
if not self.PB_API:

View File

@ -15,7 +15,7 @@ from utils.log import logger
repo = Repo(os.getcwd())
sentry_sdk_git_hash = rev_parse(repo, "HEAD").hexsha
sentry_sdk.init(
config.error_sentry_dsn,
config.error.sentry_dsn,
traces_sample_rate=1.0,
release=sentry_sdk_git_hash,
environment="production",
@ -31,7 +31,7 @@ sentry_sdk.init(
class Sentry:
@staticmethod
def report_error(update: object, exc_info):
if not config.error_sentry_dsn:
if not config.error.sentry_dsn:
return
logger.info("正在上传日记到 sentry")
message: str = ""
@ -45,8 +45,6 @@ class Sentry:
if update.effective_message:
if update.effective_message.text:
message = update.effective_message.text
sentry_sdk.set_context(
"Target", {"ChatID": str(chat_id), "UserID": str(user_id), "Msg": message}
)
sentry_sdk.set_context("Target", {"ChatID": str(chat_id), "UserID": str(user_id), "Msg": message})
sentry_sdk.capture_exception(exc_info)
logger.success("上传日记到 sentry 成功")

View File

@ -4,18 +4,18 @@ import time
import traceback
import aiofiles
from telegram import Update, ReplyKeyboardRemove
from telegram import ReplyKeyboardRemove, Update
from telegram.constants import ParseMode
from telegram.error import BadRequest, Forbidden
from telegram.ext import CallbackContext
from core.bot import bot
from core.plugin import error_handler, Plugin
from core.plugin import Plugin, error_handler
from modules.error.pb import PbClient
from modules.error.sentry import Sentry
from utils.log import logger
notice_chat_id = bot.config.error_notification_chat_id
notice_chat_id = bot.config.error.notification_chat_id
current_dir = os.getcwd()
logs_dir = os.path.join(current_dir, "logs")
if not os.path.exists(logs_dir):

View File

@ -5,6 +5,7 @@ from httpx import URL
__all__ = [
"PROJECT_ROOT",
"CORE_DIR",
"PLUGIN_DIR",
"RESOURCE_DIR",
"NOT_SET",

View File

@ -4,7 +4,10 @@ from typing import TYPE_CHECKING
from core.config import config
from utils.log._config import LoggerConfig
from utils.log._logger import LogFilter, Logger
from utils.log._logger import (
LogFilter,
Logger,
)
if TYPE_CHECKING:
from logging import LogRecord
@ -21,7 +24,7 @@ logger = Logger(
keywords=config.logger.render_keywords,
traceback_locals_max_depth=config.logger.locals_max_depth,
traceback_locals_max_length=config.logger.locals_max_length,
traceback_locals_max_string=config.logger_locals_max_string,
traceback_locals_max_string=config.logger.locals_max_string,
)
)

View File

@ -1,7 +1,13 @@
import imghdr
import os
from enum import Enum
from typing import Union, Optional
from typing import (
Optional,
Union,
)
import ujson as json
from pydantic import BaseSettings
from utils.baseobject import BaseObject
@ -96,3 +102,14 @@ class ModuleInfo:
def __str__(self):
return self.module_name
class Settings(BaseSettings):
def __new__(cls, *args, **kwargs):
cls.update_forward_refs()
return super(Settings, cls).__new__(cls)
class Config(BaseSettings.Config):
case_sensitive = False
json_loads = json.loads
json_dumps = json.dumps