mirror of
https://github.com/PaiGramTeam/GramCore.git
synced 2024-11-21 13:48:20 +00:00
✨ Add channel alias service
This commit is contained in:
parent
b8260c2964
commit
968b3fd52d
@ -3,6 +3,9 @@ from .delete_message import DeleteMessage
|
||||
from .download_resource import DownloadResource
|
||||
from .get_args import GetArgs
|
||||
from .get_chat import GetChat
|
||||
from .get_real_user_id import GetRealUserId
|
||||
from .get_real_user_name import GetRealUserName
|
||||
from .log_user import LogUser
|
||||
from .migrate_data import MigrateData
|
||||
|
||||
|
||||
@ -12,6 +15,9 @@ class PluginFuncMethods(
|
||||
DownloadResource,
|
||||
GetArgs,
|
||||
GetChat,
|
||||
GetRealUserId,
|
||||
GetRealUserName,
|
||||
LogUser,
|
||||
MigrateData,
|
||||
):
|
||||
"""插件方法"""
|
||||
|
23
plugin/methods/get_real_user_id.py
Normal file
23
plugin/methods/get_real_user_id.py
Normal file
@ -0,0 +1,23 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from gram_core.services.channels.services import ChannelAliasService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import PluginFuncMethods
|
||||
from telegram import Update
|
||||
|
||||
|
||||
class GetRealUserId:
|
||||
async def get_real_user_id(self: "PluginFuncMethods", update: "Update") -> int:
|
||||
message = update.effective_message
|
||||
if message:
|
||||
channel = message.sender_chat
|
||||
if channel:
|
||||
channel_alias_service: ChannelAliasService = self.application.managers.services_map.get(
|
||||
ChannelAliasService, None
|
||||
)
|
||||
if channel_alias_service:
|
||||
if uid := await channel_alias_service.get_uid_by_chat_id(channel.id, is_valid=True):
|
||||
return uid
|
||||
user = update.effective_user
|
||||
return user.id
|
18
plugin/methods/get_real_user_name.py
Normal file
18
plugin/methods/get_real_user_name.py
Normal file
@ -0,0 +1,18 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from telegram import Update
|
||||
|
||||
|
||||
class GetRealUserName:
|
||||
@staticmethod
|
||||
def get_real_user_name(
|
||||
update: "Update",
|
||||
) -> str:
|
||||
user = update.effective_user
|
||||
message = update.effective_message
|
||||
if message:
|
||||
channel = message.sender_chat
|
||||
if channel:
|
||||
return channel.title
|
||||
return user.first_name
|
24
plugin/methods/log_user.py
Normal file
24
plugin/methods/log_user.py
Normal file
@ -0,0 +1,24 @@
|
||||
from typing import TYPE_CHECKING, Callable, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from telegram import Update
|
||||
|
||||
|
||||
class LogUser:
|
||||
@staticmethod
|
||||
def log_user(update: Union["Update", int], func: Callable, msg: str, *args, **kwargs) -> None:
|
||||
start_msg = "用户 %s[%s] "
|
||||
if isinstance(update, int):
|
||||
args2 = ("", update) + args
|
||||
if update < 0:
|
||||
start_msg = "频道 %s[%s] "
|
||||
else:
|
||||
user = update.effective_user
|
||||
args2 = (user.full_name, user.id) + args
|
||||
message = update.effective_message
|
||||
if message:
|
||||
channel = message.sender_chat
|
||||
if channel:
|
||||
start_msg = "频道 %s[%s] "
|
||||
args2 = (channel.title, channel.id) + args
|
||||
func(start_msg + str(msg), *args2, **kwargs)
|
0
services/channels/__init__.py
Normal file
0
services/channels/__init__.py
Normal file
29
services/channels/cache.py
Normal file
29
services/channels/cache.py
Normal file
@ -0,0 +1,29 @@
|
||||
from typing import Optional
|
||||
|
||||
from gram_core.base_service import BaseService
|
||||
from gram_core.dependence.redisdb import RedisDB
|
||||
|
||||
__all__ = ("ChannelAliasCache",)
|
||||
|
||||
|
||||
class ChannelAliasCache(BaseService.Component):
|
||||
def __init__(self, redis: RedisDB):
|
||||
self.client = redis.client
|
||||
self.qname = "channels:alias"
|
||||
self.ttl = 1 * 60 * 60
|
||||
|
||||
def cache_key(self, key: int) -> str:
|
||||
return f"{self.qname}:{key}"
|
||||
|
||||
async def get_data(self, channel_id: int) -> Optional[int]:
|
||||
data = await self.client.get(self.cache_key(channel_id))
|
||||
if data:
|
||||
return int(data.decode())
|
||||
return None
|
||||
|
||||
async def set_data(self, channel_id: int, user_id: int):
|
||||
ck = self.cache_key(channel_id)
|
||||
await self.client.set(ck, user_id, ex=self.ttl)
|
||||
|
||||
async def delete(self, channel_id: int):
|
||||
await self.client.delete(self.cache_key(channel_id))
|
23
services/channels/models.py
Normal file
23
services/channels/models.py
Normal file
@ -0,0 +1,23 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import SQLModel, Field, DateTime, Column, BigInteger, Integer
|
||||
|
||||
__all__ = (
|
||||
"ChannelAlias",
|
||||
"ChannelAliasDataBase",
|
||||
)
|
||||
|
||||
|
||||
class ChannelAlias(SQLModel):
|
||||
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
||||
id: Optional[int] = Field(default=None, sa_column=Column(Integer(), primary_key=True, autoincrement=True))
|
||||
chat_id: int = Field(sa_column=Column(BigInteger(), unique=True))
|
||||
user_id: int = Field(sa_column=Column(BigInteger()))
|
||||
is_valid: bool = Field(default=True)
|
||||
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True)))
|
||||
updated_at: datetime = Field(sa_column=Column(DateTime(timezone=True)))
|
||||
|
||||
|
||||
class ChannelAliasDataBase(ChannelAlias, table=True):
|
||||
__tablename__ = "channel_alias"
|
50
services/channels/repositories.py
Normal file
50
services/channels/repositories.py
Normal file
@ -0,0 +1,50 @@
|
||||
from typing import Optional, List
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from gram_core.base_service import BaseService
|
||||
from gram_core.dependence.database import Database
|
||||
from gram_core.services.channels.models import ChannelAliasDataBase as ChannelAlias
|
||||
|
||||
__all__ = ("ChannelAliasRepository",)
|
||||
|
||||
|
||||
class ChannelAliasRepository(BaseService.Component):
|
||||
def __init__(self, database: Database):
|
||||
self.engine = database.engine
|
||||
|
||||
async def get_by_chat_id(self, chat_id: int, is_valid: Optional[bool] = None) -> Optional[ChannelAlias]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = select(ChannelAlias).where(ChannelAlias.chat_id == chat_id)
|
||||
if is_valid is not None:
|
||||
statement = statement.where(ChannelAlias.is_valid == is_valid)
|
||||
results = await session.exec(statement)
|
||||
return results.first()
|
||||
|
||||
async def add(self, channel_alias: ChannelAlias) -> ChannelAlias:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(channel_alias)
|
||||
await session.commit()
|
||||
await session.refresh(channel_alias)
|
||||
return channel_alias
|
||||
|
||||
async def update(self, channel_alias: ChannelAlias) -> ChannelAlias:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(channel_alias)
|
||||
await session.commit()
|
||||
await session.refresh(channel_alias)
|
||||
return channel_alias
|
||||
|
||||
async def remove(self, channel_alias: ChannelAlias):
|
||||
async with AsyncSession(self.engine) as session:
|
||||
await session.delete(channel_alias)
|
||||
await session.commit()
|
||||
|
||||
async def get_all(self, is_valid: Optional[bool] = None) -> List[ChannelAlias]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = select(ChannelAlias)
|
||||
if is_valid is not None:
|
||||
statement = statement.where(ChannelAlias.is_valid == is_valid)
|
||||
results = await session.exec(statement)
|
||||
return results.all()
|
49
services/channels/services.py
Normal file
49
services/channels/services.py
Normal file
@ -0,0 +1,49 @@
|
||||
from typing import Optional
|
||||
|
||||
from gram_core.base_service import BaseService
|
||||
from gram_core.services.channels.cache import ChannelAliasCache
|
||||
from gram_core.services.channels.models import ChannelAliasDataBase as ChannelAlias
|
||||
from gram_core.services.channels.repositories import ChannelAliasRepository
|
||||
|
||||
__all__ = ("ChannelAliasService",)
|
||||
|
||||
|
||||
class ChannelAliasService(BaseService):
|
||||
def __init__(self, channel_alias_repository: ChannelAliasRepository, cache: ChannelAliasCache):
|
||||
self.channel_alias_repository = channel_alias_repository
|
||||
self._cache = cache
|
||||
|
||||
async def initialize(self):
|
||||
channels = await self.channel_alias_repository.get_all(is_valid=True)
|
||||
for channel in channels:
|
||||
if channel.chat_id and channel.user_id:
|
||||
await self._cache.set_data(channel.chat_id, channel.user_id)
|
||||
|
||||
async def get_by_chat_id(self, chat_id: int, is_valid: Optional[bool] = None) -> Optional[ChannelAlias]:
|
||||
return await self.channel_alias_repository.get_by_chat_id(chat_id, is_valid)
|
||||
|
||||
async def get_uid_by_chat_id(self, chat_id: int, is_valid: Optional[bool] = None) -> Optional[int]:
|
||||
if uid := await self._cache.get_data(chat_id):
|
||||
return uid
|
||||
if channel := await self.get_by_chat_id(chat_id, is_valid):
|
||||
await self._cache.set_data(channel.chat_id, channel.user_id)
|
||||
return channel.user_id
|
||||
await self._cache.set_data(chat_id, 0)
|
||||
return None
|
||||
|
||||
async def add_channel_alias(self, channel_alias: ChannelAlias) -> ChannelAlias:
|
||||
channel_alias = await self.channel_alias_repository.add(channel_alias)
|
||||
await self._cache.set_data(channel_alias.chat_id, channel_alias.user_id)
|
||||
return channel_alias
|
||||
|
||||
async def update_channel_alias(self, channel_alias: ChannelAlias) -> ChannelAlias:
|
||||
channel_alias = await self.channel_alias_repository.update(channel_alias)
|
||||
if channel_alias.is_valid:
|
||||
await self._cache.set_data(channel_alias.chat_id, channel_alias.user_id)
|
||||
else:
|
||||
await self._cache.delete(channel_alias.chat_id)
|
||||
return channel_alias
|
||||
|
||||
async def remove_channel_alias(self, channel_alias: ChannelAlias):
|
||||
await self.channel_alias_repository.remove(channel_alias)
|
||||
await self._cache.delete(channel_alias.chat_id)
|
Loading…
Reference in New Issue
Block a user