GramCore/services/groups/repositories.py
2024-01-16 22:38:33 +08:00

65 lines
2.4 KiB
Python

from datetime import datetime, timedelta
from typing import Optional, List
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from gram_core.base_service import BaseService
from gram_core.dependence.database import Database
from gram_core.services.groups.models import GroupDataBase as Group
__all__ = ("GroupRepository",)
class GroupRepository(BaseService.Component):
def __init__(self, database: Database):
self.engine = database.engine
async def get_by_chat_id(self, chat_id: int) -> Optional[Group]:
async with AsyncSession(self.engine) as session:
statement = select(Group).where(Group.chat_id == chat_id)
results = await session.exec(statement)
return results.first()
async def add(self, group: Group):
async with AsyncSession(self.engine) as session:
session.add(group)
await session.commit()
async def update(self, group: Group) -> Group:
async with AsyncSession(self.engine) as session:
session.add(group)
await session.commit()
await session.refresh(group)
return group
async def remove(self, group: Group):
async with AsyncSession(self.engine) as session:
await session.delete(group)
await session.commit()
async def get_all(self, is_banned: Optional[bool] = None, is_left: Optional[bool] = None) -> List[Group]:
async with AsyncSession(self.engine) as session:
statement = select(Group)
if is_banned is not None:
statement = statement.where(Group.is_banned == is_banned)
if is_left is not None:
statement = statement.where(Group.is_left == is_left)
results = await session.exec(statement)
return results.all()
async def get_no_need_update(self, limit: int = 10) -> List[Group]:
async with AsyncSession(self.engine) as session:
is_left = False
is_banned = False
statement = (
select(Group)
.where(Group.is_left == is_left)
.where(Group.is_banned == is_banned)
.where((Group.updated_at + timedelta(days=1)) > datetime.now())
)
if limit:
statement = statement.limit(limit)
results = await session.exec(statement)
return results.all()