From a63b96ae805d921a7b2c01306bcf3b42d66d493d Mon Sep 17 00:00:00 2001 From: xtaodada Date: Tue, 5 Nov 2024 18:41:30 +0800 Subject: [PATCH] chore: type --- src/core/web_app.py | 8 +++++ src/route/users.py | 47 +++++++++++++++++---------- src/services/users/schemas.py | 8 ++++- src/services/users/services.py | 59 ++++++++++++++++++++-------------- 4 files changed, 80 insertions(+), 42 deletions(-) diff --git a/src/core/web_app.py b/src/core/web_app.py index ba72322..91a2f9e 100644 --- a/src/core/web_app.py +++ b/src/core/web_app.py @@ -4,6 +4,7 @@ import uvicorn from fastapi import FastAPI from persica.factory.component import AsyncInitializingComponent +from starlette.middleware.cors import CORSMiddleware from starlette.middleware.trustedhost import TrustedHostMiddleware from src.env import config @@ -26,6 +27,13 @@ class WebApp(AsyncInitializingComponent): config.web.domain, ], ) + self.app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) async def start(self): print("开始启动 web 服务") diff --git a/src/route/users.py b/src/route/users.py index f377402..c6e4bb5 100644 --- a/src/route/users.py +++ b/src/route/users.py @@ -10,12 +10,13 @@ from starlette.responses import Response from src.plugin import handler from src.plugin.plugin import Plugin from src.services.users.schemas import ( + UserLoginData, UserRegIn, SystemUserEnum, UserLoginOut, UserRoleEnum, ) -from src.services.users.services import UserServices +from src.services.users.services import UserServices, UserRoleServices if TYPE_CHECKING: from fastapi_user_auth.auth import Auth @@ -24,28 +25,31 @@ if TYPE_CHECKING: class UserRoutes(Plugin): _prefix = "/user" - def __init__(self, user_services: UserServices): + def __init__( + self, user_services: UserServices, user_role_services: UserRoleServices + ): self.user_services = user_services + self.user_role_services = user_role_services @handler.post("/register", admin=False) async def register(self, data: UserRegIn): if data.username.upper() in SystemUserEnum.__members__: - return BaseApiOut(status=-1, msg="用户名已被注册", data=None) + return BaseApiOut(status=500, msg="用户名已被注册", data=None) user = await self.user_services.get_user(username=data.username) if user: - return BaseApiOut(status=-1, msg="用户名已被注册", data=None) + return BaseApiOut(status=500, msg="用户名已被注册", data=None) role = UserRoleEnum.STUDENT.value if not (data.student_id or data.phone): - return BaseApiOut(status=-1, msg="学号或手机号至少填写一项", data=None) + return BaseApiOut(status=500, msg="学号或手机号至少填写一项", data=None) if data.student_id: user = await self.user_services.get_user(student_id=data.student_id) if user: - return BaseApiOut(status=-1, msg="学号已被注册", data=None) + return BaseApiOut(status=500, msg="学号已被注册", data=None) role = UserRoleEnum.STUDENT.value if data.phone: user = await self.user_services.get_user(phone=data.phone) if user: - return BaseApiOut(status=-1, msg="手机号已被注册", data=None) + return BaseApiOut(status=500, msg="手机号已被注册", data=None) role = UserRoleEnum.OUT.value # 检查通过,注册用户 try: @@ -55,8 +59,12 @@ class UserRoutes(Plugin): student_id=data.student_id, phone=data.phone, ) - if not await self.user_services.is_user_in_role_group(data.username, role): - await self.user_services.add_user_to_role_group(data.username, role) + if not await self.user_role_services.is_user_in_role_group( + data.username, role + ): + await self.user_role_services.add_user_to_role_group( + data.username, role + ) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -65,6 +73,7 @@ class UserRoutes(Plugin): # 注册成功,设置用户信息 token_info = UserLoginOut.model_validate(user) token_info.access_token = await self.user_services.login_user(user) + token_info.roles = role return BaseApiOut(code=0, msg="注册成功", data=token_info) async def create_login_history(self, request: "Request"): @@ -83,18 +92,19 @@ class UserRoutes(Plugin): @handler.post("/login", admin=False) async def login( - self, request: Request, response: Response, username: str, password: str + self, + request: Request, + response: Response, + user: UserLoginData, ): auth: "Auth" = request.auth - if request.scope.get("user"): - return BaseApiOut( - code=1, msg="用户已登录", data=UserLoginOut.model_validate(request.user) - ) - user = await auth.authenticate_user(username=username, password=password) + user = await auth.authenticate_user( + username=user.username, password=user.password + ) if not user: - return BaseApiOut(status=-1, msg="用户名或密码错误") + return BaseApiOut(status=500, msg="用户名或密码错误") if not user.is_active: - return BaseApiOut(status=-2, msg="用户未激活") + return BaseApiOut(status=500, msg="用户未激活") request.scope["user"] = user try: @@ -109,6 +119,9 @@ class UserRoutes(Plugin): token_info.access_token = await auth.backend.token_store.write_token( request.user.dict() ) + token_info.roles = ",".join( + await self.user_role_services.get_user_roles(user.username) + ) response.set_cookie("Authorization", f"bearer {token_info.access_token}") return BaseApiOut(code=0, data=token_info) diff --git a/src/services/users/schemas.py b/src/services/users/schemas.py index b301a4e..4055ba6 100644 --- a/src/services/users/schemas.py +++ b/src/services/users/schemas.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Optional +from typing import Optional, List from fastapi_amis_admin.utils.pydantic import PYDANTIC_V2 from fastapi_amis_admin.utils.translation import i18n as _ @@ -16,11 +16,17 @@ class BaseTokenData(BaseModel): username: str +class UserLoginData(BaseModel): + username: str + password: str + + class UserLoginOut(UserModel): """用户登录返回信息""" token_type: str = "bearer" access_token: Optional[str] = None + roles: str = "" password: SecretStr = Field( title=_("Password"), max_length=128, diff --git a/src/services/users/services.py b/src/services/users/services.py index 7c32a91..6d6033c 100644 --- a/src/services/users/services.py +++ b/src/services/users/services.py @@ -16,17 +16,6 @@ class UserServices(AsyncInitializingComponent): def __init__(self, repo: UserRepo): self.repo = repo self.user_model = UserModel - self.role_model = RoleModel - self.rule_model = CasbinRule - - async def initialize(self): - for g in UserRoleEnum.__members__.keys(): - key = g.lower() - if key == "admin": - continue - if await self.get_role(key=key) is None: - await self.create_role(key, f"{key} role") - print(f"Create role: {key}") async def register_user( self, @@ -53,11 +42,46 @@ class UserServices(AsyncInitializingComponent): ) -> Optional[UserModel]: return await self.repo.get_user(username, student_id, phone) + async def create_login_history( + self, user: "UserModel", ip: str, ua: str, forwarded_for: str + ): + history = LoginHistory( + user_id=user.id, + login_name=user.username, + ip=ip, + user_agent=ua, + login_status="登录成功", + forwarded_for=forwarded_for, + ) + return await self.repo.create_login_history(history) + + +class UserRoleServices(AsyncInitializingComponent): + __order__ = 1 + + def __init__(self, repo: UserRepo): + self.repo = repo + self.role_model = RoleModel + self.rule_model = CasbinRule + + async def initialize(self): + for g in UserRoleEnum.__members__.keys(): + key = g.lower() + if key == "admin": + continue + if await self.get_role(key=key) is None: + await self.create_role(key, f"{key} role") + print(f"Create role: {key}") + async def get_role( self, rid: Optional[int] = None, key: Optional[str] = None ) -> Optional[RoleModel]: return await self.repo.get_role(rid, key) + async def get_user_roles(self, username: str) -> List[str]: + role_keys = await self.repo.AUTH.enforcer.get_roles_for_user(f"u:{username}") + return [i.replace("r:", "") for i in role_keys] + async def create_role( self, key: str, name: str, description: Optional[str] = None ) -> RoleModel: @@ -86,16 +110,3 @@ class UserServices(AsyncInitializingComponent): await update_subject_roles( self.repo.AUTH.enforcer, subject=f"u:{username}", role_keys=new_roles ) - - async def create_login_history( - self, user: "UserModel", ip: str, ua: str, forwarded_for: str - ): - history = LoginHistory( - user_id=user.id, - login_name=user.username, - ip=ip, - user_agent=ua, - login_status="登录成功", - forwarded_for=forwarded_for, - ) - return await self.repo.create_login_history(history)