diff --git a/.env.example b/.env.example index 24b2545..a05bf20 100644 --- a/.env.example +++ b/.env.example @@ -20,5 +20,5 @@ DB_DATABASE=xxx # Redis REDIS_HOST=127.0.0.1 REDIS_PORT=6379 -REDIS_DB=0 +REDIS_DATABASE=0 REDIS_PASSWORD="" diff --git a/src/route/users.py b/src/route/users.py index 89401ca..98fbffe 100644 --- a/src/route/users.py +++ b/src/route/users.py @@ -1,13 +1,60 @@ +from fastapi_amis_admin.crud import BaseApiOut from persica.factory.component import AsyncInitializingComponent +from fastapi import APIRouter, HTTPException +from starlette import status + from src.core.web_app import WebApp -from src.services.users.repositories import UserRepo +from src.services.users.schemas import UserRegIn, SystemUserEnum, UserLoginOut +from src.services.users.services import UserServices class UserRoutes(AsyncInitializingComponent): - def __init__(self, app: WebApp): - app.app.add_api_route("/users", self.test_get, methods=["GET"]) + def __init__(self, app: WebApp, user_services: UserServices): + self.router = APIRouter(prefix="/user") + self.router.add_api_route("/register", self.register, methods=["POST"]) + self.user_services = user_services + app.app.include_router(self.router) - async def test_get(self): - print(UserRepo.AUTH is not None) - return {} + async def register(self, data: UserRegIn): + if data.username.upper() in SystemUserEnum.__members__: + return BaseApiOut(status=-1, msg="用户名已被注册", data=None) + user = await self.user_services.get_user(username=data.username) + if user: + return BaseApiOut(status=-1, msg="用户名已被注册", data=None) + user = await self.user_services.get_user(email=data.email) + if user: + return BaseApiOut(status=-1, msg="邮箱已被注册", data=None) + role = "student" + if not (data.student_id or data.phone): + return BaseApiOut(status=-1, msg="学号或手机号至少填写一项", data=None) + if data.student_id: + user = await self.user_services.get_user(student_id=data.student_id) + if user: + return BaseApiOut(status=-1, msg="学号已被注册", data=None) + role = "student" + if data.phone: + user = await self.user_services.get_user(phone=data.phone) + if user: + return BaseApiOut(status=-1, msg="手机号已被注册", data=None) + role = "out" + # 检查通过,注册用户 + try: + user = await self.user_services.register_user( + username=data.username, + password=data.password, + email=data.email, + student_id=data.student_id, + phone=data.phone, + ) + if not await self.user_services.is_user_in_role_group(data.username, role): + await self.user_services.add_user_to_role_group(data.username, role) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error Execute SQL:{e}", + ) from e + # 注册成功,设置用户信息 + token_info = UserLoginOut.model_validate(user) + token_info.access_token = await self.user_services.login_user(user) + return BaseApiOut(code=0, msg="注册成功", data=token_info) diff --git a/src/services/users/models.py b/src/services/users/models.py index 8499f8d..9588f4e 100644 --- a/src/services/users/models.py +++ b/src/services/users/models.py @@ -1,17 +1,17 @@ from typing import Optional from fastapi_amis_admin.models.fields import Field -from fastapi_user_auth.auth.models import BaseUser, Role +from fastapi_user_auth.auth.models import BaseUser, Role as RoleModel +from sqlmodel import SQLModel -class UserModel(BaseUser, table=True): - __table_args__ = {"extend_existing": True} - +class StudentIdMixin(SQLModel): student_id: Optional[str] = Field("", title="学号", max_length=15) + + +class PhoneMixin(SQLModel): phone: Optional[str] = Field("", title="电话号码", max_length=15) -class RoleModel(Role, table=True): +class UserModel(BaseUser, StudentIdMixin, PhoneMixin, table=True): __table_args__ = {"extend_existing": True} - - is_active: bool = Field(default=True, title="是否激活") diff --git a/src/services/users/repositories.py b/src/services/users/repositories.py index 9ccb80c..9f6c542 100644 --- a/src/services/users/repositories.py +++ b/src/services/users/repositories.py @@ -1,19 +1,119 @@ +from typing import Optional + from fastapi_user_auth.auth import Auth from fastapi_user_auth.auth.backends.redis import RedisTokenStore +from fastapi_user_auth.auth.models import CasbinRule from persica.factory.component import AsyncInitializingComponent +from pydantic import SecretStr +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession from src.core.database import Database from src.core.redis_db import RedisDB -from src.services.users.models import UserModel +from src.core.web_app import WebApp +from src.services.users.models import UserModel, RoleModel class UserRepo(AsyncInitializingComponent): AUTH: Auth = None - def __init__(self, database: Database, redis: RedisDB): + def __init__(self, app: WebApp, database: Database, redis: RedisDB): + self.engine = database.engine self.AUTH = Auth( - database.db, token_store=RedisTokenStore(redis.client), user_model=UserModel + database.db, + token_store=RedisTokenStore(redis.client), + user_model=UserModel, ) + self.AUTH.backend.attach_middleware(app.app) + self.user_model = UserModel + self.role_model = RoleModel + self.rule_model = CasbinRule async def initialize(self): await self.AUTH.create_role_user("admin") + + async def register_user( + self, + username: str, + password: SecretStr, + email: str, + student_id: Optional[str], + phone: Optional[str], + ): + password = self.AUTH.pwd_context.hash(password.get_secret_value()) + values = { + "username": username, + "password": password, + "email": email, + "student_id": student_id, + "phone": phone, + } + user = self.user_model.model_validate(values) + async with AsyncSession(self.engine) as session: + session.add(user) + await session.commit() + await session.refresh(user) + return user + + async def get_user( + self, + username: Optional[str] = None, + email: Optional[str] = None, + student_id: Optional[str] = None, + phone: Optional[str] = None, + ) -> Optional[UserModel]: + async with AsyncSession(self.engine) as session: + statement = select(self.user_model) + if username: + statement = statement.where(self.user_model.username == username) + if email: + statement = statement.where(self.user_model.email == email) + if student_id: + statement = statement.where(self.user_model.student_id == student_id) + if phone: + statement = statement.where(self.user_model.phone == phone) + r = await session.exec(statement) + return r.first() + + async def get_role( + self, rid: Optional[int] = None, key: Optional[str] = None + ) -> Optional[RoleModel]: + async with AsyncSession(self.engine) as session: + statement = select(self.role_model) + if rid: + statement = statement.where(self.role_model.id == rid) + if key: + statement = statement.where(self.role_model.key == key) + r = await session.exec(statement) + return r.first() + + async def create_role(self, role: "RoleModel") -> RoleModel: + async with AsyncSession(self.engine) as session: + session.add(role) + await session.commit() + await session.refresh(role) + return role + + async def get_role_rule( + self, + ptype: Optional[str] = None, + v0: Optional[str] = None, + v1: Optional[str] = None, + ) -> Optional[RoleModel]: + async with AsyncSession(self.engine) as session: + statement = select(self.rule_model) + if ptype: + statement = statement.where(self.rule_model.ptype == ptype) + if v0: + statement = statement.where(self.rule_model.v0 == v0) + if v1: + statement = statement.where(self.rule_model.v1 == v1) + r = await session.exec(statement) + return r.first() + + async def create_role_rule(self, rule: "CasbinRule") -> CasbinRule: + async with AsyncSession(self.engine) as session: + session.add(rule) + await session.commit() + await session.refresh(rule) + return rule diff --git a/src/services/users/schemas.py b/src/services/users/schemas.py index e69de29..758849c 100644 --- a/src/services/users/schemas.py +++ b/src/services/users/schemas.py @@ -0,0 +1,69 @@ +from enum import Enum +from typing import Optional + +from fastapi_amis_admin.utils.pydantic import PYDANTIC_V2 +from fastapi_amis_admin.utils.translation import i18n as _ +from fastapi_user_auth.mixins.models import UsernameMixin, PasswordMixin, EmailMixin +from fastapi_user_auth.utils.sqltypes import SecretStrType +from pydantic import BaseModel, SecretStr +from sqlmodel import Field + +from .models import UserModel, StudentIdMixin, PhoneMixin + + +class BaseTokenData(BaseModel): + id: int + username: str + + +class UserLoginOut(UserModel): + """用户登录返回信息""" + + token_type: str = "bearer" + access_token: Optional[str] = None + password: SecretStr = Field( + title=_("Password"), + max_length=128, + sa_type=SecretStrType, + nullable=False, + ) + + +class UserRegIn(UsernameMixin, PasswordMixin, EmailMixin, StudentIdMixin, PhoneMixin): + """用户注册""" + + password2: str = Field(title=_("Confirm Password"), max_length=128) + + if PYDANTIC_V2: + from pydantic import model_validator + + @model_validator(mode="after") + def check_passwords_match(self): + if ( + self.password is not None + and self.password.get_secret_value() != self.password2 + ): + raise ValueError("passwords do not match!") + return self + + else: + from pydantic import validator + + @validator("password2") + def passwords_match_(cls, v, values, **kwargs): + if "password" in values and v != values["password"]: + raise ValueError("passwords do not match!") + return v + + +# 默认保留的用户 +class SystemUserEnum(str, Enum): + ROOT = "root" + ADMIN = "admin" + GUEST = "guest" + + +class UserRoleEnum(str, Enum): + ADMIN = "admin" # 管理员 + STUDENT = "student" # 教职工 + OUT = "out" # 校外人员 diff --git a/src/services/users/services.py b/src/services/users/services.py new file mode 100644 index 0000000..5f57f9a --- /dev/null +++ b/src/services/users/services.py @@ -0,0 +1,77 @@ +from typing import Optional + +from fastapi_user_auth.auth.models import CasbinRule +from persica.factory.component import AsyncInitializingComponent +from pydantic import SecretStr + +from .models import UserModel, RoleModel +from src.services.users.repositories import UserRepo +from .schemas import UserRoleEnum + + +class UserServices(AsyncInitializingComponent): + def __init__(self, repo: UserRepo): + self.repo = repo + self.user_model = UserModel + self.role_model = RoleModel + self.rule_model = CasbinRule + + async def initialize(self): + for g in UserRoleEnum.__members__.keys(): + key = g.lower() + if await self.get_role(key=key) is None: + await self.create_role(key, f"{key} role") + print(f"Create role: {key}") + + async def register_user( + self, + username: str, + password: SecretStr, + email: str, + student_id: Optional[str], + phone: Optional[str], + ): + return await self.repo.register_user( + username, password, email, student_id, phone + ) + + async def login_user(self, user: "UserModel") -> str: + return await self.repo.AUTH.backend.token_store.write_token(user.model_dump()) + + async def get_user( + self, + username: Optional[str] = None, + email: Optional[str] = None, + student_id: Optional[str] = None, + phone: Optional[str] = None, + ) -> Optional[UserModel]: + return await self.repo.get_user(username, email, student_id, phone) + + async def get_role( + self, rid: Optional[int] = None, key: Optional[str] = None + ) -> Optional[RoleModel]: + return await self.repo.get_role(rid, key) + + async def create_role( + self, key: str, name: str, description: Optional[str] = None + ) -> RoleModel: + role = self.role_model(key=key, name=name, description=description) + return await self.repo.create_role(role) + + async def get_role_rule( + self, + ptype: Optional[str] = None, + v0: Optional[str] = None, + v1: Optional[str] = None, + ) -> Optional[RoleModel]: + return await self.repo.get_role_rule(ptype, v0, v1) + + async def is_user_in_role_group(self, username: str, role_key: str) -> bool: + return ( + await self.repo.get_role_rule("g", f"u:{username}", f"r:{role_key}") + is not None + ) + + async def add_user_to_role_group(self, username: str, role_key: str) -> CasbinRule: + rule = self.rule_model(ptype="g", v0=f"u:{username}", v1=f"r:{role_key}") + return await self.repo.create_role_rule(rule)