chore: type
This commit is contained in:
parent
f5786d5d1e
commit
a63b96ae80
@ -4,6 +4,7 @@ import uvicorn
|
||||
|
||||
from fastapi import FastAPI
|
||||
from persica.factory.component import AsyncInitializingComponent
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
||||
|
||||
from src.env import config
|
||||
@ -26,6 +27,13 @@ class WebApp(AsyncInitializingComponent):
|
||||
config.web.domain,
|
||||
],
|
||||
)
|
||||
self.app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
print("开始启动 web 服务")
|
||||
|
@ -10,12 +10,13 @@ from starlette.responses import Response
|
||||
from src.plugin import handler
|
||||
from src.plugin.plugin import Plugin
|
||||
from src.services.users.schemas import (
|
||||
UserLoginData,
|
||||
UserRegIn,
|
||||
SystemUserEnum,
|
||||
UserLoginOut,
|
||||
UserRoleEnum,
|
||||
)
|
||||
from src.services.users.services import UserServices
|
||||
from src.services.users.services import UserServices, UserRoleServices
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi_user_auth.auth import Auth
|
||||
@ -24,28 +25,31 @@ if TYPE_CHECKING:
|
||||
class UserRoutes(Plugin):
|
||||
_prefix = "/user"
|
||||
|
||||
def __init__(self, user_services: UserServices):
|
||||
def __init__(
|
||||
self, user_services: UserServices, user_role_services: UserRoleServices
|
||||
):
|
||||
self.user_services = user_services
|
||||
self.user_role_services = user_role_services
|
||||
|
||||
@handler.post("/register", admin=False)
|
||||
async def register(self, data: UserRegIn):
|
||||
if data.username.upper() in SystemUserEnum.__members__:
|
||||
return BaseApiOut(status=-1, msg="用户名已被注册", data=None)
|
||||
return BaseApiOut(status=500, msg="用户名已被注册", data=None)
|
||||
user = await self.user_services.get_user(username=data.username)
|
||||
if user:
|
||||
return BaseApiOut(status=-1, msg="用户名已被注册", data=None)
|
||||
return BaseApiOut(status=500, msg="用户名已被注册", data=None)
|
||||
role = UserRoleEnum.STUDENT.value
|
||||
if not (data.student_id or data.phone):
|
||||
return BaseApiOut(status=-1, msg="学号或手机号至少填写一项", data=None)
|
||||
return BaseApiOut(status=500, msg="学号或手机号至少填写一项", data=None)
|
||||
if data.student_id:
|
||||
user = await self.user_services.get_user(student_id=data.student_id)
|
||||
if user:
|
||||
return BaseApiOut(status=-1, msg="学号已被注册", data=None)
|
||||
return BaseApiOut(status=500, msg="学号已被注册", data=None)
|
||||
role = UserRoleEnum.STUDENT.value
|
||||
if data.phone:
|
||||
user = await self.user_services.get_user(phone=data.phone)
|
||||
if user:
|
||||
return BaseApiOut(status=-1, msg="手机号已被注册", data=None)
|
||||
return BaseApiOut(status=500, msg="手机号已被注册", data=None)
|
||||
role = UserRoleEnum.OUT.value
|
||||
# 检查通过,注册用户
|
||||
try:
|
||||
@ -55,8 +59,12 @@ class UserRoutes(Plugin):
|
||||
student_id=data.student_id,
|
||||
phone=data.phone,
|
||||
)
|
||||
if not await self.user_services.is_user_in_role_group(data.username, role):
|
||||
await self.user_services.add_user_to_role_group(data.username, role)
|
||||
if not await self.user_role_services.is_user_in_role_group(
|
||||
data.username, role
|
||||
):
|
||||
await self.user_role_services.add_user_to_role_group(
|
||||
data.username, role
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
@ -65,6 +73,7 @@ class UserRoutes(Plugin):
|
||||
# 注册成功,设置用户信息
|
||||
token_info = UserLoginOut.model_validate(user)
|
||||
token_info.access_token = await self.user_services.login_user(user)
|
||||
token_info.roles = role
|
||||
return BaseApiOut(code=0, msg="注册成功", data=token_info)
|
||||
|
||||
async def create_login_history(self, request: "Request"):
|
||||
@ -83,18 +92,19 @@ class UserRoutes(Plugin):
|
||||
|
||||
@handler.post("/login", admin=False)
|
||||
async def login(
|
||||
self, request: Request, response: Response, username: str, password: str
|
||||
self,
|
||||
request: Request,
|
||||
response: Response,
|
||||
user: UserLoginData,
|
||||
):
|
||||
auth: "Auth" = request.auth
|
||||
if request.scope.get("user"):
|
||||
return BaseApiOut(
|
||||
code=1, msg="用户已登录", data=UserLoginOut.model_validate(request.user)
|
||||
user = await auth.authenticate_user(
|
||||
username=user.username, password=user.password
|
||||
)
|
||||
user = await auth.authenticate_user(username=username, password=password)
|
||||
if not user:
|
||||
return BaseApiOut(status=-1, msg="用户名或密码错误")
|
||||
return BaseApiOut(status=500, msg="用户名或密码错误")
|
||||
if not user.is_active:
|
||||
return BaseApiOut(status=-2, msg="用户未激活")
|
||||
return BaseApiOut(status=500, msg="用户未激活")
|
||||
request.scope["user"] = user
|
||||
|
||||
try:
|
||||
@ -109,6 +119,9 @@ class UserRoutes(Plugin):
|
||||
token_info.access_token = await auth.backend.token_store.write_token(
|
||||
request.user.dict()
|
||||
)
|
||||
token_info.roles = ",".join(
|
||||
await self.user_role_services.get_user_roles(user.username)
|
||||
)
|
||||
response.set_cookie("Authorization", f"bearer {token_info.access_token}")
|
||||
return BaseApiOut(code=0, data=token_info)
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
|
||||
from fastapi_amis_admin.utils.pydantic import PYDANTIC_V2
|
||||
from fastapi_amis_admin.utils.translation import i18n as _
|
||||
@ -16,11 +16,17 @@ class BaseTokenData(BaseModel):
|
||||
username: str
|
||||
|
||||
|
||||
class UserLoginData(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class UserLoginOut(UserModel):
|
||||
"""用户登录返回信息"""
|
||||
|
||||
token_type: str = "bearer"
|
||||
access_token: Optional[str] = None
|
||||
roles: str = ""
|
||||
password: SecretStr = Field(
|
||||
title=_("Password"),
|
||||
max_length=128,
|
||||
|
@ -16,17 +16,6 @@ class UserServices(AsyncInitializingComponent):
|
||||
def __init__(self, repo: UserRepo):
|
||||
self.repo = repo
|
||||
self.user_model = UserModel
|
||||
self.role_model = RoleModel
|
||||
self.rule_model = CasbinRule
|
||||
|
||||
async def initialize(self):
|
||||
for g in UserRoleEnum.__members__.keys():
|
||||
key = g.lower()
|
||||
if key == "admin":
|
||||
continue
|
||||
if await self.get_role(key=key) is None:
|
||||
await self.create_role(key, f"{key} role")
|
||||
print(f"Create role: {key}")
|
||||
|
||||
async def register_user(
|
||||
self,
|
||||
@ -53,11 +42,46 @@ class UserServices(AsyncInitializingComponent):
|
||||
) -> Optional[UserModel]:
|
||||
return await self.repo.get_user(username, student_id, phone)
|
||||
|
||||
async def create_login_history(
|
||||
self, user: "UserModel", ip: str, ua: str, forwarded_for: str
|
||||
):
|
||||
history = LoginHistory(
|
||||
user_id=user.id,
|
||||
login_name=user.username,
|
||||
ip=ip,
|
||||
user_agent=ua,
|
||||
login_status="登录成功",
|
||||
forwarded_for=forwarded_for,
|
||||
)
|
||||
return await self.repo.create_login_history(history)
|
||||
|
||||
|
||||
class UserRoleServices(AsyncInitializingComponent):
|
||||
__order__ = 1
|
||||
|
||||
def __init__(self, repo: UserRepo):
|
||||
self.repo = repo
|
||||
self.role_model = RoleModel
|
||||
self.rule_model = CasbinRule
|
||||
|
||||
async def initialize(self):
|
||||
for g in UserRoleEnum.__members__.keys():
|
||||
key = g.lower()
|
||||
if key == "admin":
|
||||
continue
|
||||
if await self.get_role(key=key) is None:
|
||||
await self.create_role(key, f"{key} role")
|
||||
print(f"Create role: {key}")
|
||||
|
||||
async def get_role(
|
||||
self, rid: Optional[int] = None, key: Optional[str] = None
|
||||
) -> Optional[RoleModel]:
|
||||
return await self.repo.get_role(rid, key)
|
||||
|
||||
async def get_user_roles(self, username: str) -> List[str]:
|
||||
role_keys = await self.repo.AUTH.enforcer.get_roles_for_user(f"u:{username}")
|
||||
return [i.replace("r:", "") for i in role_keys]
|
||||
|
||||
async def create_role(
|
||||
self, key: str, name: str, description: Optional[str] = None
|
||||
) -> RoleModel:
|
||||
@ -86,16 +110,3 @@ class UserServices(AsyncInitializingComponent):
|
||||
await update_subject_roles(
|
||||
self.repo.AUTH.enforcer, subject=f"u:{username}", role_keys=new_roles
|
||||
)
|
||||
|
||||
async def create_login_history(
|
||||
self, user: "UserModel", ip: str, ua: str, forwarded_for: str
|
||||
):
|
||||
history = LoginHistory(
|
||||
user_id=user.id,
|
||||
login_name=user.username,
|
||||
ip=ip,
|
||||
user_agent=ua,
|
||||
login_status="登录成功",
|
||||
forwarded_for=forwarded_for,
|
||||
)
|
||||
return await self.repo.create_login_history(history)
|
||||
|
Loading…
Reference in New Issue
Block a user