feat: user reg

This commit is contained in:
xtaodada 2024-11-04 17:35:28 +08:00
parent 277ab78938
commit 2b7f89388b
Signed by: xtaodada
GPG Key ID: 4CBB3F4FA8C85659
6 changed files with 310 additions and 17 deletions

View File

@ -20,5 +20,5 @@ DB_DATABASE=xxx
# Redis
REDIS_HOST=127.0.0.1
REDIS_PORT=6379
REDIS_DB=0
REDIS_DATABASE=0
REDIS_PASSWORD=""

View File

@ -1,13 +1,60 @@
from fastapi_amis_admin.crud import BaseApiOut
from persica.factory.component import AsyncInitializingComponent
from fastapi import APIRouter, HTTPException
from starlette import status
from src.core.web_app import WebApp
from src.services.users.repositories import UserRepo
from src.services.users.schemas import UserRegIn, SystemUserEnum, UserLoginOut
from src.services.users.services import UserServices
class UserRoutes(AsyncInitializingComponent):
def __init__(self, app: WebApp):
app.app.add_api_route("/users", self.test_get, methods=["GET"])
def __init__(self, app: WebApp, user_services: UserServices):
self.router = APIRouter(prefix="/user")
self.router.add_api_route("/register", self.register, methods=["POST"])
self.user_services = user_services
app.app.include_router(self.router)
async def test_get(self):
print(UserRepo.AUTH is not None)
return {}
async def register(self, data: UserRegIn):
if data.username.upper() in SystemUserEnum.__members__:
return BaseApiOut(status=-1, msg="用户名已被注册", data=None)
user = await self.user_services.get_user(username=data.username)
if user:
return BaseApiOut(status=-1, msg="用户名已被注册", data=None)
user = await self.user_services.get_user(email=data.email)
if user:
return BaseApiOut(status=-1, msg="邮箱已被注册", data=None)
role = "student"
if not (data.student_id or data.phone):
return BaseApiOut(status=-1, msg="学号或手机号至少填写一项", data=None)
if data.student_id:
user = await self.user_services.get_user(student_id=data.student_id)
if user:
return BaseApiOut(status=-1, msg="学号已被注册", data=None)
role = "student"
if data.phone:
user = await self.user_services.get_user(phone=data.phone)
if user:
return BaseApiOut(status=-1, msg="手机号已被注册", data=None)
role = "out"
# 检查通过,注册用户
try:
user = await self.user_services.register_user(
username=data.username,
password=data.password,
email=data.email,
student_id=data.student_id,
phone=data.phone,
)
if not await self.user_services.is_user_in_role_group(data.username, role):
await self.user_services.add_user_to_role_group(data.username, role)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error Execute SQL{e}",
) from e
# 注册成功,设置用户信息
token_info = UserLoginOut.model_validate(user)
token_info.access_token = await self.user_services.login_user(user)
return BaseApiOut(code=0, msg="注册成功", data=token_info)

View File

@ -1,17 +1,17 @@
from typing import Optional
from fastapi_amis_admin.models.fields import Field
from fastapi_user_auth.auth.models import BaseUser, Role
from fastapi_user_auth.auth.models import BaseUser, Role as RoleModel
from sqlmodel import SQLModel
class UserModel(BaseUser, table=True):
__table_args__ = {"extend_existing": True}
class StudentIdMixin(SQLModel):
student_id: Optional[str] = Field("", title="学号", max_length=15)
class PhoneMixin(SQLModel):
phone: Optional[str] = Field("", title="电话号码", max_length=15)
class RoleModel(Role, table=True):
class UserModel(BaseUser, StudentIdMixin, PhoneMixin, table=True):
__table_args__ = {"extend_existing": True}
is_active: bool = Field(default=True, title="是否激活")

View File

@ -1,19 +1,119 @@
from typing import Optional
from fastapi_user_auth.auth import Auth
from fastapi_user_auth.auth.backends.redis import RedisTokenStore
from fastapi_user_auth.auth.models import CasbinRule
from persica.factory.component import AsyncInitializingComponent
from pydantic import SecretStr
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from src.core.database import Database
from src.core.redis_db import RedisDB
from src.services.users.models import UserModel
from src.core.web_app import WebApp
from src.services.users.models import UserModel, RoleModel
class UserRepo(AsyncInitializingComponent):
AUTH: Auth = None
def __init__(self, database: Database, redis: RedisDB):
def __init__(self, app: WebApp, database: Database, redis: RedisDB):
self.engine = database.engine
self.AUTH = Auth(
database.db, token_store=RedisTokenStore(redis.client), user_model=UserModel
database.db,
token_store=RedisTokenStore(redis.client),
user_model=UserModel,
)
self.AUTH.backend.attach_middleware(app.app)
self.user_model = UserModel
self.role_model = RoleModel
self.rule_model = CasbinRule
async def initialize(self):
await self.AUTH.create_role_user("admin")
async def register_user(
self,
username: str,
password: SecretStr,
email: str,
student_id: Optional[str],
phone: Optional[str],
):
password = self.AUTH.pwd_context.hash(password.get_secret_value())
values = {
"username": username,
"password": password,
"email": email,
"student_id": student_id,
"phone": phone,
}
user = self.user_model.model_validate(values)
async with AsyncSession(self.engine) as session:
session.add(user)
await session.commit()
await session.refresh(user)
return user
async def get_user(
self,
username: Optional[str] = None,
email: Optional[str] = None,
student_id: Optional[str] = None,
phone: Optional[str] = None,
) -> Optional[UserModel]:
async with AsyncSession(self.engine) as session:
statement = select(self.user_model)
if username:
statement = statement.where(self.user_model.username == username)
if email:
statement = statement.where(self.user_model.email == email)
if student_id:
statement = statement.where(self.user_model.student_id == student_id)
if phone:
statement = statement.where(self.user_model.phone == phone)
r = await session.exec(statement)
return r.first()
async def get_role(
self, rid: Optional[int] = None, key: Optional[str] = None
) -> Optional[RoleModel]:
async with AsyncSession(self.engine) as session:
statement = select(self.role_model)
if rid:
statement = statement.where(self.role_model.id == rid)
if key:
statement = statement.where(self.role_model.key == key)
r = await session.exec(statement)
return r.first()
async def create_role(self, role: "RoleModel") -> RoleModel:
async with AsyncSession(self.engine) as session:
session.add(role)
await session.commit()
await session.refresh(role)
return role
async def get_role_rule(
self,
ptype: Optional[str] = None,
v0: Optional[str] = None,
v1: Optional[str] = None,
) -> Optional[RoleModel]:
async with AsyncSession(self.engine) as session:
statement = select(self.rule_model)
if ptype:
statement = statement.where(self.rule_model.ptype == ptype)
if v0:
statement = statement.where(self.rule_model.v0 == v0)
if v1:
statement = statement.where(self.rule_model.v1 == v1)
r = await session.exec(statement)
return r.first()
async def create_role_rule(self, rule: "CasbinRule") -> CasbinRule:
async with AsyncSession(self.engine) as session:
session.add(rule)
await session.commit()
await session.refresh(rule)
return rule

View File

@ -0,0 +1,69 @@
from enum import Enum
from typing import Optional
from fastapi_amis_admin.utils.pydantic import PYDANTIC_V2
from fastapi_amis_admin.utils.translation import i18n as _
from fastapi_user_auth.mixins.models import UsernameMixin, PasswordMixin, EmailMixin
from fastapi_user_auth.utils.sqltypes import SecretStrType
from pydantic import BaseModel, SecretStr
from sqlmodel import Field
from .models import UserModel, StudentIdMixin, PhoneMixin
class BaseTokenData(BaseModel):
id: int
username: str
class UserLoginOut(UserModel):
"""用户登录返回信息"""
token_type: str = "bearer"
access_token: Optional[str] = None
password: SecretStr = Field(
title=_("Password"),
max_length=128,
sa_type=SecretStrType,
nullable=False,
)
class UserRegIn(UsernameMixin, PasswordMixin, EmailMixin, StudentIdMixin, PhoneMixin):
"""用户注册"""
password2: str = Field(title=_("Confirm Password"), max_length=128)
if PYDANTIC_V2:
from pydantic import model_validator
@model_validator(mode="after")
def check_passwords_match(self):
if (
self.password is not None
and self.password.get_secret_value() != self.password2
):
raise ValueError("passwords do not match!")
return self
else:
from pydantic import validator
@validator("password2")
def passwords_match_(cls, v, values, **kwargs):
if "password" in values and v != values["password"]:
raise ValueError("passwords do not match!")
return v
# 默认保留的用户
class SystemUserEnum(str, Enum):
ROOT = "root"
ADMIN = "admin"
GUEST = "guest"
class UserRoleEnum(str, Enum):
ADMIN = "admin" # 管理员
STUDENT = "student" # 教职工
OUT = "out" # 校外人员

View File

@ -0,0 +1,77 @@
from typing import Optional
from fastapi_user_auth.auth.models import CasbinRule
from persica.factory.component import AsyncInitializingComponent
from pydantic import SecretStr
from .models import UserModel, RoleModel
from src.services.users.repositories import UserRepo
from .schemas import UserRoleEnum
class UserServices(AsyncInitializingComponent):
def __init__(self, repo: UserRepo):
self.repo = repo
self.user_model = UserModel
self.role_model = RoleModel
self.rule_model = CasbinRule
async def initialize(self):
for g in UserRoleEnum.__members__.keys():
key = g.lower()
if await self.get_role(key=key) is None:
await self.create_role(key, f"{key} role")
print(f"Create role: {key}")
async def register_user(
self,
username: str,
password: SecretStr,
email: str,
student_id: Optional[str],
phone: Optional[str],
):
return await self.repo.register_user(
username, password, email, student_id, phone
)
async def login_user(self, user: "UserModel") -> str:
return await self.repo.AUTH.backend.token_store.write_token(user.model_dump())
async def get_user(
self,
username: Optional[str] = None,
email: Optional[str] = None,
student_id: Optional[str] = None,
phone: Optional[str] = None,
) -> Optional[UserModel]:
return await self.repo.get_user(username, email, student_id, phone)
async def get_role(
self, rid: Optional[int] = None, key: Optional[str] = None
) -> Optional[RoleModel]:
return await self.repo.get_role(rid, key)
async def create_role(
self, key: str, name: str, description: Optional[str] = None
) -> RoleModel:
role = self.role_model(key=key, name=name, description=description)
return await self.repo.create_role(role)
async def get_role_rule(
self,
ptype: Optional[str] = None,
v0: Optional[str] = None,
v1: Optional[str] = None,
) -> Optional[RoleModel]:
return await self.repo.get_role_rule(ptype, v0, v1)
async def is_user_in_role_group(self, username: str, role_key: str) -> bool:
return (
await self.repo.get_role_rule("g", f"u:{username}", f"r:{role_key}")
is not None
)
async def add_user_to_role_group(self, username: str, role_key: str) -> CasbinRule:
rule = self.rule_model(ptype="g", v0=f"u:{username}", v1=f"r:{role_key}")
return await self.repo.create_role_rule(rule)