mirror of
https://github.com/PaiGramTeam/GramCore.git
synced 2024-11-24 07:10:37 +00:00
✨ Add channel alias service
This commit is contained in:
parent
b8260c2964
commit
968b3fd52d
@ -3,6 +3,9 @@ from .delete_message import DeleteMessage
|
|||||||
from .download_resource import DownloadResource
|
from .download_resource import DownloadResource
|
||||||
from .get_args import GetArgs
|
from .get_args import GetArgs
|
||||||
from .get_chat import GetChat
|
from .get_chat import GetChat
|
||||||
|
from .get_real_user_id import GetRealUserId
|
||||||
|
from .get_real_user_name import GetRealUserName
|
||||||
|
from .log_user import LogUser
|
||||||
from .migrate_data import MigrateData
|
from .migrate_data import MigrateData
|
||||||
|
|
||||||
|
|
||||||
@ -12,6 +15,9 @@ class PluginFuncMethods(
|
|||||||
DownloadResource,
|
DownloadResource,
|
||||||
GetArgs,
|
GetArgs,
|
||||||
GetChat,
|
GetChat,
|
||||||
|
GetRealUserId,
|
||||||
|
GetRealUserName,
|
||||||
|
LogUser,
|
||||||
MigrateData,
|
MigrateData,
|
||||||
):
|
):
|
||||||
"""插件方法"""
|
"""插件方法"""
|
||||||
|
23
plugin/methods/get_real_user_id.py
Normal file
23
plugin/methods/get_real_user_id.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from gram_core.services.channels.services import ChannelAliasService
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from . import PluginFuncMethods
|
||||||
|
from telegram import Update
|
||||||
|
|
||||||
|
|
||||||
|
class GetRealUserId:
|
||||||
|
async def get_real_user_id(self: "PluginFuncMethods", update: "Update") -> int:
|
||||||
|
message = update.effective_message
|
||||||
|
if message:
|
||||||
|
channel = message.sender_chat
|
||||||
|
if channel:
|
||||||
|
channel_alias_service: ChannelAliasService = self.application.managers.services_map.get(
|
||||||
|
ChannelAliasService, None
|
||||||
|
)
|
||||||
|
if channel_alias_service:
|
||||||
|
if uid := await channel_alias_service.get_uid_by_chat_id(channel.id, is_valid=True):
|
||||||
|
return uid
|
||||||
|
user = update.effective_user
|
||||||
|
return user.id
|
18
plugin/methods/get_real_user_name.py
Normal file
18
plugin/methods/get_real_user_name.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from telegram import Update
|
||||||
|
|
||||||
|
|
||||||
|
class GetRealUserName:
|
||||||
|
@staticmethod
|
||||||
|
def get_real_user_name(
|
||||||
|
update: "Update",
|
||||||
|
) -> str:
|
||||||
|
user = update.effective_user
|
||||||
|
message = update.effective_message
|
||||||
|
if message:
|
||||||
|
channel = message.sender_chat
|
||||||
|
if channel:
|
||||||
|
return channel.title
|
||||||
|
return user.first_name
|
24
plugin/methods/log_user.py
Normal file
24
plugin/methods/log_user.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from typing import TYPE_CHECKING, Callable, Union
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from telegram import Update
|
||||||
|
|
||||||
|
|
||||||
|
class LogUser:
|
||||||
|
@staticmethod
|
||||||
|
def log_user(update: Union["Update", int], func: Callable, msg: str, *args, **kwargs) -> None:
|
||||||
|
start_msg = "用户 %s[%s] "
|
||||||
|
if isinstance(update, int):
|
||||||
|
args2 = ("", update) + args
|
||||||
|
if update < 0:
|
||||||
|
start_msg = "频道 %s[%s] "
|
||||||
|
else:
|
||||||
|
user = update.effective_user
|
||||||
|
args2 = (user.full_name, user.id) + args
|
||||||
|
message = update.effective_message
|
||||||
|
if message:
|
||||||
|
channel = message.sender_chat
|
||||||
|
if channel:
|
||||||
|
start_msg = "频道 %s[%s] "
|
||||||
|
args2 = (channel.title, channel.id) + args
|
||||||
|
func(start_msg + str(msg), *args2, **kwargs)
|
0
services/channels/__init__.py
Normal file
0
services/channels/__init__.py
Normal file
29
services/channels/cache.py
Normal file
29
services/channels/cache.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from gram_core.base_service import BaseService
|
||||||
|
from gram_core.dependence.redisdb import RedisDB
|
||||||
|
|
||||||
|
__all__ = ("ChannelAliasCache",)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelAliasCache(BaseService.Component):
|
||||||
|
def __init__(self, redis: RedisDB):
|
||||||
|
self.client = redis.client
|
||||||
|
self.qname = "channels:alias"
|
||||||
|
self.ttl = 1 * 60 * 60
|
||||||
|
|
||||||
|
def cache_key(self, key: int) -> str:
|
||||||
|
return f"{self.qname}:{key}"
|
||||||
|
|
||||||
|
async def get_data(self, channel_id: int) -> Optional[int]:
|
||||||
|
data = await self.client.get(self.cache_key(channel_id))
|
||||||
|
if data:
|
||||||
|
return int(data.decode())
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def set_data(self, channel_id: int, user_id: int):
|
||||||
|
ck = self.cache_key(channel_id)
|
||||||
|
await self.client.set(ck, user_id, ex=self.ttl)
|
||||||
|
|
||||||
|
async def delete(self, channel_id: int):
|
||||||
|
await self.client.delete(self.cache_key(channel_id))
|
23
services/channels/models.py
Normal file
23
services/channels/models.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlmodel import SQLModel, Field, DateTime, Column, BigInteger, Integer
|
||||||
|
|
||||||
|
__all__ = (
|
||||||
|
"ChannelAlias",
|
||||||
|
"ChannelAliasDataBase",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelAlias(SQLModel):
|
||||||
|
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
||||||
|
id: Optional[int] = Field(default=None, sa_column=Column(Integer(), primary_key=True, autoincrement=True))
|
||||||
|
chat_id: int = Field(sa_column=Column(BigInteger(), unique=True))
|
||||||
|
user_id: int = Field(sa_column=Column(BigInteger()))
|
||||||
|
is_valid: bool = Field(default=True)
|
||||||
|
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True)))
|
||||||
|
updated_at: datetime = Field(sa_column=Column(DateTime(timezone=True)))
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelAliasDataBase(ChannelAlias, table=True):
|
||||||
|
__tablename__ = "channel_alias"
|
50
services/channels/repositories.py
Normal file
50
services/channels/repositories.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from sqlmodel import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from gram_core.base_service import BaseService
|
||||||
|
from gram_core.dependence.database import Database
|
||||||
|
from gram_core.services.channels.models import ChannelAliasDataBase as ChannelAlias
|
||||||
|
|
||||||
|
__all__ = ("ChannelAliasRepository",)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelAliasRepository(BaseService.Component):
|
||||||
|
def __init__(self, database: Database):
|
||||||
|
self.engine = database.engine
|
||||||
|
|
||||||
|
async def get_by_chat_id(self, chat_id: int, is_valid: Optional[bool] = None) -> Optional[ChannelAlias]:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
statement = select(ChannelAlias).where(ChannelAlias.chat_id == chat_id)
|
||||||
|
if is_valid is not None:
|
||||||
|
statement = statement.where(ChannelAlias.is_valid == is_valid)
|
||||||
|
results = await session.exec(statement)
|
||||||
|
return results.first()
|
||||||
|
|
||||||
|
async def add(self, channel_alias: ChannelAlias) -> ChannelAlias:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
session.add(channel_alias)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(channel_alias)
|
||||||
|
return channel_alias
|
||||||
|
|
||||||
|
async def update(self, channel_alias: ChannelAlias) -> ChannelAlias:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
session.add(channel_alias)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(channel_alias)
|
||||||
|
return channel_alias
|
||||||
|
|
||||||
|
async def remove(self, channel_alias: ChannelAlias):
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
await session.delete(channel_alias)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def get_all(self, is_valid: Optional[bool] = None) -> List[ChannelAlias]:
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
statement = select(ChannelAlias)
|
||||||
|
if is_valid is not None:
|
||||||
|
statement = statement.where(ChannelAlias.is_valid == is_valid)
|
||||||
|
results = await session.exec(statement)
|
||||||
|
return results.all()
|
49
services/channels/services.py
Normal file
49
services/channels/services.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from gram_core.base_service import BaseService
|
||||||
|
from gram_core.services.channels.cache import ChannelAliasCache
|
||||||
|
from gram_core.services.channels.models import ChannelAliasDataBase as ChannelAlias
|
||||||
|
from gram_core.services.channels.repositories import ChannelAliasRepository
|
||||||
|
|
||||||
|
__all__ = ("ChannelAliasService",)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelAliasService(BaseService):
|
||||||
|
def __init__(self, channel_alias_repository: ChannelAliasRepository, cache: ChannelAliasCache):
|
||||||
|
self.channel_alias_repository = channel_alias_repository
|
||||||
|
self._cache = cache
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
channels = await self.channel_alias_repository.get_all(is_valid=True)
|
||||||
|
for channel in channels:
|
||||||
|
if channel.chat_id and channel.user_id:
|
||||||
|
await self._cache.set_data(channel.chat_id, channel.user_id)
|
||||||
|
|
||||||
|
async def get_by_chat_id(self, chat_id: int, is_valid: Optional[bool] = None) -> Optional[ChannelAlias]:
|
||||||
|
return await self.channel_alias_repository.get_by_chat_id(chat_id, is_valid)
|
||||||
|
|
||||||
|
async def get_uid_by_chat_id(self, chat_id: int, is_valid: Optional[bool] = None) -> Optional[int]:
|
||||||
|
if uid := await self._cache.get_data(chat_id):
|
||||||
|
return uid
|
||||||
|
if channel := await self.get_by_chat_id(chat_id, is_valid):
|
||||||
|
await self._cache.set_data(channel.chat_id, channel.user_id)
|
||||||
|
return channel.user_id
|
||||||
|
await self._cache.set_data(chat_id, 0)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def add_channel_alias(self, channel_alias: ChannelAlias) -> ChannelAlias:
|
||||||
|
channel_alias = await self.channel_alias_repository.add(channel_alias)
|
||||||
|
await self._cache.set_data(channel_alias.chat_id, channel_alias.user_id)
|
||||||
|
return channel_alias
|
||||||
|
|
||||||
|
async def update_channel_alias(self, channel_alias: ChannelAlias) -> ChannelAlias:
|
||||||
|
channel_alias = await self.channel_alias_repository.update(channel_alias)
|
||||||
|
if channel_alias.is_valid:
|
||||||
|
await self._cache.set_data(channel_alias.chat_id, channel_alias.user_id)
|
||||||
|
else:
|
||||||
|
await self._cache.delete(channel_alias.chat_id)
|
||||||
|
return channel_alias
|
||||||
|
|
||||||
|
async def remove_channel_alias(self, channel_alias: ChannelAlias):
|
||||||
|
await self.channel_alias_repository.remove(channel_alias)
|
||||||
|
await self._cache.delete(channel_alias.chat_id)
|
Loading…
Reference in New Issue
Block a user