feat: user profile update
This commit is contained in:
parent
986fc90826
commit
184a97e486
64
src/route/users_update.py
Normal file
64
src/route/users_update.py
Normal file
@ -0,0 +1,64 @@
|
||||
from fastapi_amis_admin.crud import BaseApiOut
|
||||
from starlette import status
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from starlette.requests import Request
|
||||
|
||||
from src.plugin import handler
|
||||
from src.plugin.plugin import Plugin
|
||||
from src.services.users.models import UserModel
|
||||
from src.services.users.schemas import UserUpdate
|
||||
from src.services.users.services import UserServices, UserRoleServices
|
||||
|
||||
|
||||
class UserUpdateRoutes(Plugin):
|
||||
_prefix = "/user"
|
||||
|
||||
def __init__(
|
||||
self, user_services: UserServices, user_role_services: UserRoleServices
|
||||
):
|
||||
self.user_services = user_services
|
||||
self.user_role_services = user_role_services
|
||||
|
||||
@handler.get("/me", student=True, out=True)
|
||||
async def get_me(self, request: Request):
|
||||
user: "UserModel" = request.user
|
||||
user = await self.user_services.get_user(username=user.username)
|
||||
if user:
|
||||
return BaseApiOut(code=0, msg="查询成功", data=user)
|
||||
return BaseApiOut(status=500, msg="查询失败,内部服务器错误")
|
||||
|
||||
@handler.post("/update", student=True, out=True)
|
||||
async def update_user(self, request: Request, new_user: UserUpdate):
|
||||
user: "UserModel" = request.user
|
||||
if new_user.sex not in ["男", "女"]:
|
||||
return BaseApiOut(status=500, msg="请选择性别")
|
||||
need_change_password = False
|
||||
if new_user.old_password and new_user.password and new_user.password2:
|
||||
if new_user.old_password == new_user.password:
|
||||
return BaseApiOut(status=500, msg="新密码不能与旧密码相同")
|
||||
if new_user.password != new_user.password2:
|
||||
return BaseApiOut(status=500, msg="两次输入密码不一致")
|
||||
need_change_password = True
|
||||
user = await self.user_services.get_user(username=user.username)
|
||||
if not user:
|
||||
return BaseApiOut(status=500, msg="用户不存在")
|
||||
|
||||
try:
|
||||
user = await self.user_services.update_user_profile(
|
||||
username=user.username,
|
||||
nickname=new_user.nickname,
|
||||
email=new_user.email,
|
||||
real_name=new_user.real_name,
|
||||
sex=new_user.sex,
|
||||
old_password=new_user.old_password if need_change_password else None,
|
||||
password=new_user.password if need_change_password else None,
|
||||
)
|
||||
return BaseApiOut(code=0, msg="更新成功", data=user)
|
||||
except FileNotFoundError:
|
||||
return BaseApiOut(status=500, msg="旧密码错误")
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error Execute SQL:{e}",
|
||||
) from e
|
@ -126,3 +126,10 @@ class UserRepo(AsyncInitializingComponent):
|
||||
await session.commit()
|
||||
await session.refresh(login_history)
|
||||
return login_history
|
||||
|
||||
async def update_user(self, user: "UserModel") -> Optional[UserModel]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return user
|
||||
|
@ -1,9 +1,10 @@
|
||||
from email.policy import default
|
||||
from enum import Enum
|
||||
from typing import Optional, List
|
||||
|
||||
from fastapi_amis_admin.utils.pydantic import PYDANTIC_V2
|
||||
from fastapi_amis_admin.utils.translation import i18n as _
|
||||
from fastapi_user_auth.mixins.models import UsernameMixin, PasswordMixin
|
||||
from fastapi_user_auth.mixins.models import UsernameMixin, PasswordMixin, EmailMixin
|
||||
from fastapi_user_auth.utils.sqltypes import SecretStrType
|
||||
from pydantic import BaseModel, SecretStr
|
||||
from sqlmodel import Field
|
||||
@ -68,6 +69,26 @@ class UserRegIn(
|
||||
return v
|
||||
|
||||
|
||||
class UserUpdate(
|
||||
EmailMixin,
|
||||
RealNameMixin,
|
||||
SexMixin,
|
||||
):
|
||||
"""用户修改个人资料"""
|
||||
|
||||
nickname: Optional[str] = Field("", title=_("Nickname"), max_length=40)
|
||||
|
||||
old_password: Optional[str] = Field(
|
||||
default="", title=_("Confirm Password"), max_length=128
|
||||
)
|
||||
password: Optional[str] = Field(
|
||||
default="", title=_("Confirm Password"), max_length=128
|
||||
)
|
||||
password2: Optional[str] = Field(
|
||||
default="", title=_("Confirm Password"), max_length=128
|
||||
)
|
||||
|
||||
|
||||
# 默认保留的用户
|
||||
class SystemUserEnum(str, Enum):
|
||||
ROOT = "root"
|
||||
|
@ -57,6 +57,35 @@ class UserServices(AsyncInitializingComponent):
|
||||
)
|
||||
return await self.repo.create_login_history(history)
|
||||
|
||||
async def update_user_profile(
|
||||
self,
|
||||
username: str,
|
||||
nickname: Optional[str],
|
||||
email: Optional[str],
|
||||
real_name: Optional[str],
|
||||
sex: Optional[str],
|
||||
old_password: Optional[str],
|
||||
password: Optional[str],
|
||||
) -> Optional[UserModel]:
|
||||
user = await self.get_user(username=username)
|
||||
if not user:
|
||||
return None
|
||||
if nickname is not None:
|
||||
user.nickname = nickname
|
||||
if email is not None:
|
||||
user.email = email
|
||||
if real_name is not None:
|
||||
user.real_name = real_name
|
||||
if sex is not None:
|
||||
user.sex = sex
|
||||
if old_password and password:
|
||||
if not self.repo.AUTH.pwd_context.verify(
|
||||
old_password, user.password.get_secret_value()
|
||||
):
|
||||
raise FileNotFoundError("Old password is incorrect")
|
||||
user.password = self.repo.AUTH.pwd_context.hash(password)
|
||||
return await self.repo.update_user(user)
|
||||
|
||||
|
||||
class UserRoleServices(AsyncInitializingComponent):
|
||||
__order__ = 1
|
||||
|
Loading…
Reference in New Issue
Block a user