Support mult user task service

This commit is contained in:
xtaodada 2024-11-18 21:04:23 +08:00
parent 7129322f6d
commit 112b2e92d8
Signed by: xtaodada
GPG Key ID: 4CBB3F4FA8C85659
3 changed files with 12 additions and 5 deletions

View File

@ -33,6 +33,7 @@ class Task(SQLModel, table=True):
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci") __table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
id: Optional[int] = Field(default=None, sa_column=Column(Integer(), primary_key=True, autoincrement=True)) id: Optional[int] = Field(default=None, sa_column=Column(Integer(), primary_key=True, autoincrement=True))
user_id: int = Field(sa_column=Column(BigInteger(), primary_key=True, index=True)) user_id: int = Field(sa_column=Column(BigInteger(), primary_key=True, index=True))
player_id: int = Field(sa_column=Column(BigInteger(), primary_key=True, index=True))
chat_id: Optional[int] = Field(default=None, sa_column=Column(BigInteger())) chat_id: Optional[int] = Field(default=None, sa_column=Column(BigInteger()))
time_created: Optional[datetime] = Field( time_created: Optional[datetime] = Field(
sa_column=Column(DateTime, server_default=func.now()) # pylint: disable=E1102 sa_column=Column(DateTime, server_default=func.now()) # pylint: disable=E1102

View File

@ -31,9 +31,14 @@ class TaskRepository(BaseService.Component):
await session.refresh(task) await session.refresh(task)
return task return task
async def get_by_user_id(self, user_id: int, task_type: TaskTypeEnum) -> Optional[Task]: async def get_by_user_id(self, user_id: int, player_id: int, task_type: TaskTypeEnum) -> Optional[Task]:
async with AsyncSession(self.engine) as session: async with AsyncSession(self.engine) as session:
statement = select(Task).where(Task.user_id == user_id).where(Task.type == task_type) statement = (
select(Task)
.where(Task.user_id == user_id)
.where(Task.player_id == player_id)
.where(Task.type == task_type)
)
results = await session.exec(statement) results = await session.exec(statement)
return results.first() return results.first()

View File

@ -32,8 +32,8 @@ class TaskServices:
task.time_updated = datetime.datetime.now() task.time_updated = datetime.datetime.now()
return await self._repository.update(task) return await self._repository.update(task)
async def get_by_user_id(self, user_id: int): async def get_by_user_id(self, user_id: int, player_id: int):
return await self._repository.get_by_user_id(user_id, self.TASK_TYPE) return await self._repository.get_by_user_id(user_id, player_id, self.TASK_TYPE)
async def get_all(self): async def get_all(self):
return await self._repository.get_all(self.TASK_TYPE) return await self._repository.get_all(self.TASK_TYPE)
@ -41,9 +41,10 @@ class TaskServices:
async def get_all_by_user_id(self, user_id: int): async def get_all_by_user_id(self, user_id: int):
return await self._repository.get_all_by_user_id(user_id) return await self._repository.get_all_by_user_id(user_id)
def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None): def create(self, user_id: int, player_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None):
return Task( return Task(
user_id=user_id, user_id=user_id,
player_id=player_id,
chat_id=chat_id, chat_id=chat_id,
time_created=datetime.datetime.now(), time_created=datetime.datetime.now(),
status=status, status=status,