Add channel alias service

This commit is contained in:
omg-xtao 2024-03-10 19:31:21 +08:00 committed by GitHub
parent b8260c2964
commit 968b3fd52d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 222 additions and 0 deletions

View File

@ -3,6 +3,9 @@ from .delete_message import DeleteMessage
from .download_resource import DownloadResource
from .get_args import GetArgs
from .get_chat import GetChat
from .get_real_user_id import GetRealUserId
from .get_real_user_name import GetRealUserName
from .log_user import LogUser
from .migrate_data import MigrateData
@ -12,6 +15,9 @@ class PluginFuncMethods(
DownloadResource,
GetArgs,
GetChat,
GetRealUserId,
GetRealUserName,
LogUser,
MigrateData,
):
"""插件方法"""

View File

@ -0,0 +1,23 @@
from typing import TYPE_CHECKING
from gram_core.services.channels.services import ChannelAliasService
if TYPE_CHECKING:
from . import PluginFuncMethods
from telegram import Update
class GetRealUserId:
async def get_real_user_id(self: "PluginFuncMethods", update: "Update") -> int:
message = update.effective_message
if message:
channel = message.sender_chat
if channel:
channel_alias_service: ChannelAliasService = self.application.managers.services_map.get(
ChannelAliasService, None
)
if channel_alias_service:
if uid := await channel_alias_service.get_uid_by_chat_id(channel.id, is_valid=True):
return uid
user = update.effective_user
return user.id

View File

@ -0,0 +1,18 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from telegram import Update
class GetRealUserName:
@staticmethod
def get_real_user_name(
update: "Update",
) -> str:
user = update.effective_user
message = update.effective_message
if message:
channel = message.sender_chat
if channel:
return channel.title
return user.first_name

View File

@ -0,0 +1,24 @@
from typing import TYPE_CHECKING, Callable, Union
if TYPE_CHECKING:
from telegram import Update
class LogUser:
@staticmethod
def log_user(update: Union["Update", int], func: Callable, msg: str, *args, **kwargs) -> None:
start_msg = "用户 %s[%s] "
if isinstance(update, int):
args2 = ("", update) + args
if update < 0:
start_msg = "频道 %s[%s] "
else:
user = update.effective_user
args2 = (user.full_name, user.id) + args
message = update.effective_message
if message:
channel = message.sender_chat
if channel:
start_msg = "频道 %s[%s] "
args2 = (channel.title, channel.id) + args
func(start_msg + str(msg), *args2, **kwargs)

View File

View File

@ -0,0 +1,29 @@
from typing import Optional
from gram_core.base_service import BaseService
from gram_core.dependence.redisdb import RedisDB
__all__ = ("ChannelAliasCache",)
class ChannelAliasCache(BaseService.Component):
def __init__(self, redis: RedisDB):
self.client = redis.client
self.qname = "channels:alias"
self.ttl = 1 * 60 * 60
def cache_key(self, key: int) -> str:
return f"{self.qname}:{key}"
async def get_data(self, channel_id: int) -> Optional[int]:
data = await self.client.get(self.cache_key(channel_id))
if data:
return int(data.decode())
return None
async def set_data(self, channel_id: int, user_id: int):
ck = self.cache_key(channel_id)
await self.client.set(ck, user_id, ex=self.ttl)
async def delete(self, channel_id: int):
await self.client.delete(self.cache_key(channel_id))

View File

@ -0,0 +1,23 @@
from datetime import datetime
from typing import Optional
from sqlmodel import SQLModel, Field, DateTime, Column, BigInteger, Integer
__all__ = (
"ChannelAlias",
"ChannelAliasDataBase",
)
class ChannelAlias(SQLModel):
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
id: Optional[int] = Field(default=None, sa_column=Column(Integer(), primary_key=True, autoincrement=True))
chat_id: int = Field(sa_column=Column(BigInteger(), unique=True))
user_id: int = Field(sa_column=Column(BigInteger()))
is_valid: bool = Field(default=True)
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True)))
updated_at: datetime = Field(sa_column=Column(DateTime(timezone=True)))
class ChannelAliasDataBase(ChannelAlias, table=True):
__tablename__ = "channel_alias"

View File

@ -0,0 +1,50 @@
from typing import Optional, List
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from gram_core.base_service import BaseService
from gram_core.dependence.database import Database
from gram_core.services.channels.models import ChannelAliasDataBase as ChannelAlias
__all__ = ("ChannelAliasRepository",)
class ChannelAliasRepository(BaseService.Component):
def __init__(self, database: Database):
self.engine = database.engine
async def get_by_chat_id(self, chat_id: int, is_valid: Optional[bool] = None) -> Optional[ChannelAlias]:
async with AsyncSession(self.engine) as session:
statement = select(ChannelAlias).where(ChannelAlias.chat_id == chat_id)
if is_valid is not None:
statement = statement.where(ChannelAlias.is_valid == is_valid)
results = await session.exec(statement)
return results.first()
async def add(self, channel_alias: ChannelAlias) -> ChannelAlias:
async with AsyncSession(self.engine) as session:
session.add(channel_alias)
await session.commit()
await session.refresh(channel_alias)
return channel_alias
async def update(self, channel_alias: ChannelAlias) -> ChannelAlias:
async with AsyncSession(self.engine) as session:
session.add(channel_alias)
await session.commit()
await session.refresh(channel_alias)
return channel_alias
async def remove(self, channel_alias: ChannelAlias):
async with AsyncSession(self.engine) as session:
await session.delete(channel_alias)
await session.commit()
async def get_all(self, is_valid: Optional[bool] = None) -> List[ChannelAlias]:
async with AsyncSession(self.engine) as session:
statement = select(ChannelAlias)
if is_valid is not None:
statement = statement.where(ChannelAlias.is_valid == is_valid)
results = await session.exec(statement)
return results.all()

View File

@ -0,0 +1,49 @@
from typing import Optional
from gram_core.base_service import BaseService
from gram_core.services.channels.cache import ChannelAliasCache
from gram_core.services.channels.models import ChannelAliasDataBase as ChannelAlias
from gram_core.services.channels.repositories import ChannelAliasRepository
__all__ = ("ChannelAliasService",)
class ChannelAliasService(BaseService):
def __init__(self, channel_alias_repository: ChannelAliasRepository, cache: ChannelAliasCache):
self.channel_alias_repository = channel_alias_repository
self._cache = cache
async def initialize(self):
channels = await self.channel_alias_repository.get_all(is_valid=True)
for channel in channels:
if channel.chat_id and channel.user_id:
await self._cache.set_data(channel.chat_id, channel.user_id)
async def get_by_chat_id(self, chat_id: int, is_valid: Optional[bool] = None) -> Optional[ChannelAlias]:
return await self.channel_alias_repository.get_by_chat_id(chat_id, is_valid)
async def get_uid_by_chat_id(self, chat_id: int, is_valid: Optional[bool] = None) -> Optional[int]:
if uid := await self._cache.get_data(chat_id):
return uid
if channel := await self.get_by_chat_id(chat_id, is_valid):
await self._cache.set_data(channel.chat_id, channel.user_id)
return channel.user_id
await self._cache.set_data(chat_id, 0)
return None
async def add_channel_alias(self, channel_alias: ChannelAlias) -> ChannelAlias:
channel_alias = await self.channel_alias_repository.add(channel_alias)
await self._cache.set_data(channel_alias.chat_id, channel_alias.user_id)
return channel_alias
async def update_channel_alias(self, channel_alias: ChannelAlias) -> ChannelAlias:
channel_alias = await self.channel_alias_repository.update(channel_alias)
if channel_alias.is_valid:
await self._cache.set_data(channel_alias.chat_id, channel_alias.user_id)
else:
await self._cache.delete(channel_alias.chat_id)
return channel_alias
async def remove_channel_alias(self, channel_alias: ChannelAlias):
await self.channel_alias_repository.remove(channel_alias)
await self._cache.delete(channel_alias.chat_id)