chore: type

This commit is contained in:
xtaodada 2024-11-05 18:41:30 +08:00
parent f5786d5d1e
commit a63b96ae80
Signed by: xtaodada
GPG Key ID: 4CBB3F4FA8C85659
4 changed files with 80 additions and 42 deletions

View File

@ -4,6 +4,7 @@ import uvicorn
from fastapi import FastAPI
from persica.factory.component import AsyncInitializingComponent
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
from src.env import config
@ -26,6 +27,13 @@ class WebApp(AsyncInitializingComponent):
config.web.domain,
],
)
self.app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
async def start(self):
print("开始启动 web 服务")

View File

@ -10,12 +10,13 @@ from starlette.responses import Response
from src.plugin import handler
from src.plugin.plugin import Plugin
from src.services.users.schemas import (
UserLoginData,
UserRegIn,
SystemUserEnum,
UserLoginOut,
UserRoleEnum,
)
from src.services.users.services import UserServices
from src.services.users.services import UserServices, UserRoleServices
if TYPE_CHECKING:
from fastapi_user_auth.auth import Auth
@ -24,28 +25,31 @@ if TYPE_CHECKING:
class UserRoutes(Plugin):
_prefix = "/user"
def __init__(self, user_services: UserServices):
def __init__(
self, user_services: UserServices, user_role_services: UserRoleServices
):
self.user_services = user_services
self.user_role_services = user_role_services
@handler.post("/register", admin=False)
async def register(self, data: UserRegIn):
if data.username.upper() in SystemUserEnum.__members__:
return BaseApiOut(status=-1, msg="用户名已被注册", data=None)
return BaseApiOut(status=500, msg="用户名已被注册", data=None)
user = await self.user_services.get_user(username=data.username)
if user:
return BaseApiOut(status=-1, msg="用户名已被注册", data=None)
return BaseApiOut(status=500, msg="用户名已被注册", data=None)
role = UserRoleEnum.STUDENT.value
if not (data.student_id or data.phone):
return BaseApiOut(status=-1, msg="学号或手机号至少填写一项", data=None)
return BaseApiOut(status=500, msg="学号或手机号至少填写一项", data=None)
if data.student_id:
user = await self.user_services.get_user(student_id=data.student_id)
if user:
return BaseApiOut(status=-1, msg="学号已被注册", data=None)
return BaseApiOut(status=500, msg="学号已被注册", data=None)
role = UserRoleEnum.STUDENT.value
if data.phone:
user = await self.user_services.get_user(phone=data.phone)
if user:
return BaseApiOut(status=-1, msg="手机号已被注册", data=None)
return BaseApiOut(status=500, msg="手机号已被注册", data=None)
role = UserRoleEnum.OUT.value
# 检查通过,注册用户
try:
@ -55,8 +59,12 @@ class UserRoutes(Plugin):
student_id=data.student_id,
phone=data.phone,
)
if not await self.user_services.is_user_in_role_group(data.username, role):
await self.user_services.add_user_to_role_group(data.username, role)
if not await self.user_role_services.is_user_in_role_group(
data.username, role
):
await self.user_role_services.add_user_to_role_group(
data.username, role
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@ -65,6 +73,7 @@ class UserRoutes(Plugin):
# 注册成功,设置用户信息
token_info = UserLoginOut.model_validate(user)
token_info.access_token = await self.user_services.login_user(user)
token_info.roles = role
return BaseApiOut(code=0, msg="注册成功", data=token_info)
async def create_login_history(self, request: "Request"):
@ -83,18 +92,19 @@ class UserRoutes(Plugin):
@handler.post("/login", admin=False)
async def login(
self, request: Request, response: Response, username: str, password: str
self,
request: Request,
response: Response,
user: UserLoginData,
):
auth: "Auth" = request.auth
if request.scope.get("user"):
return BaseApiOut(
code=1, msg="用户已登录", data=UserLoginOut.model_validate(request.user)
user = await auth.authenticate_user(
username=user.username, password=user.password
)
user = await auth.authenticate_user(username=username, password=password)
if not user:
return BaseApiOut(status=-1, msg="用户名或密码错误")
return BaseApiOut(status=500, msg="用户名或密码错误")
if not user.is_active:
return BaseApiOut(status=-2, msg="用户未激活")
return BaseApiOut(status=500, msg="用户未激活")
request.scope["user"] = user
try:
@ -109,6 +119,9 @@ class UserRoutes(Plugin):
token_info.access_token = await auth.backend.token_store.write_token(
request.user.dict()
)
token_info.roles = ",".join(
await self.user_role_services.get_user_roles(user.username)
)
response.set_cookie("Authorization", f"bearer {token_info.access_token}")
return BaseApiOut(code=0, data=token_info)

View File

@ -1,5 +1,5 @@
from enum import Enum
from typing import Optional
from typing import Optional, List
from fastapi_amis_admin.utils.pydantic import PYDANTIC_V2
from fastapi_amis_admin.utils.translation import i18n as _
@ -16,11 +16,17 @@ class BaseTokenData(BaseModel):
username: str
class UserLoginData(BaseModel):
username: str
password: str
class UserLoginOut(UserModel):
"""用户登录返回信息"""
token_type: str = "bearer"
access_token: Optional[str] = None
roles: str = ""
password: SecretStr = Field(
title=_("Password"),
max_length=128,

View File

@ -16,17 +16,6 @@ class UserServices(AsyncInitializingComponent):
def __init__(self, repo: UserRepo):
self.repo = repo
self.user_model = UserModel
self.role_model = RoleModel
self.rule_model = CasbinRule
async def initialize(self):
for g in UserRoleEnum.__members__.keys():
key = g.lower()
if key == "admin":
continue
if await self.get_role(key=key) is None:
await self.create_role(key, f"{key} role")
print(f"Create role: {key}")
async def register_user(
self,
@ -53,11 +42,46 @@ class UserServices(AsyncInitializingComponent):
) -> Optional[UserModel]:
return await self.repo.get_user(username, student_id, phone)
async def create_login_history(
self, user: "UserModel", ip: str, ua: str, forwarded_for: str
):
history = LoginHistory(
user_id=user.id,
login_name=user.username,
ip=ip,
user_agent=ua,
login_status="登录成功",
forwarded_for=forwarded_for,
)
return await self.repo.create_login_history(history)
class UserRoleServices(AsyncInitializingComponent):
__order__ = 1
def __init__(self, repo: UserRepo):
self.repo = repo
self.role_model = RoleModel
self.rule_model = CasbinRule
async def initialize(self):
for g in UserRoleEnum.__members__.keys():
key = g.lower()
if key == "admin":
continue
if await self.get_role(key=key) is None:
await self.create_role(key, f"{key} role")
print(f"Create role: {key}")
async def get_role(
self, rid: Optional[int] = None, key: Optional[str] = None
) -> Optional[RoleModel]:
return await self.repo.get_role(rid, key)
async def get_user_roles(self, username: str) -> List[str]:
role_keys = await self.repo.AUTH.enforcer.get_roles_for_user(f"u:{username}")
return [i.replace("r:", "") for i in role_keys]
async def create_role(
self, key: str, name: str, description: Optional[str] = None
) -> RoleModel:
@ -86,16 +110,3 @@ class UserServices(AsyncInitializingComponent):
await update_subject_roles(
self.repo.AUTH.enforcer, subject=f"u:{username}", role_keys=new_roles
)
async def create_login_history(
self, user: "UserModel", ip: str, ua: str, forwarded_for: str
):
history = LoginHistory(
user_id=user.id,
login_name=user.username,
ip=ip,
user_agent=ua,
login_status="登录成功",
forwarded_for=forwarded_for,
)
return await self.repo.create_login_history(history)