feat: user reg
This commit is contained in:
parent
277ab78938
commit
2b7f89388b
@ -20,5 +20,5 @@ DB_DATABASE=xxx
|
||||
# Redis
|
||||
REDIS_HOST=127.0.0.1
|
||||
REDIS_PORT=6379
|
||||
REDIS_DB=0
|
||||
REDIS_DATABASE=0
|
||||
REDIS_PASSWORD=""
|
||||
|
@ -1,13 +1,60 @@
|
||||
from fastapi_amis_admin.crud import BaseApiOut
|
||||
from persica.factory.component import AsyncInitializingComponent
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from starlette import status
|
||||
|
||||
from src.core.web_app import WebApp
|
||||
from src.services.users.repositories import UserRepo
|
||||
from src.services.users.schemas import UserRegIn, SystemUserEnum, UserLoginOut
|
||||
from src.services.users.services import UserServices
|
||||
|
||||
|
||||
class UserRoutes(AsyncInitializingComponent):
|
||||
def __init__(self, app: WebApp):
|
||||
app.app.add_api_route("/users", self.test_get, methods=["GET"])
|
||||
def __init__(self, app: WebApp, user_services: UserServices):
|
||||
self.router = APIRouter(prefix="/user")
|
||||
self.router.add_api_route("/register", self.register, methods=["POST"])
|
||||
self.user_services = user_services
|
||||
app.app.include_router(self.router)
|
||||
|
||||
async def test_get(self):
|
||||
print(UserRepo.AUTH is not None)
|
||||
return {}
|
||||
async def register(self, data: UserRegIn):
|
||||
if data.username.upper() in SystemUserEnum.__members__:
|
||||
return BaseApiOut(status=-1, msg="用户名已被注册", data=None)
|
||||
user = await self.user_services.get_user(username=data.username)
|
||||
if user:
|
||||
return BaseApiOut(status=-1, msg="用户名已被注册", data=None)
|
||||
user = await self.user_services.get_user(email=data.email)
|
||||
if user:
|
||||
return BaseApiOut(status=-1, msg="邮箱已被注册", data=None)
|
||||
role = "student"
|
||||
if not (data.student_id or data.phone):
|
||||
return BaseApiOut(status=-1, msg="学号或手机号至少填写一项", data=None)
|
||||
if data.student_id:
|
||||
user = await self.user_services.get_user(student_id=data.student_id)
|
||||
if user:
|
||||
return BaseApiOut(status=-1, msg="学号已被注册", data=None)
|
||||
role = "student"
|
||||
if data.phone:
|
||||
user = await self.user_services.get_user(phone=data.phone)
|
||||
if user:
|
||||
return BaseApiOut(status=-1, msg="手机号已被注册", data=None)
|
||||
role = "out"
|
||||
# 检查通过,注册用户
|
||||
try:
|
||||
user = await self.user_services.register_user(
|
||||
username=data.username,
|
||||
password=data.password,
|
||||
email=data.email,
|
||||
student_id=data.student_id,
|
||||
phone=data.phone,
|
||||
)
|
||||
if not await self.user_services.is_user_in_role_group(data.username, role):
|
||||
await self.user_services.add_user_to_role_group(data.username, role)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error Execute SQL:{e}",
|
||||
) from e
|
||||
# 注册成功,设置用户信息
|
||||
token_info = UserLoginOut.model_validate(user)
|
||||
token_info.access_token = await self.user_services.login_user(user)
|
||||
return BaseApiOut(code=0, msg="注册成功", data=token_info)
|
||||
|
@ -1,17 +1,17 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi_amis_admin.models.fields import Field
|
||||
from fastapi_user_auth.auth.models import BaseUser, Role
|
||||
from fastapi_user_auth.auth.models import BaseUser, Role as RoleModel
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
|
||||
class UserModel(BaseUser, table=True):
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
class StudentIdMixin(SQLModel):
|
||||
student_id: Optional[str] = Field("", title="学号", max_length=15)
|
||||
|
||||
|
||||
class PhoneMixin(SQLModel):
|
||||
phone: Optional[str] = Field("", title="电话号码", max_length=15)
|
||||
|
||||
|
||||
class RoleModel(Role, table=True):
|
||||
class UserModel(BaseUser, StudentIdMixin, PhoneMixin, table=True):
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
is_active: bool = Field(default=True, title="是否激活")
|
||||
|
@ -1,19 +1,119 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi_user_auth.auth import Auth
|
||||
from fastapi_user_auth.auth.backends.redis import RedisTokenStore
|
||||
from fastapi_user_auth.auth.models import CasbinRule
|
||||
from persica.factory.component import AsyncInitializingComponent
|
||||
from pydantic import SecretStr
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from src.core.database import Database
|
||||
from src.core.redis_db import RedisDB
|
||||
from src.services.users.models import UserModel
|
||||
from src.core.web_app import WebApp
|
||||
from src.services.users.models import UserModel, RoleModel
|
||||
|
||||
|
||||
class UserRepo(AsyncInitializingComponent):
|
||||
AUTH: Auth = None
|
||||
|
||||
def __init__(self, database: Database, redis: RedisDB):
|
||||
def __init__(self, app: WebApp, database: Database, redis: RedisDB):
|
||||
self.engine = database.engine
|
||||
self.AUTH = Auth(
|
||||
database.db, token_store=RedisTokenStore(redis.client), user_model=UserModel
|
||||
database.db,
|
||||
token_store=RedisTokenStore(redis.client),
|
||||
user_model=UserModel,
|
||||
)
|
||||
self.AUTH.backend.attach_middleware(app.app)
|
||||
self.user_model = UserModel
|
||||
self.role_model = RoleModel
|
||||
self.rule_model = CasbinRule
|
||||
|
||||
async def initialize(self):
|
||||
await self.AUTH.create_role_user("admin")
|
||||
|
||||
async def register_user(
|
||||
self,
|
||||
username: str,
|
||||
password: SecretStr,
|
||||
email: str,
|
||||
student_id: Optional[str],
|
||||
phone: Optional[str],
|
||||
):
|
||||
password = self.AUTH.pwd_context.hash(password.get_secret_value())
|
||||
values = {
|
||||
"username": username,
|
||||
"password": password,
|
||||
"email": email,
|
||||
"student_id": student_id,
|
||||
"phone": phone,
|
||||
}
|
||||
user = self.user_model.model_validate(values)
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return user
|
||||
|
||||
async def get_user(
|
||||
self,
|
||||
username: Optional[str] = None,
|
||||
email: Optional[str] = None,
|
||||
student_id: Optional[str] = None,
|
||||
phone: Optional[str] = None,
|
||||
) -> Optional[UserModel]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = select(self.user_model)
|
||||
if username:
|
||||
statement = statement.where(self.user_model.username == username)
|
||||
if email:
|
||||
statement = statement.where(self.user_model.email == email)
|
||||
if student_id:
|
||||
statement = statement.where(self.user_model.student_id == student_id)
|
||||
if phone:
|
||||
statement = statement.where(self.user_model.phone == phone)
|
||||
r = await session.exec(statement)
|
||||
return r.first()
|
||||
|
||||
async def get_role(
|
||||
self, rid: Optional[int] = None, key: Optional[str] = None
|
||||
) -> Optional[RoleModel]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = select(self.role_model)
|
||||
if rid:
|
||||
statement = statement.where(self.role_model.id == rid)
|
||||
if key:
|
||||
statement = statement.where(self.role_model.key == key)
|
||||
r = await session.exec(statement)
|
||||
return r.first()
|
||||
|
||||
async def create_role(self, role: "RoleModel") -> RoleModel:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(role)
|
||||
await session.commit()
|
||||
await session.refresh(role)
|
||||
return role
|
||||
|
||||
async def get_role_rule(
|
||||
self,
|
||||
ptype: Optional[str] = None,
|
||||
v0: Optional[str] = None,
|
||||
v1: Optional[str] = None,
|
||||
) -> Optional[RoleModel]:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
statement = select(self.rule_model)
|
||||
if ptype:
|
||||
statement = statement.where(self.rule_model.ptype == ptype)
|
||||
if v0:
|
||||
statement = statement.where(self.rule_model.v0 == v0)
|
||||
if v1:
|
||||
statement = statement.where(self.rule_model.v1 == v1)
|
||||
r = await session.exec(statement)
|
||||
return r.first()
|
||||
|
||||
async def create_role_rule(self, rule: "CasbinRule") -> CasbinRule:
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(rule)
|
||||
await session.commit()
|
||||
await session.refresh(rule)
|
||||
return rule
|
||||
|
@ -0,0 +1,69 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from fastapi_amis_admin.utils.pydantic import PYDANTIC_V2
|
||||
from fastapi_amis_admin.utils.translation import i18n as _
|
||||
from fastapi_user_auth.mixins.models import UsernameMixin, PasswordMixin, EmailMixin
|
||||
from fastapi_user_auth.utils.sqltypes import SecretStrType
|
||||
from pydantic import BaseModel, SecretStr
|
||||
from sqlmodel import Field
|
||||
|
||||
from .models import UserModel, StudentIdMixin, PhoneMixin
|
||||
|
||||
|
||||
class BaseTokenData(BaseModel):
|
||||
id: int
|
||||
username: str
|
||||
|
||||
|
||||
class UserLoginOut(UserModel):
|
||||
"""用户登录返回信息"""
|
||||
|
||||
token_type: str = "bearer"
|
||||
access_token: Optional[str] = None
|
||||
password: SecretStr = Field(
|
||||
title=_("Password"),
|
||||
max_length=128,
|
||||
sa_type=SecretStrType,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
class UserRegIn(UsernameMixin, PasswordMixin, EmailMixin, StudentIdMixin, PhoneMixin):
|
||||
"""用户注册"""
|
||||
|
||||
password2: str = Field(title=_("Confirm Password"), max_length=128)
|
||||
|
||||
if PYDANTIC_V2:
|
||||
from pydantic import model_validator
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_passwords_match(self):
|
||||
if (
|
||||
self.password is not None
|
||||
and self.password.get_secret_value() != self.password2
|
||||
):
|
||||
raise ValueError("passwords do not match!")
|
||||
return self
|
||||
|
||||
else:
|
||||
from pydantic import validator
|
||||
|
||||
@validator("password2")
|
||||
def passwords_match_(cls, v, values, **kwargs):
|
||||
if "password" in values and v != values["password"]:
|
||||
raise ValueError("passwords do not match!")
|
||||
return v
|
||||
|
||||
|
||||
# 默认保留的用户
|
||||
class SystemUserEnum(str, Enum):
|
||||
ROOT = "root"
|
||||
ADMIN = "admin"
|
||||
GUEST = "guest"
|
||||
|
||||
|
||||
class UserRoleEnum(str, Enum):
|
||||
ADMIN = "admin" # 管理员
|
||||
STUDENT = "student" # 教职工
|
||||
OUT = "out" # 校外人员
|
77
src/services/users/services.py
Normal file
77
src/services/users/services.py
Normal file
@ -0,0 +1,77 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi_user_auth.auth.models import CasbinRule
|
||||
from persica.factory.component import AsyncInitializingComponent
|
||||
from pydantic import SecretStr
|
||||
|
||||
from .models import UserModel, RoleModel
|
||||
from src.services.users.repositories import UserRepo
|
||||
from .schemas import UserRoleEnum
|
||||
|
||||
|
||||
class UserServices(AsyncInitializingComponent):
|
||||
def __init__(self, repo: UserRepo):
|
||||
self.repo = repo
|
||||
self.user_model = UserModel
|
||||
self.role_model = RoleModel
|
||||
self.rule_model = CasbinRule
|
||||
|
||||
async def initialize(self):
|
||||
for g in UserRoleEnum.__members__.keys():
|
||||
key = g.lower()
|
||||
if await self.get_role(key=key) is None:
|
||||
await self.create_role(key, f"{key} role")
|
||||
print(f"Create role: {key}")
|
||||
|
||||
async def register_user(
|
||||
self,
|
||||
username: str,
|
||||
password: SecretStr,
|
||||
email: str,
|
||||
student_id: Optional[str],
|
||||
phone: Optional[str],
|
||||
):
|
||||
return await self.repo.register_user(
|
||||
username, password, email, student_id, phone
|
||||
)
|
||||
|
||||
async def login_user(self, user: "UserModel") -> str:
|
||||
return await self.repo.AUTH.backend.token_store.write_token(user.model_dump())
|
||||
|
||||
async def get_user(
|
||||
self,
|
||||
username: Optional[str] = None,
|
||||
email: Optional[str] = None,
|
||||
student_id: Optional[str] = None,
|
||||
phone: Optional[str] = None,
|
||||
) -> Optional[UserModel]:
|
||||
return await self.repo.get_user(username, email, student_id, phone)
|
||||
|
||||
async def get_role(
|
||||
self, rid: Optional[int] = None, key: Optional[str] = None
|
||||
) -> Optional[RoleModel]:
|
||||
return await self.repo.get_role(rid, key)
|
||||
|
||||
async def create_role(
|
||||
self, key: str, name: str, description: Optional[str] = None
|
||||
) -> RoleModel:
|
||||
role = self.role_model(key=key, name=name, description=description)
|
||||
return await self.repo.create_role(role)
|
||||
|
||||
async def get_role_rule(
|
||||
self,
|
||||
ptype: Optional[str] = None,
|
||||
v0: Optional[str] = None,
|
||||
v1: Optional[str] = None,
|
||||
) -> Optional[RoleModel]:
|
||||
return await self.repo.get_role_rule(ptype, v0, v1)
|
||||
|
||||
async def is_user_in_role_group(self, username: str, role_key: str) -> bool:
|
||||
return (
|
||||
await self.repo.get_role_rule("g", f"u:{username}", f"r:{role_key}")
|
||||
is not None
|
||||
)
|
||||
|
||||
async def add_user_to_role_group(self, username: str, role_key: str) -> CasbinRule:
|
||||
rule = self.rule_model(ptype="g", v0=f"u:{username}", v1=f"r:{role_key}")
|
||||
return await self.repo.create_role_rule(rule)
|
Loading…
Reference in New Issue
Block a user