mirror of
https://github.com/PaiGramTeam/GramCore.git
synced 2024-11-25 15:42:14 +00:00
65 lines
2.4 KiB
Python
65 lines
2.4 KiB
Python
from datetime import datetime, timedelta
|
|
from typing import Optional, List
|
|
|
|
from sqlmodel import select
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from gram_core.base_service import BaseService
|
|
from gram_core.dependence.database import Database
|
|
from gram_core.services.groups.models import GroupDataBase as Group
|
|
|
|
__all__ = ("GroupRepository",)
|
|
|
|
|
|
class GroupRepository(BaseService.Component):
|
|
def __init__(self, database: Database):
|
|
self.engine = database.engine
|
|
|
|
async def get_by_chat_id(self, chat_id: int) -> Optional[Group]:
|
|
async with AsyncSession(self.engine) as session:
|
|
statement = select(Group).where(Group.chat_id == chat_id)
|
|
results = await session.exec(statement)
|
|
return results.first()
|
|
|
|
async def add(self, group: Group):
|
|
async with AsyncSession(self.engine) as session:
|
|
session.add(group)
|
|
await session.commit()
|
|
|
|
async def update(self, group: Group) -> Group:
|
|
async with AsyncSession(self.engine) as session:
|
|
session.add(group)
|
|
await session.commit()
|
|
await session.refresh(group)
|
|
return group
|
|
|
|
async def remove(self, group: Group):
|
|
async with AsyncSession(self.engine) as session:
|
|
await session.delete(group)
|
|
await session.commit()
|
|
|
|
async def get_all(self, is_banned: Optional[bool] = None, is_left: Optional[bool] = None) -> List[Group]:
|
|
async with AsyncSession(self.engine) as session:
|
|
statement = select(Group)
|
|
if is_banned is not None:
|
|
statement = statement.where(Group.is_banned == is_banned)
|
|
if is_left is not None:
|
|
statement = statement.where(Group.is_left == is_left)
|
|
results = await session.exec(statement)
|
|
return results.all()
|
|
|
|
async def get_no_need_update(self, limit: int = 10) -> List[Group]:
|
|
async with AsyncSession(self.engine) as session:
|
|
is_left = False
|
|
is_banned = False
|
|
statement = (
|
|
select(Group)
|
|
.where(Group.is_left == is_left)
|
|
.where(Group.is_banned == is_banned)
|
|
.where((Group.updated_at + timedelta(days=1)) > datetime.now())
|
|
)
|
|
if limit:
|
|
statement = statement.limit(limit)
|
|
results = await session.exec(statement)
|
|
return results.all()
|