chore: type

This commit is contained in:
xtaodada 2024-11-05 18:41:30 +08:00
parent f5786d5d1e
commit a63b96ae80
Signed by: xtaodada
GPG Key ID: 4CBB3F4FA8C85659
4 changed files with 80 additions and 42 deletions

View File

@ -4,6 +4,7 @@ import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from persica.factory.component import AsyncInitializingComponent from persica.factory.component import AsyncInitializingComponent
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware from starlette.middleware.trustedhost import TrustedHostMiddleware
from src.env import config from src.env import config
@ -26,6 +27,13 @@ class WebApp(AsyncInitializingComponent):
config.web.domain, config.web.domain,
], ],
) )
self.app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
async def start(self): async def start(self):
print("开始启动 web 服务") print("开始启动 web 服务")

View File

@ -10,12 +10,13 @@ from starlette.responses import Response
from src.plugin import handler from src.plugin import handler
from src.plugin.plugin import Plugin from src.plugin.plugin import Plugin
from src.services.users.schemas import ( from src.services.users.schemas import (
UserLoginData,
UserRegIn, UserRegIn,
SystemUserEnum, SystemUserEnum,
UserLoginOut, UserLoginOut,
UserRoleEnum, UserRoleEnum,
) )
from src.services.users.services import UserServices from src.services.users.services import UserServices, UserRoleServices
if TYPE_CHECKING: if TYPE_CHECKING:
from fastapi_user_auth.auth import Auth from fastapi_user_auth.auth import Auth
@ -24,28 +25,31 @@ if TYPE_CHECKING:
class UserRoutes(Plugin): class UserRoutes(Plugin):
_prefix = "/user" _prefix = "/user"
def __init__(self, user_services: UserServices): def __init__(
self, user_services: UserServices, user_role_services: UserRoleServices
):
self.user_services = user_services self.user_services = user_services
self.user_role_services = user_role_services
@handler.post("/register", admin=False) @handler.post("/register", admin=False)
async def register(self, data: UserRegIn): async def register(self, data: UserRegIn):
if data.username.upper() in SystemUserEnum.__members__: if data.username.upper() in SystemUserEnum.__members__:
return BaseApiOut(status=-1, msg="用户名已被注册", data=None) return BaseApiOut(status=500, msg="用户名已被注册", data=None)
user = await self.user_services.get_user(username=data.username) user = await self.user_services.get_user(username=data.username)
if user: if user:
return BaseApiOut(status=-1, msg="用户名已被注册", data=None) return BaseApiOut(status=500, msg="用户名已被注册", data=None)
role = UserRoleEnum.STUDENT.value role = UserRoleEnum.STUDENT.value
if not (data.student_id or data.phone): if not (data.student_id or data.phone):
return BaseApiOut(status=-1, msg="学号或手机号至少填写一项", data=None) return BaseApiOut(status=500, msg="学号或手机号至少填写一项", data=None)
if data.student_id: if data.student_id:
user = await self.user_services.get_user(student_id=data.student_id) user = await self.user_services.get_user(student_id=data.student_id)
if user: if user:
return BaseApiOut(status=-1, msg="学号已被注册", data=None) return BaseApiOut(status=500, msg="学号已被注册", data=None)
role = UserRoleEnum.STUDENT.value role = UserRoleEnum.STUDENT.value
if data.phone: if data.phone:
user = await self.user_services.get_user(phone=data.phone) user = await self.user_services.get_user(phone=data.phone)
if user: if user:
return BaseApiOut(status=-1, msg="手机号已被注册", data=None) return BaseApiOut(status=500, msg="手机号已被注册", data=None)
role = UserRoleEnum.OUT.value role = UserRoleEnum.OUT.value
# 检查通过,注册用户 # 检查通过,注册用户
try: try:
@ -55,8 +59,12 @@ class UserRoutes(Plugin):
student_id=data.student_id, student_id=data.student_id,
phone=data.phone, phone=data.phone,
) )
if not await self.user_services.is_user_in_role_group(data.username, role): if not await self.user_role_services.is_user_in_role_group(
await self.user_services.add_user_to_role_group(data.username, role) data.username, role
):
await self.user_role_services.add_user_to_role_group(
data.username, role
)
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@ -65,6 +73,7 @@ class UserRoutes(Plugin):
# 注册成功,设置用户信息 # 注册成功,设置用户信息
token_info = UserLoginOut.model_validate(user) token_info = UserLoginOut.model_validate(user)
token_info.access_token = await self.user_services.login_user(user) token_info.access_token = await self.user_services.login_user(user)
token_info.roles = role
return BaseApiOut(code=0, msg="注册成功", data=token_info) return BaseApiOut(code=0, msg="注册成功", data=token_info)
async def create_login_history(self, request: "Request"): async def create_login_history(self, request: "Request"):
@ -83,18 +92,19 @@ class UserRoutes(Plugin):
@handler.post("/login", admin=False) @handler.post("/login", admin=False)
async def login( async def login(
self, request: Request, response: Response, username: str, password: str self,
request: Request,
response: Response,
user: UserLoginData,
): ):
auth: "Auth" = request.auth auth: "Auth" = request.auth
if request.scope.get("user"): user = await auth.authenticate_user(
return BaseApiOut( username=user.username, password=user.password
code=1, msg="用户已登录", data=UserLoginOut.model_validate(request.user)
) )
user = await auth.authenticate_user(username=username, password=password)
if not user: if not user:
return BaseApiOut(status=-1, msg="用户名或密码错误") return BaseApiOut(status=500, msg="用户名或密码错误")
if not user.is_active: if not user.is_active:
return BaseApiOut(status=-2, msg="用户未激活") return BaseApiOut(status=500, msg="用户未激活")
request.scope["user"] = user request.scope["user"] = user
try: try:
@ -109,6 +119,9 @@ class UserRoutes(Plugin):
token_info.access_token = await auth.backend.token_store.write_token( token_info.access_token = await auth.backend.token_store.write_token(
request.user.dict() request.user.dict()
) )
token_info.roles = ",".join(
await self.user_role_services.get_user_roles(user.username)
)
response.set_cookie("Authorization", f"bearer {token_info.access_token}") response.set_cookie("Authorization", f"bearer {token_info.access_token}")
return BaseApiOut(code=0, data=token_info) return BaseApiOut(code=0, data=token_info)

View File

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional, List
from fastapi_amis_admin.utils.pydantic import PYDANTIC_V2 from fastapi_amis_admin.utils.pydantic import PYDANTIC_V2
from fastapi_amis_admin.utils.translation import i18n as _ from fastapi_amis_admin.utils.translation import i18n as _
@ -16,11 +16,17 @@ class BaseTokenData(BaseModel):
username: str username: str
class UserLoginData(BaseModel):
username: str
password: str
class UserLoginOut(UserModel): class UserLoginOut(UserModel):
"""用户登录返回信息""" """用户登录返回信息"""
token_type: str = "bearer" token_type: str = "bearer"
access_token: Optional[str] = None access_token: Optional[str] = None
roles: str = ""
password: SecretStr = Field( password: SecretStr = Field(
title=_("Password"), title=_("Password"),
max_length=128, max_length=128,

View File

@ -16,17 +16,6 @@ class UserServices(AsyncInitializingComponent):
def __init__(self, repo: UserRepo): def __init__(self, repo: UserRepo):
self.repo = repo self.repo = repo
self.user_model = UserModel self.user_model = UserModel
self.role_model = RoleModel
self.rule_model = CasbinRule
async def initialize(self):
for g in UserRoleEnum.__members__.keys():
key = g.lower()
if key == "admin":
continue
if await self.get_role(key=key) is None:
await self.create_role(key, f"{key} role")
print(f"Create role: {key}")
async def register_user( async def register_user(
self, self,
@ -53,11 +42,46 @@ class UserServices(AsyncInitializingComponent):
) -> Optional[UserModel]: ) -> Optional[UserModel]:
return await self.repo.get_user(username, student_id, phone) return await self.repo.get_user(username, student_id, phone)
async def create_login_history(
self, user: "UserModel", ip: str, ua: str, forwarded_for: str
):
history = LoginHistory(
user_id=user.id,
login_name=user.username,
ip=ip,
user_agent=ua,
login_status="登录成功",
forwarded_for=forwarded_for,
)
return await self.repo.create_login_history(history)
class UserRoleServices(AsyncInitializingComponent):
__order__ = 1
def __init__(self, repo: UserRepo):
self.repo = repo
self.role_model = RoleModel
self.rule_model = CasbinRule
async def initialize(self):
for g in UserRoleEnum.__members__.keys():
key = g.lower()
if key == "admin":
continue
if await self.get_role(key=key) is None:
await self.create_role(key, f"{key} role")
print(f"Create role: {key}")
async def get_role( async def get_role(
self, rid: Optional[int] = None, key: Optional[str] = None self, rid: Optional[int] = None, key: Optional[str] = None
) -> Optional[RoleModel]: ) -> Optional[RoleModel]:
return await self.repo.get_role(rid, key) return await self.repo.get_role(rid, key)
async def get_user_roles(self, username: str) -> List[str]:
role_keys = await self.repo.AUTH.enforcer.get_roles_for_user(f"u:{username}")
return [i.replace("r:", "") for i in role_keys]
async def create_role( async def create_role(
self, key: str, name: str, description: Optional[str] = None self, key: str, name: str, description: Optional[str] = None
) -> RoleModel: ) -> RoleModel:
@ -86,16 +110,3 @@ class UserServices(AsyncInitializingComponent):
await update_subject_roles( await update_subject_roles(
self.repo.AUTH.enforcer, subject=f"u:{username}", role_keys=new_roles self.repo.AUTH.enforcer, subject=f"u:{username}", role_keys=new_roles
) )
async def create_login_history(
self, user: "UserModel", ip: str, ua: str, forwarded_for: str
):
history = LoginHistory(
user_id=user.id,
login_name=user.username,
ip=ip,
user_agent=ua,
login_status="登录成功",
forwarded_for=forwarded_for,
)
return await self.repo.create_login_history(history)