From 112b2e92d8492df17dbae23024fa805e3510a56e Mon Sep 17 00:00:00 2001 From: xtaodada Date: Mon, 18 Nov 2024 21:04:23 +0800 Subject: [PATCH] :sparkles: Support mult user task service --- services/task/models.py | 1 + services/task/repositories.py | 9 +++++++-- services/task/services.py | 7 ++++--- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/services/task/models.py b/services/task/models.py index cb0c4c0..06b8cd9 100644 --- a/services/task/models.py +++ b/services/task/models.py @@ -33,6 +33,7 @@ class Task(SQLModel, table=True): __table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci") id: Optional[int] = Field(default=None, sa_column=Column(Integer(), primary_key=True, autoincrement=True)) user_id: int = Field(sa_column=Column(BigInteger(), primary_key=True, index=True)) + player_id: int = Field(sa_column=Column(BigInteger(), primary_key=True, index=True)) chat_id: Optional[int] = Field(default=None, sa_column=Column(BigInteger())) time_created: Optional[datetime] = Field( sa_column=Column(DateTime, server_default=func.now()) # pylint: disable=E1102 diff --git a/services/task/repositories.py b/services/task/repositories.py index 6fcea59..3791ece 100644 --- a/services/task/repositories.py +++ b/services/task/repositories.py @@ -31,9 +31,14 @@ class TaskRepository(BaseService.Component): await session.refresh(task) return task - async def get_by_user_id(self, user_id: int, task_type: TaskTypeEnum) -> Optional[Task]: + async def get_by_user_id(self, user_id: int, player_id: int, task_type: TaskTypeEnum) -> Optional[Task]: async with AsyncSession(self.engine) as session: - statement = select(Task).where(Task.user_id == user_id).where(Task.type == task_type) + statement = ( + select(Task) + .where(Task.user_id == user_id) + .where(Task.player_id == player_id) + .where(Task.type == task_type) + ) results = await session.exec(statement) return results.first() diff --git a/services/task/services.py b/services/task/services.py index 47b5048..f41c318 100644 --- a/services/task/services.py +++ b/services/task/services.py @@ -32,8 +32,8 @@ class TaskServices: task.time_updated = datetime.datetime.now() return await self._repository.update(task) - async def get_by_user_id(self, user_id: int): - return await self._repository.get_by_user_id(user_id, self.TASK_TYPE) + async def get_by_user_id(self, user_id: int, player_id: int): + return await self._repository.get_by_user_id(user_id, player_id, self.TASK_TYPE) async def get_all(self): return await self._repository.get_all(self.TASK_TYPE) @@ -41,9 +41,10 @@ class TaskServices: async def get_all_by_user_id(self, user_id: int): return await self._repository.get_all_by_user_id(user_id) - def create(self, user_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None): + def create(self, user_id: int, player_id: int, chat_id: int, status: int, data: Optional[Dict[str, Any]] = None): return Task( user_id=user_id, + player_id=player_id, chat_id=chat_id, time_created=datetime.datetime.now(), status=status,