chore: type
This commit is contained in:
parent
f5786d5d1e
commit
a63b96ae80
@ -4,6 +4,7 @@ import uvicorn
|
|||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from persica.factory.component import AsyncInitializingComponent
|
from persica.factory.component import AsyncInitializingComponent
|
||||||
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
||||||
|
|
||||||
from src.env import config
|
from src.env import config
|
||||||
@ -26,6 +27,13 @@ class WebApp(AsyncInitializingComponent):
|
|||||||
config.web.domain,
|
config.web.domain,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
self.app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
print("开始启动 web 服务")
|
print("开始启动 web 服务")
|
||||||
|
@ -10,12 +10,13 @@ from starlette.responses import Response
|
|||||||
from src.plugin import handler
|
from src.plugin import handler
|
||||||
from src.plugin.plugin import Plugin
|
from src.plugin.plugin import Plugin
|
||||||
from src.services.users.schemas import (
|
from src.services.users.schemas import (
|
||||||
|
UserLoginData,
|
||||||
UserRegIn,
|
UserRegIn,
|
||||||
SystemUserEnum,
|
SystemUserEnum,
|
||||||
UserLoginOut,
|
UserLoginOut,
|
||||||
UserRoleEnum,
|
UserRoleEnum,
|
||||||
)
|
)
|
||||||
from src.services.users.services import UserServices
|
from src.services.users.services import UserServices, UserRoleServices
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from fastapi_user_auth.auth import Auth
|
from fastapi_user_auth.auth import Auth
|
||||||
@ -24,28 +25,31 @@ if TYPE_CHECKING:
|
|||||||
class UserRoutes(Plugin):
|
class UserRoutes(Plugin):
|
||||||
_prefix = "/user"
|
_prefix = "/user"
|
||||||
|
|
||||||
def __init__(self, user_services: UserServices):
|
def __init__(
|
||||||
|
self, user_services: UserServices, user_role_services: UserRoleServices
|
||||||
|
):
|
||||||
self.user_services = user_services
|
self.user_services = user_services
|
||||||
|
self.user_role_services = user_role_services
|
||||||
|
|
||||||
@handler.post("/register", admin=False)
|
@handler.post("/register", admin=False)
|
||||||
async def register(self, data: UserRegIn):
|
async def register(self, data: UserRegIn):
|
||||||
if data.username.upper() in SystemUserEnum.__members__:
|
if data.username.upper() in SystemUserEnum.__members__:
|
||||||
return BaseApiOut(status=-1, msg="用户名已被注册", data=None)
|
return BaseApiOut(status=500, msg="用户名已被注册", data=None)
|
||||||
user = await self.user_services.get_user(username=data.username)
|
user = await self.user_services.get_user(username=data.username)
|
||||||
if user:
|
if user:
|
||||||
return BaseApiOut(status=-1, msg="用户名已被注册", data=None)
|
return BaseApiOut(status=500, msg="用户名已被注册", data=None)
|
||||||
role = UserRoleEnum.STUDENT.value
|
role = UserRoleEnum.STUDENT.value
|
||||||
if not (data.student_id or data.phone):
|
if not (data.student_id or data.phone):
|
||||||
return BaseApiOut(status=-1, msg="学号或手机号至少填写一项", data=None)
|
return BaseApiOut(status=500, msg="学号或手机号至少填写一项", data=None)
|
||||||
if data.student_id:
|
if data.student_id:
|
||||||
user = await self.user_services.get_user(student_id=data.student_id)
|
user = await self.user_services.get_user(student_id=data.student_id)
|
||||||
if user:
|
if user:
|
||||||
return BaseApiOut(status=-1, msg="学号已被注册", data=None)
|
return BaseApiOut(status=500, msg="学号已被注册", data=None)
|
||||||
role = UserRoleEnum.STUDENT.value
|
role = UserRoleEnum.STUDENT.value
|
||||||
if data.phone:
|
if data.phone:
|
||||||
user = await self.user_services.get_user(phone=data.phone)
|
user = await self.user_services.get_user(phone=data.phone)
|
||||||
if user:
|
if user:
|
||||||
return BaseApiOut(status=-1, msg="手机号已被注册", data=None)
|
return BaseApiOut(status=500, msg="手机号已被注册", data=None)
|
||||||
role = UserRoleEnum.OUT.value
|
role = UserRoleEnum.OUT.value
|
||||||
# 检查通过,注册用户
|
# 检查通过,注册用户
|
||||||
try:
|
try:
|
||||||
@ -55,8 +59,12 @@ class UserRoutes(Plugin):
|
|||||||
student_id=data.student_id,
|
student_id=data.student_id,
|
||||||
phone=data.phone,
|
phone=data.phone,
|
||||||
)
|
)
|
||||||
if not await self.user_services.is_user_in_role_group(data.username, role):
|
if not await self.user_role_services.is_user_in_role_group(
|
||||||
await self.user_services.add_user_to_role_group(data.username, role)
|
data.username, role
|
||||||
|
):
|
||||||
|
await self.user_role_services.add_user_to_role_group(
|
||||||
|
data.username, role
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
@ -65,6 +73,7 @@ class UserRoutes(Plugin):
|
|||||||
# 注册成功,设置用户信息
|
# 注册成功,设置用户信息
|
||||||
token_info = UserLoginOut.model_validate(user)
|
token_info = UserLoginOut.model_validate(user)
|
||||||
token_info.access_token = await self.user_services.login_user(user)
|
token_info.access_token = await self.user_services.login_user(user)
|
||||||
|
token_info.roles = role
|
||||||
return BaseApiOut(code=0, msg="注册成功", data=token_info)
|
return BaseApiOut(code=0, msg="注册成功", data=token_info)
|
||||||
|
|
||||||
async def create_login_history(self, request: "Request"):
|
async def create_login_history(self, request: "Request"):
|
||||||
@ -83,18 +92,19 @@ class UserRoutes(Plugin):
|
|||||||
|
|
||||||
@handler.post("/login", admin=False)
|
@handler.post("/login", admin=False)
|
||||||
async def login(
|
async def login(
|
||||||
self, request: Request, response: Response, username: str, password: str
|
self,
|
||||||
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
user: UserLoginData,
|
||||||
):
|
):
|
||||||
auth: "Auth" = request.auth
|
auth: "Auth" = request.auth
|
||||||
if request.scope.get("user"):
|
user = await auth.authenticate_user(
|
||||||
return BaseApiOut(
|
username=user.username, password=user.password
|
||||||
code=1, msg="用户已登录", data=UserLoginOut.model_validate(request.user)
|
|
||||||
)
|
)
|
||||||
user = await auth.authenticate_user(username=username, password=password)
|
|
||||||
if not user:
|
if not user:
|
||||||
return BaseApiOut(status=-1, msg="用户名或密码错误")
|
return BaseApiOut(status=500, msg="用户名或密码错误")
|
||||||
if not user.is_active:
|
if not user.is_active:
|
||||||
return BaseApiOut(status=-2, msg="用户未激活")
|
return BaseApiOut(status=500, msg="用户未激活")
|
||||||
request.scope["user"] = user
|
request.scope["user"] = user
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -109,6 +119,9 @@ class UserRoutes(Plugin):
|
|||||||
token_info.access_token = await auth.backend.token_store.write_token(
|
token_info.access_token = await auth.backend.token_store.write_token(
|
||||||
request.user.dict()
|
request.user.dict()
|
||||||
)
|
)
|
||||||
|
token_info.roles = ",".join(
|
||||||
|
await self.user_role_services.get_user_roles(user.username)
|
||||||
|
)
|
||||||
response.set_cookie("Authorization", f"bearer {token_info.access_token}")
|
response.set_cookie("Authorization", f"bearer {token_info.access_token}")
|
||||||
return BaseApiOut(code=0, data=token_info)
|
return BaseApiOut(code=0, data=token_info)
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Optional, List
|
||||||
|
|
||||||
from fastapi_amis_admin.utils.pydantic import PYDANTIC_V2
|
from fastapi_amis_admin.utils.pydantic import PYDANTIC_V2
|
||||||
from fastapi_amis_admin.utils.translation import i18n as _
|
from fastapi_amis_admin.utils.translation import i18n as _
|
||||||
@ -16,11 +16,17 @@ class BaseTokenData(BaseModel):
|
|||||||
username: str
|
username: str
|
||||||
|
|
||||||
|
|
||||||
|
class UserLoginData(BaseModel):
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
class UserLoginOut(UserModel):
|
class UserLoginOut(UserModel):
|
||||||
"""用户登录返回信息"""
|
"""用户登录返回信息"""
|
||||||
|
|
||||||
token_type: str = "bearer"
|
token_type: str = "bearer"
|
||||||
access_token: Optional[str] = None
|
access_token: Optional[str] = None
|
||||||
|
roles: str = ""
|
||||||
password: SecretStr = Field(
|
password: SecretStr = Field(
|
||||||
title=_("Password"),
|
title=_("Password"),
|
||||||
max_length=128,
|
max_length=128,
|
||||||
|
@ -16,17 +16,6 @@ class UserServices(AsyncInitializingComponent):
|
|||||||
def __init__(self, repo: UserRepo):
|
def __init__(self, repo: UserRepo):
|
||||||
self.repo = repo
|
self.repo = repo
|
||||||
self.user_model = UserModel
|
self.user_model = UserModel
|
||||||
self.role_model = RoleModel
|
|
||||||
self.rule_model = CasbinRule
|
|
||||||
|
|
||||||
async def initialize(self):
|
|
||||||
for g in UserRoleEnum.__members__.keys():
|
|
||||||
key = g.lower()
|
|
||||||
if key == "admin":
|
|
||||||
continue
|
|
||||||
if await self.get_role(key=key) is None:
|
|
||||||
await self.create_role(key, f"{key} role")
|
|
||||||
print(f"Create role: {key}")
|
|
||||||
|
|
||||||
async def register_user(
|
async def register_user(
|
||||||
self,
|
self,
|
||||||
@ -53,11 +42,46 @@ class UserServices(AsyncInitializingComponent):
|
|||||||
) -> Optional[UserModel]:
|
) -> Optional[UserModel]:
|
||||||
return await self.repo.get_user(username, student_id, phone)
|
return await self.repo.get_user(username, student_id, phone)
|
||||||
|
|
||||||
|
async def create_login_history(
|
||||||
|
self, user: "UserModel", ip: str, ua: str, forwarded_for: str
|
||||||
|
):
|
||||||
|
history = LoginHistory(
|
||||||
|
user_id=user.id,
|
||||||
|
login_name=user.username,
|
||||||
|
ip=ip,
|
||||||
|
user_agent=ua,
|
||||||
|
login_status="登录成功",
|
||||||
|
forwarded_for=forwarded_for,
|
||||||
|
)
|
||||||
|
return await self.repo.create_login_history(history)
|
||||||
|
|
||||||
|
|
||||||
|
class UserRoleServices(AsyncInitializingComponent):
|
||||||
|
__order__ = 1
|
||||||
|
|
||||||
|
def __init__(self, repo: UserRepo):
|
||||||
|
self.repo = repo
|
||||||
|
self.role_model = RoleModel
|
||||||
|
self.rule_model = CasbinRule
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
for g in UserRoleEnum.__members__.keys():
|
||||||
|
key = g.lower()
|
||||||
|
if key == "admin":
|
||||||
|
continue
|
||||||
|
if await self.get_role(key=key) is None:
|
||||||
|
await self.create_role(key, f"{key} role")
|
||||||
|
print(f"Create role: {key}")
|
||||||
|
|
||||||
async def get_role(
|
async def get_role(
|
||||||
self, rid: Optional[int] = None, key: Optional[str] = None
|
self, rid: Optional[int] = None, key: Optional[str] = None
|
||||||
) -> Optional[RoleModel]:
|
) -> Optional[RoleModel]:
|
||||||
return await self.repo.get_role(rid, key)
|
return await self.repo.get_role(rid, key)
|
||||||
|
|
||||||
|
async def get_user_roles(self, username: str) -> List[str]:
|
||||||
|
role_keys = await self.repo.AUTH.enforcer.get_roles_for_user(f"u:{username}")
|
||||||
|
return [i.replace("r:", "") for i in role_keys]
|
||||||
|
|
||||||
async def create_role(
|
async def create_role(
|
||||||
self, key: str, name: str, description: Optional[str] = None
|
self, key: str, name: str, description: Optional[str] = None
|
||||||
) -> RoleModel:
|
) -> RoleModel:
|
||||||
@ -86,16 +110,3 @@ class UserServices(AsyncInitializingComponent):
|
|||||||
await update_subject_roles(
|
await update_subject_roles(
|
||||||
self.repo.AUTH.enforcer, subject=f"u:{username}", role_keys=new_roles
|
self.repo.AUTH.enforcer, subject=f"u:{username}", role_keys=new_roles
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_login_history(
|
|
||||||
self, user: "UserModel", ip: str, ua: str, forwarded_for: str
|
|
||||||
):
|
|
||||||
history = LoginHistory(
|
|
||||||
user_id=user.id,
|
|
||||||
login_name=user.username,
|
|
||||||
ip=ip,
|
|
||||||
user_agent=ua,
|
|
||||||
login_status="登录成功",
|
|
||||||
forwarded_for=forwarded_for,
|
|
||||||
)
|
|
||||||
return await self.repo.create_login_history(history)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user