From 7501be8d10cd998d3d0d394e72986313affefe93 Mon Sep 17 00:00:00 2001 From: xtaodada Date: Thu, 7 Nov 2024 17:10:19 +0800 Subject: [PATCH] feat: user avatar update --- pyproject.toml | 1 + src/errors.py | 3 ++ src/route/users_update.py | 41 +++++++++++++++++++++++++- src/services/users/schemas.py | 8 +++++ src/services/users/services.py | 9 ++++++ src/utils/__init__.py | 5 +--- src/utils/_path.py | 9 ++++++ src/utils/upload_file.py | 53 ++++++++++++++++++++++++++++++++++ uv.lock | 4 ++- 9 files changed, 127 insertions(+), 6 deletions(-) create mode 100644 src/errors.py create mode 100644 src/utils/_path.py create mode 100644 src/utils/upload_file.py diff --git a/pyproject.toml b/pyproject.toml index 623c44c..93bfb4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ description = "Add your description here" readme = "README.md" requires-python = ">=3.11" dependencies = [ + "aiofiles>=24.1.0", "alembic>=1.13.3", "asyncmy>=0.2.9", "black>=24.10.0", diff --git a/src/errors.py b/src/errors.py new file mode 100644 index 0000000..3a3fa0a --- /dev/null +++ b/src/errors.py @@ -0,0 +1,3 @@ +class ProjectBaseError(Exception): + def __init__(self, msg: str): + self.msg = msg diff --git a/src/route/users_update.py b/src/route/users_update.py index ed95001..c88a80c 100644 --- a/src/route/users_update.py +++ b/src/route/users_update.py @@ -1,14 +1,17 @@ +from fastapi import File, UploadFile from fastapi_amis_admin.crud import BaseApiOut from starlette import status from starlette.exceptions import HTTPException from starlette.requests import Request +from starlette.responses import FileResponse from src.plugin import handler from src.plugin.plugin import Plugin from src.services.users.models import UserModel -from src.services.users.schemas import UserUpdate +from src.services.users.schemas import UserUpdate, UserUpdateAvatar from src.services.users.services import UserServices, UserRoleServices +from src.utils.upload_file import get_avatar, save_avatar, check_avatar class UserUpdateRoutes(Plugin): @@ -19,6 +22,7 @@ class UserUpdateRoutes(Plugin): ): self.user_services = user_services self.user_role_services = user_role_services + self.avatar_path = "/user/avatar/" @handler.get("/me", student=True, out=True) async def get_me(self, request: Request): @@ -62,3 +66,38 @@ class UserUpdateRoutes(Plugin): status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error Execute SQL:{e}", ) from e + + @handler.get("/avatar/{uid}/{file_path}", admin=False) + async def get_avatar(self, request: Request, uid: int, file_path: str): + # if request.user.id != uid: + # return BaseApiOut(status=500, msg="无权查看他人头像") + path = await get_avatar(uid, file_path) + if not path: + return BaseApiOut(status=500, msg="文件不存在") + return FileResponse(path) + + @handler.post("/update/avatar/upload", student=True, out=True) + async def update_upload_avatar( + self, request: Request, file: UploadFile = File(...) + ): + user: "UserModel" = request.user + path = await save_avatar(user.id, file) + real_path = self.avatar_path + str(user.id) + "/" + path + return BaseApiOut(code=0, msg="上传成功", data=real_path) + + @handler.post("/update/avatar/save", student=True, out=True) + async def update_save_avatar(self, request: Request, data: UserUpdateAvatar): + user: "UserModel" = request.user + avatar = data.avatar + if not avatar.startswith(self.avatar_path): + return BaseApiOut(status=500, msg="头像地址错误") + if not await check_avatar(user.id, avatar[len(self.avatar_path) :]): + return BaseApiOut(status=500, msg="头像不存在") + try: + user = await self.user_services.update_user_avatar(user.username, avatar) + return BaseApiOut(code=0, msg="更新成功", data=user) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error Execute SQL:{e}", + ) from e diff --git a/src/services/users/schemas.py b/src/services/users/schemas.py index bc1160d..67fe4f5 100644 --- a/src/services/users/schemas.py +++ b/src/services/users/schemas.py @@ -89,6 +89,14 @@ class UserUpdate( ) +class UserUpdateAvatar(BaseModel): + avatar: str = Field( + title=_("Avatar"), + max_length=255, + nullable=True, + ) + + # 默认保留的用户 class SystemUserEnum(str, Enum): ROOT = "root" diff --git a/src/services/users/services.py b/src/services/users/services.py index 3459255..af36edf 100644 --- a/src/services/users/services.py +++ b/src/services/users/services.py @@ -86,6 +86,15 @@ class UserServices(AsyncInitializingComponent): user.password = self.repo.AUTH.pwd_context.hash(password) return await self.repo.update_user(user) + async def update_user_avatar( + self, username: str, avatar: str + ) -> Optional[UserModel]: + user = await self.get_user(username=username) + if not user: + return None + user.avatar = avatar + return await self.repo.update_user(user) + class UserRoleServices(AsyncInitializingComponent): __order__ = 1 diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 6eb97d8..15a6bae 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1,4 +1 @@ -from pathlib import Path - -PROJECT_ROOT = Path(__file__).joinpath("../../..").resolve() -SERVICES_PATH = PROJECT_ROOT.joinpath("src/services") +from ._path import PROJECT_ROOT, SERVICES_PATH diff --git a/src/utils/_path.py b/src/utils/_path.py new file mode 100644 index 0000000..e324100 --- /dev/null +++ b/src/utils/_path.py @@ -0,0 +1,9 @@ +from pathlib import Path + +PROJECT_ROOT = Path(__file__).joinpath("../../..").resolve() +SERVICES_PATH = PROJECT_ROOT.joinpath("src/services") + +DATA_PATH = PROJECT_ROOT / "data" +DATA_PATH.mkdir(exist_ok=True) +AVATAR_DATA_PATH = DATA_PATH / "avatar" +AVATAR_DATA_PATH.mkdir(exist_ok=True) diff --git a/src/utils/upload_file.py b/src/utils/upload_file.py new file mode 100644 index 0000000..3990d26 --- /dev/null +++ b/src/utils/upload_file.py @@ -0,0 +1,53 @@ +import hashlib +from pathlib import Path +from typing import Optional + +import aiofiles +from fastapi import UploadFile + +from ._path import AVATAR_DATA_PATH +from ..errors import ProjectBaseError + +AVATAR_MAX_FILE_SIZE = 5 * 1024 * 1024 # 5MB + + +async def check_avatar(uid: int, uri_path: str) -> bool: + try: + real_uid, file_path = uri_path.split("/") + if int(real_uid) != uid: + return False + path = AVATAR_DATA_PATH / f"{uid}" / f"{file_path}" + return path.exists() + except ValueError: + return False + + +async def get_avatar(uid: int, file_path: str) -> Optional[Path]: + path = AVATAR_DATA_PATH / f"{uid}" / f"{file_path}" + if not path.exists(): + return None + return path + + +async def save_avatar(uid: int, file: UploadFile) -> str: + filename = file.filename.lower() if file.filename else "" + if not filename or not filename.endswith(".jpg"): + raise ProjectBaseError("请上传 jpg 格式的文件") + + path = AVATAR_DATA_PATH / f"{uid}" + path.mkdir(exist_ok=True) + + file_data = await file.read() + if len(file_data) > AVATAR_MAX_FILE_SIZE: + raise ProjectBaseError("文件过大,请上传小于5MB的文件") + + name = ( + hashlib.md5(file_data).hexdigest() + ".jpg" + ) # 使用md5作为文件名,以免同一个文件多次写入 + file_path = path / name + if file_path.exists(): + return name + + async with aiofiles.open(file_path, "wb") as f: + await f.write(file_data) + return name diff --git a/uv.lock b/uv.lock index ef9dc23..985cf25 100644 --- a/uv.lock +++ b/uv.lock @@ -660,6 +660,7 @@ name = "yoloface-be" version = "0.1.0" source = { virtual = "." } dependencies = [ + { name = "aiofiles" }, { name = "alembic" }, { name = "asyncmy" }, { name = "black" }, @@ -678,9 +679,10 @@ dependencies = [ [package.metadata] requires-dist = [ + { name = "aiofiles" }, { name = "alembic", specifier = ">=1.13.3" }, { name = "asyncmy", specifier = ">=0.2.9" }, - { name = "black" }, + { name = "black", specifier = ">=24.10.0" }, { name = "fakeredis", specifier = ">=2.26.1" }, { name = "fastapi", specifier = "==0.112.2" }, { name = "fastapi-amis-admin", specifier = ">=0.7.2" },