feat: login

This commit is contained in:
xtaodada 2024-11-04 19:41:15 +08:00
parent 2b7f89388b
commit a5e58557e3
Signed by: xtaodada
GPG Key ID: 4CBB3F4FA8C85659
7 changed files with 120 additions and 29 deletions

View File

@ -1,8 +1,8 @@
"""users """users
Revision ID: 089138f9c051 Revision ID: 3785e9a2a0c0
Revises: Revises:
Create Date: 2024-11-04 15:31:52.096235 Create Date: 2024-11-04 19:12:43.374773
""" """
@ -12,7 +12,7 @@ import sqlmodel
from fastapi_user_auth.utils.sqltypes import SecretStrType from fastapi_user_auth.utils.sqltypes import SecretStrType
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = "089138f9c051" revision = "3785e9a2a0c0"
down_revision = None down_revision = None
branch_labels = None branch_labels = None
depends_on = None depends_on = None
@ -162,12 +162,12 @@ def upgrade() -> None:
sqlmodel.sql.sqltypes.AutoString(length=255), sqlmodel.sql.sqltypes.AutoString(length=255),
nullable=True, nullable=True,
), ),
sa.Column("phone", sqlmodel.sql.sqltypes.AutoString(length=15), nullable=True),
sa.Column( sa.Column(
"student_id", "student_id",
sqlmodel.sql.sqltypes.AutoString(length=15), sqlmodel.sql.sqltypes.AutoString(length=15),
nullable=True, nullable=True,
), ),
sa.Column("phone", sqlmodel.sql.sqltypes.AutoString(length=15), nullable=True),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint("id"),
) )
op.create_index( op.create_index(

View File

@ -10,6 +10,8 @@ from src.env import config
class WebApp(AsyncInitializingComponent): class WebApp(AsyncInitializingComponent):
__order__ = 3
def __init__(self): def __init__(self):
dependencies = [] dependencies = []
self.app = FastAPI(dependencies=dependencies) self.app = FastAPI(dependencies=dependencies)

View File

@ -1,20 +1,39 @@
from typing import TYPE_CHECKING
from fastapi_amis_admin.crud import BaseApiOut from fastapi_amis_admin.crud import BaseApiOut
from persica.factory.component import AsyncInitializingComponent from persica.factory.component import AsyncInitializingComponent
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException, Depends
from starlette import status from starlette import status
from starlette.requests import Request
from starlette.responses import Response
from src.core.web_app import WebApp from src.core.web_app import WebApp
from src.services.users.schemas import UserRegIn, SystemUserEnum, UserLoginOut from src.services.users.schemas import UserRegIn, SystemUserEnum, UserLoginOut
from src.services.users.services import UserServices from src.services.users.services import UserServices
if TYPE_CHECKING:
from fastapi_user_auth.auth import Auth
class UserRoutes(AsyncInitializingComponent): class UserRoutes(AsyncInitializingComponent):
__order__ = 2
def __init__(self, app: WebApp, user_services: UserServices): def __init__(self, app: WebApp, user_services: UserServices):
self.router = APIRouter(prefix="/user") self.router = APIRouter(prefix="/user")
self.router.add_api_route("/register", self.register, methods=["POST"]) self.router.add_api_route("/register", self.register, methods=["POST"])
self.router.add_api_route("/login", self.login, methods=["POST"])
self.user_services = user_services self.user_services = user_services
app.app.include_router(self.router) self.app = app.app
async def initialize(self):
self.router.add_api_route(
"/need_login",
self.need_login,
methods=["GET"],
dependencies=[Depends(self.user_services.repo.AUTH.requires("admin")())],
)
self.app.include_router(self.router)
async def register(self, data: UserRegIn): async def register(self, data: UserRegIn):
if data.username.upper() in SystemUserEnum.__members__: if data.username.upper() in SystemUserEnum.__members__:
@ -22,9 +41,6 @@ class UserRoutes(AsyncInitializingComponent):
user = await self.user_services.get_user(username=data.username) user = await self.user_services.get_user(username=data.username)
if user: if user:
return BaseApiOut(status=-1, msg="用户名已被注册", data=None) return BaseApiOut(status=-1, msg="用户名已被注册", data=None)
user = await self.user_services.get_user(email=data.email)
if user:
return BaseApiOut(status=-1, msg="邮箱已被注册", data=None)
role = "student" role = "student"
if not (data.student_id or data.phone): if not (data.student_id or data.phone):
return BaseApiOut(status=-1, msg="学号或手机号至少填写一项", data=None) return BaseApiOut(status=-1, msg="学号或手机号至少填写一项", data=None)
@ -43,7 +59,6 @@ class UserRoutes(AsyncInitializingComponent):
user = await self.user_services.register_user( user = await self.user_services.register_user(
username=data.username, username=data.username,
password=data.password, password=data.password,
email=data.email,
student_id=data.student_id, student_id=data.student_id,
phone=data.phone, phone=data.phone,
) )
@ -58,3 +73,51 @@ class UserRoutes(AsyncInitializingComponent):
token_info = UserLoginOut.model_validate(user) token_info = UserLoginOut.model_validate(user)
token_info.access_token = await self.user_services.login_user(user) token_info.access_token = await self.user_services.login_user(user)
return BaseApiOut(code=0, msg="注册成功", data=token_info) return BaseApiOut(code=0, msg="注册成功", data=token_info)
async def create_login_history(self, request: "Request"):
# 保存登录记录
ip = request.client.host # 获取真实ip
# 获取代理ip
ips = [
request.headers.get(key, "").strip()
for key in ["x-forwarded-for", "x-real-ip", "x-client-ip", "remote-host"]
]
forwarded_for = ",".join([i for i in set(ips) if i and i != ip])
ua = request.headers.get("user-agent", "")
return await self.user_services.create_login_history(
request.scope.get("user"), ip, ua, forwarded_for
)
async def login(
self, request: Request, response: Response, username: str, password: str
):
auth: "Auth" = request.auth
if request.scope.get("user"):
return BaseApiOut(
code=1, msg="用户已登录", data=UserLoginOut.model_validate(request.user)
)
user = await auth.authenticate_user(username=username, password=password)
if not user:
return BaseApiOut(status=-1, msg="用户名或密码错误")
if not user.is_active:
return BaseApiOut(status=-2, msg="用户未激活")
request.scope["user"] = user
try:
await self.create_login_history(request)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error Execute SQL{e}",
) from e
token_info = UserLoginOut.model_validate(request.user)
token_info.access_token = await auth.backend.token_store.write_token(
request.user.dict()
)
response.set_cookie("Authorization", f"bearer {token_info.access_token}")
return BaseApiOut(code=0, data=token_info)
@staticmethod
async def need_login():
return {}

View File

@ -15,3 +15,5 @@ class PhoneMixin(SQLModel):
class UserModel(BaseUser, StudentIdMixin, PhoneMixin, table=True): class UserModel(BaseUser, StudentIdMixin, PhoneMixin, table=True):
__table_args__ = {"extend_existing": True} __table_args__ = {"extend_existing": True}
email: Optional[str] = None

View File

@ -2,7 +2,7 @@ from typing import Optional
from fastapi_user_auth.auth import Auth from fastapi_user_auth.auth import Auth
from fastapi_user_auth.auth.backends.redis import RedisTokenStore from fastapi_user_auth.auth.backends.redis import RedisTokenStore
from fastapi_user_auth.auth.models import CasbinRule from fastapi_user_auth.auth.models import CasbinRule, LoginHistory
from persica.factory.component import AsyncInitializingComponent from persica.factory.component import AsyncInitializingComponent
from pydantic import SecretStr from pydantic import SecretStr
from sqlmodel import select from sqlmodel import select
@ -15,28 +15,31 @@ from src.services.users.models import UserModel, RoleModel
class UserRepo(AsyncInitializingComponent): class UserRepo(AsyncInitializingComponent):
__order__ = 1
AUTH: Auth = None AUTH: Auth = None
def __init__(self, app: WebApp, database: Database, redis: RedisDB): def __init__(self, app: WebApp, database: Database, redis: RedisDB):
self.engine = database.engine self.engine = database.engine
self.AUTH = Auth( self.database = database
database.db, self.redis = redis
token_store=RedisTokenStore(redis.client), self.app = app
user_model=UserModel,
)
self.AUTH.backend.attach_middleware(app.app)
self.user_model = UserModel self.user_model = UserModel
self.role_model = RoleModel self.role_model = RoleModel
self.rule_model = CasbinRule self.rule_model = CasbinRule
async def initialize(self): async def initialize(self):
self.AUTH = Auth(
self.database.db,
token_store=RedisTokenStore(self.redis.client),
user_model=UserModel,
)
self.AUTH.backend.attach_middleware(self.app.app)
await self.AUTH.create_role_user("admin") await self.AUTH.create_role_user("admin")
async def register_user( async def register_user(
self, self,
username: str, username: str,
password: SecretStr, password: SecretStr,
email: str,
student_id: Optional[str], student_id: Optional[str],
phone: Optional[str], phone: Optional[str],
): ):
@ -44,7 +47,6 @@ class UserRepo(AsyncInitializingComponent):
values = { values = {
"username": username, "username": username,
"password": password, "password": password,
"email": email,
"student_id": student_id, "student_id": student_id,
"phone": phone, "phone": phone,
} }
@ -58,7 +60,6 @@ class UserRepo(AsyncInitializingComponent):
async def get_user( async def get_user(
self, self,
username: Optional[str] = None, username: Optional[str] = None,
email: Optional[str] = None,
student_id: Optional[str] = None, student_id: Optional[str] = None,
phone: Optional[str] = None, phone: Optional[str] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
@ -66,8 +67,6 @@ class UserRepo(AsyncInitializingComponent):
statement = select(self.user_model) statement = select(self.user_model)
if username: if username:
statement = statement.where(self.user_model.username == username) statement = statement.where(self.user_model.username == username)
if email:
statement = statement.where(self.user_model.email == email)
if student_id: if student_id:
statement = statement.where(self.user_model.student_id == student_id) statement = statement.where(self.user_model.student_id == student_id)
if phone: if phone:
@ -117,3 +116,10 @@ class UserRepo(AsyncInitializingComponent):
await session.commit() await session.commit()
await session.refresh(rule) await session.refresh(rule)
return rule return rule
async def create_login_history(self, login_history: "LoginHistory"):
async with AsyncSession(self.engine) as session:
session.add(login_history)
await session.commit()
await session.refresh(login_history)
return login_history

View File

@ -3,7 +3,7 @@ from typing import Optional
from fastapi_amis_admin.utils.pydantic import PYDANTIC_V2 from fastapi_amis_admin.utils.pydantic import PYDANTIC_V2
from fastapi_amis_admin.utils.translation import i18n as _ from fastapi_amis_admin.utils.translation import i18n as _
from fastapi_user_auth.mixins.models import UsernameMixin, PasswordMixin, EmailMixin from fastapi_user_auth.mixins.models import UsernameMixin, PasswordMixin
from fastapi_user_auth.utils.sqltypes import SecretStrType from fastapi_user_auth.utils.sqltypes import SecretStrType
from pydantic import BaseModel, SecretStr from pydantic import BaseModel, SecretStr
from sqlmodel import Field from sqlmodel import Field
@ -29,7 +29,7 @@ class UserLoginOut(UserModel):
) )
class UserRegIn(UsernameMixin, PasswordMixin, EmailMixin, StudentIdMixin, PhoneMixin): class UserRegIn(UsernameMixin, PasswordMixin, StudentIdMixin, PhoneMixin):
"""用户注册""" """用户注册"""
password2: str = Field(title=_("Confirm Password"), max_length=128) password2: str = Field(title=_("Confirm Password"), max_length=128)

View File

@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from fastapi_user_auth.auth.models import CasbinRule from fastapi_user_auth.auth.models import CasbinRule, LoginHistory
from persica.factory.component import AsyncInitializingComponent from persica.factory.component import AsyncInitializingComponent
from pydantic import SecretStr from pydantic import SecretStr
@ -10,6 +10,8 @@ from .schemas import UserRoleEnum
class UserServices(AsyncInitializingComponent): class UserServices(AsyncInitializingComponent):
__order__ = 1
def __init__(self, repo: UserRepo): def __init__(self, repo: UserRepo):
self.repo = repo self.repo = repo
self.user_model = UserModel self.user_model = UserModel
@ -19,6 +21,8 @@ class UserServices(AsyncInitializingComponent):
async def initialize(self): async def initialize(self):
for g in UserRoleEnum.__members__.keys(): for g in UserRoleEnum.__members__.keys():
key = g.lower() key = g.lower()
if key == "admin":
continue
if await self.get_role(key=key) is None: if await self.get_role(key=key) is None:
await self.create_role(key, f"{key} role") await self.create_role(key, f"{key} role")
print(f"Create role: {key}") print(f"Create role: {key}")
@ -27,12 +31,14 @@ class UserServices(AsyncInitializingComponent):
self, self,
username: str, username: str,
password: SecretStr, password: SecretStr,
email: str,
student_id: Optional[str], student_id: Optional[str],
phone: Optional[str], phone: Optional[str],
): ):
return await self.repo.register_user( return await self.repo.register_user(
username, password, email, student_id, phone username,
password,
student_id,
phone,
) )
async def login_user(self, user: "UserModel") -> str: async def login_user(self, user: "UserModel") -> str:
@ -41,11 +47,10 @@ class UserServices(AsyncInitializingComponent):
async def get_user( async def get_user(
self, self,
username: Optional[str] = None, username: Optional[str] = None,
email: Optional[str] = None,
student_id: Optional[str] = None, student_id: Optional[str] = None,
phone: Optional[str] = None, phone: Optional[str] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
return await self.repo.get_user(username, email, student_id, phone) return await self.repo.get_user(username, student_id, phone)
async def get_role( async def get_role(
self, rid: Optional[int] = None, key: Optional[str] = None self, rid: Optional[int] = None, key: Optional[str] = None
@ -75,3 +80,16 @@ class UserServices(AsyncInitializingComponent):
async def add_user_to_role_group(self, username: str, role_key: str) -> CasbinRule: async def add_user_to_role_group(self, username: str, role_key: str) -> CasbinRule:
rule = self.rule_model(ptype="g", v0=f"u:{username}", v1=f"r:{role_key}") rule = self.rule_model(ptype="g", v0=f"u:{username}", v1=f"r:{role_key}")
return await self.repo.create_role_rule(rule) return await self.repo.create_role_rule(rule)
async def create_login_history(
self, user: "UserModel", ip: str, ua: str, forwarded_for: str
):
history = LoginHistory(
user_id=user.id,
login_name=user.username,
ip=ip,
user_agent=ua,
login_status="登录成功",
forwarded_for=forwarded_for,
)
return await self.repo.create_login_history(history)