diff --git a/alembic/versions/089138f9c051_users.py b/alembic/versions/3785e9a2a0c0_users.py similarity index 98% rename from alembic/versions/089138f9c051_users.py rename to alembic/versions/3785e9a2a0c0_users.py index d5fe20e..ef92fd1 100644 --- a/alembic/versions/089138f9c051_users.py +++ b/alembic/versions/3785e9a2a0c0_users.py @@ -1,8 +1,8 @@ """users -Revision ID: 089138f9c051 +Revision ID: 3785e9a2a0c0 Revises: -Create Date: 2024-11-04 15:31:52.096235 +Create Date: 2024-11-04 19:12:43.374773 """ @@ -12,7 +12,7 @@ import sqlmodel from fastapi_user_auth.utils.sqltypes import SecretStrType # revision identifiers, used by Alembic. -revision = "089138f9c051" +revision = "3785e9a2a0c0" down_revision = None branch_labels = None depends_on = None @@ -162,12 +162,12 @@ def upgrade() -> None: sqlmodel.sql.sqltypes.AutoString(length=255), nullable=True, ), + sa.Column("phone", sqlmodel.sql.sqltypes.AutoString(length=15), nullable=True), sa.Column( "student_id", sqlmodel.sql.sqltypes.AutoString(length=15), nullable=True, ), - sa.Column("phone", sqlmodel.sql.sqltypes.AutoString(length=15), nullable=True), sa.PrimaryKeyConstraint("id"), ) op.create_index( diff --git a/src/core/web_app.py b/src/core/web_app.py index d680e23..28543f7 100644 --- a/src/core/web_app.py +++ b/src/core/web_app.py @@ -10,6 +10,8 @@ from src.env import config class WebApp(AsyncInitializingComponent): + __order__ = 3 + def __init__(self): dependencies = [] self.app = FastAPI(dependencies=dependencies) diff --git a/src/route/users.py b/src/route/users.py index 98fbffe..343ba57 100644 --- a/src/route/users.py +++ b/src/route/users.py @@ -1,20 +1,39 @@ +from typing import TYPE_CHECKING + from fastapi_amis_admin.crud import BaseApiOut from persica.factory.component import AsyncInitializingComponent -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, Depends from starlette import status +from starlette.requests import Request +from starlette.responses import Response from src.core.web_app import WebApp from src.services.users.schemas import UserRegIn, SystemUserEnum, UserLoginOut from src.services.users.services import UserServices +if TYPE_CHECKING: + from fastapi_user_auth.auth import Auth + class UserRoutes(AsyncInitializingComponent): + __order__ = 2 + def __init__(self, app: WebApp, user_services: UserServices): self.router = APIRouter(prefix="/user") self.router.add_api_route("/register", self.register, methods=["POST"]) + self.router.add_api_route("/login", self.login, methods=["POST"]) self.user_services = user_services - app.app.include_router(self.router) + self.app = app.app + + async def initialize(self): + self.router.add_api_route( + "/need_login", + self.need_login, + methods=["GET"], + dependencies=[Depends(self.user_services.repo.AUTH.requires("admin")())], + ) + self.app.include_router(self.router) async def register(self, data: UserRegIn): if data.username.upper() in SystemUserEnum.__members__: @@ -22,9 +41,6 @@ class UserRoutes(AsyncInitializingComponent): user = await self.user_services.get_user(username=data.username) if user: return BaseApiOut(status=-1, msg="用户名已被注册", data=None) - user = await self.user_services.get_user(email=data.email) - if user: - return BaseApiOut(status=-1, msg="邮箱已被注册", data=None) role = "student" if not (data.student_id or data.phone): return BaseApiOut(status=-1, msg="学号或手机号至少填写一项", data=None) @@ -43,7 +59,6 @@ class UserRoutes(AsyncInitializingComponent): user = await self.user_services.register_user( username=data.username, password=data.password, - email=data.email, student_id=data.student_id, phone=data.phone, ) @@ -58,3 +73,51 @@ class UserRoutes(AsyncInitializingComponent): token_info = UserLoginOut.model_validate(user) token_info.access_token = await self.user_services.login_user(user) return BaseApiOut(code=0, msg="注册成功", data=token_info) + + async def create_login_history(self, request: "Request"): + # 保存登录记录 + ip = request.client.host # 获取真实ip + # 获取代理ip + ips = [ + request.headers.get(key, "").strip() + for key in ["x-forwarded-for", "x-real-ip", "x-client-ip", "remote-host"] + ] + forwarded_for = ",".join([i for i in set(ips) if i and i != ip]) + ua = request.headers.get("user-agent", "") + return await self.user_services.create_login_history( + request.scope.get("user"), ip, ua, forwarded_for + ) + + async def login( + self, request: Request, response: Response, username: str, password: str + ): + auth: "Auth" = request.auth + if request.scope.get("user"): + return BaseApiOut( + code=1, msg="用户已登录", data=UserLoginOut.model_validate(request.user) + ) + user = await auth.authenticate_user(username=username, password=password) + if not user: + return BaseApiOut(status=-1, msg="用户名或密码错误") + if not user.is_active: + return BaseApiOut(status=-2, msg="用户未激活") + request.scope["user"] = user + + try: + await self.create_login_history(request) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error Execute SQL:{e}", + ) from e + + token_info = UserLoginOut.model_validate(request.user) + token_info.access_token = await auth.backend.token_store.write_token( + request.user.dict() + ) + response.set_cookie("Authorization", f"bearer {token_info.access_token}") + return BaseApiOut(code=0, data=token_info) + + @staticmethod + async def need_login(): + return {} diff --git a/src/services/users/models.py b/src/services/users/models.py index 9588f4e..8a9d565 100644 --- a/src/services/users/models.py +++ b/src/services/users/models.py @@ -15,3 +15,5 @@ class PhoneMixin(SQLModel): class UserModel(BaseUser, StudentIdMixin, PhoneMixin, table=True): __table_args__ = {"extend_existing": True} + + email: Optional[str] = None diff --git a/src/services/users/repositories.py b/src/services/users/repositories.py index 9f6c542..7b656a5 100644 --- a/src/services/users/repositories.py +++ b/src/services/users/repositories.py @@ -2,7 +2,7 @@ from typing import Optional from fastapi_user_auth.auth import Auth from fastapi_user_auth.auth.backends.redis import RedisTokenStore -from fastapi_user_auth.auth.models import CasbinRule +from fastapi_user_auth.auth.models import CasbinRule, LoginHistory from persica.factory.component import AsyncInitializingComponent from pydantic import SecretStr from sqlmodel import select @@ -15,28 +15,31 @@ from src.services.users.models import UserModel, RoleModel class UserRepo(AsyncInitializingComponent): + __order__ = 1 AUTH: Auth = None def __init__(self, app: WebApp, database: Database, redis: RedisDB): self.engine = database.engine - self.AUTH = Auth( - database.db, - token_store=RedisTokenStore(redis.client), - user_model=UserModel, - ) - self.AUTH.backend.attach_middleware(app.app) + self.database = database + self.redis = redis + self.app = app self.user_model = UserModel self.role_model = RoleModel self.rule_model = CasbinRule async def initialize(self): + self.AUTH = Auth( + self.database.db, + token_store=RedisTokenStore(self.redis.client), + user_model=UserModel, + ) + self.AUTH.backend.attach_middleware(self.app.app) await self.AUTH.create_role_user("admin") async def register_user( self, username: str, password: SecretStr, - email: str, student_id: Optional[str], phone: Optional[str], ): @@ -44,7 +47,6 @@ class UserRepo(AsyncInitializingComponent): values = { "username": username, "password": password, - "email": email, "student_id": student_id, "phone": phone, } @@ -58,7 +60,6 @@ class UserRepo(AsyncInitializingComponent): async def get_user( self, username: Optional[str] = None, - email: Optional[str] = None, student_id: Optional[str] = None, phone: Optional[str] = None, ) -> Optional[UserModel]: @@ -66,8 +67,6 @@ class UserRepo(AsyncInitializingComponent): statement = select(self.user_model) if username: statement = statement.where(self.user_model.username == username) - if email: - statement = statement.where(self.user_model.email == email) if student_id: statement = statement.where(self.user_model.student_id == student_id) if phone: @@ -117,3 +116,10 @@ class UserRepo(AsyncInitializingComponent): await session.commit() await session.refresh(rule) return rule + + async def create_login_history(self, login_history: "LoginHistory"): + async with AsyncSession(self.engine) as session: + session.add(login_history) + await session.commit() + await session.refresh(login_history) + return login_history diff --git a/src/services/users/schemas.py b/src/services/users/schemas.py index 758849c..b301a4e 100644 --- a/src/services/users/schemas.py +++ b/src/services/users/schemas.py @@ -3,7 +3,7 @@ from typing import Optional from fastapi_amis_admin.utils.pydantic import PYDANTIC_V2 from fastapi_amis_admin.utils.translation import i18n as _ -from fastapi_user_auth.mixins.models import UsernameMixin, PasswordMixin, EmailMixin +from fastapi_user_auth.mixins.models import UsernameMixin, PasswordMixin from fastapi_user_auth.utils.sqltypes import SecretStrType from pydantic import BaseModel, SecretStr from sqlmodel import Field @@ -29,7 +29,7 @@ class UserLoginOut(UserModel): ) -class UserRegIn(UsernameMixin, PasswordMixin, EmailMixin, StudentIdMixin, PhoneMixin): +class UserRegIn(UsernameMixin, PasswordMixin, StudentIdMixin, PhoneMixin): """用户注册""" password2: str = Field(title=_("Confirm Password"), max_length=128) diff --git a/src/services/users/services.py b/src/services/users/services.py index 5f57f9a..dc9fa48 100644 --- a/src/services/users/services.py +++ b/src/services/users/services.py @@ -1,6 +1,6 @@ from typing import Optional -from fastapi_user_auth.auth.models import CasbinRule +from fastapi_user_auth.auth.models import CasbinRule, LoginHistory from persica.factory.component import AsyncInitializingComponent from pydantic import SecretStr @@ -10,6 +10,8 @@ from .schemas import UserRoleEnum class UserServices(AsyncInitializingComponent): + __order__ = 1 + def __init__(self, repo: UserRepo): self.repo = repo self.user_model = UserModel @@ -19,6 +21,8 @@ class UserServices(AsyncInitializingComponent): async def initialize(self): for g in UserRoleEnum.__members__.keys(): key = g.lower() + if key == "admin": + continue if await self.get_role(key=key) is None: await self.create_role(key, f"{key} role") print(f"Create role: {key}") @@ -27,12 +31,14 @@ class UserServices(AsyncInitializingComponent): self, username: str, password: SecretStr, - email: str, student_id: Optional[str], phone: Optional[str], ): return await self.repo.register_user( - username, password, email, student_id, phone + username, + password, + student_id, + phone, ) async def login_user(self, user: "UserModel") -> str: @@ -41,11 +47,10 @@ class UserServices(AsyncInitializingComponent): async def get_user( self, username: Optional[str] = None, - email: Optional[str] = None, student_id: Optional[str] = None, phone: Optional[str] = None, ) -> Optional[UserModel]: - return await self.repo.get_user(username, email, student_id, phone) + return await self.repo.get_user(username, student_id, phone) async def get_role( self, rid: Optional[int] = None, key: Optional[str] = None @@ -75,3 +80,16 @@ class UserServices(AsyncInitializingComponent): async def add_user_to_role_group(self, username: str, role_key: str) -> CasbinRule: rule = self.rule_model(ptype="g", v0=f"u:{username}", v1=f"r:{role_key}") return await self.repo.create_role_rule(rule) + + async def create_login_history( + self, user: "UserModel", ip: str, ua: str, forwarded_for: str + ): + history = LoginHistory( + user_id=user.id, + login_name=user.username, + ip=ip, + user_agent=ua, + login_status="登录成功", + forwarded_for=forwarded_for, + ) + return await self.repo.create_login_history(history)