feat: login

This commit is contained in:
xtaodada 2024-11-04 19:41:15 +08:00
parent 2b7f89388b
commit a5e58557e3
Signed by: xtaodada
GPG Key ID: 4CBB3F4FA8C85659
7 changed files with 120 additions and 29 deletions

View File

@ -1,8 +1,8 @@
"""users
Revision ID: 089138f9c051
Revision ID: 3785e9a2a0c0
Revises:
Create Date: 2024-11-04 15:31:52.096235
Create Date: 2024-11-04 19:12:43.374773
"""
@ -12,7 +12,7 @@ import sqlmodel
from fastapi_user_auth.utils.sqltypes import SecretStrType
# revision identifiers, used by Alembic.
revision = "089138f9c051"
revision = "3785e9a2a0c0"
down_revision = None
branch_labels = None
depends_on = None
@ -162,12 +162,12 @@ def upgrade() -> None:
sqlmodel.sql.sqltypes.AutoString(length=255),
nullable=True,
),
sa.Column("phone", sqlmodel.sql.sqltypes.AutoString(length=15), nullable=True),
sa.Column(
"student_id",
sqlmodel.sql.sqltypes.AutoString(length=15),
nullable=True,
),
sa.Column("phone", sqlmodel.sql.sqltypes.AutoString(length=15), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(

View File

@ -10,6 +10,8 @@ from src.env import config
class WebApp(AsyncInitializingComponent):
__order__ = 3
def __init__(self):
dependencies = []
self.app = FastAPI(dependencies=dependencies)

View File

@ -1,20 +1,39 @@
from typing import TYPE_CHECKING
from fastapi_amis_admin.crud import BaseApiOut
from persica.factory.component import AsyncInitializingComponent
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, HTTPException, Depends
from starlette import status
from starlette.requests import Request
from starlette.responses import Response
from src.core.web_app import WebApp
from src.services.users.schemas import UserRegIn, SystemUserEnum, UserLoginOut
from src.services.users.services import UserServices
if TYPE_CHECKING:
from fastapi_user_auth.auth import Auth
class UserRoutes(AsyncInitializingComponent):
__order__ = 2
def __init__(self, app: WebApp, user_services: UserServices):
self.router = APIRouter(prefix="/user")
self.router.add_api_route("/register", self.register, methods=["POST"])
self.router.add_api_route("/login", self.login, methods=["POST"])
self.user_services = user_services
app.app.include_router(self.router)
self.app = app.app
async def initialize(self):
self.router.add_api_route(
"/need_login",
self.need_login,
methods=["GET"],
dependencies=[Depends(self.user_services.repo.AUTH.requires("admin")())],
)
self.app.include_router(self.router)
async def register(self, data: UserRegIn):
if data.username.upper() in SystemUserEnum.__members__:
@ -22,9 +41,6 @@ class UserRoutes(AsyncInitializingComponent):
user = await self.user_services.get_user(username=data.username)
if user:
return BaseApiOut(status=-1, msg="用户名已被注册", data=None)
user = await self.user_services.get_user(email=data.email)
if user:
return BaseApiOut(status=-1, msg="邮箱已被注册", data=None)
role = "student"
if not (data.student_id or data.phone):
return BaseApiOut(status=-1, msg="学号或手机号至少填写一项", data=None)
@ -43,7 +59,6 @@ class UserRoutes(AsyncInitializingComponent):
user = await self.user_services.register_user(
username=data.username,
password=data.password,
email=data.email,
student_id=data.student_id,
phone=data.phone,
)
@ -58,3 +73,51 @@ class UserRoutes(AsyncInitializingComponent):
token_info = UserLoginOut.model_validate(user)
token_info.access_token = await self.user_services.login_user(user)
return BaseApiOut(code=0, msg="注册成功", data=token_info)
async def create_login_history(self, request: "Request"):
# 保存登录记录
ip = request.client.host # 获取真实ip
# 获取代理ip
ips = [
request.headers.get(key, "").strip()
for key in ["x-forwarded-for", "x-real-ip", "x-client-ip", "remote-host"]
]
forwarded_for = ",".join([i for i in set(ips) if i and i != ip])
ua = request.headers.get("user-agent", "")
return await self.user_services.create_login_history(
request.scope.get("user"), ip, ua, forwarded_for
)
async def login(
self, request: Request, response: Response, username: str, password: str
):
auth: "Auth" = request.auth
if request.scope.get("user"):
return BaseApiOut(
code=1, msg="用户已登录", data=UserLoginOut.model_validate(request.user)
)
user = await auth.authenticate_user(username=username, password=password)
if not user:
return BaseApiOut(status=-1, msg="用户名或密码错误")
if not user.is_active:
return BaseApiOut(status=-2, msg="用户未激活")
request.scope["user"] = user
try:
await self.create_login_history(request)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error Execute SQL{e}",
) from e
token_info = UserLoginOut.model_validate(request.user)
token_info.access_token = await auth.backend.token_store.write_token(
request.user.dict()
)
response.set_cookie("Authorization", f"bearer {token_info.access_token}")
return BaseApiOut(code=0, data=token_info)
@staticmethod
async def need_login():
return {}

View File

@ -15,3 +15,5 @@ class PhoneMixin(SQLModel):
class UserModel(BaseUser, StudentIdMixin, PhoneMixin, table=True):
__table_args__ = {"extend_existing": True}
email: Optional[str] = None

View File

@ -2,7 +2,7 @@ from typing import Optional
from fastapi_user_auth.auth import Auth
from fastapi_user_auth.auth.backends.redis import RedisTokenStore
from fastapi_user_auth.auth.models import CasbinRule
from fastapi_user_auth.auth.models import CasbinRule, LoginHistory
from persica.factory.component import AsyncInitializingComponent
from pydantic import SecretStr
from sqlmodel import select
@ -15,28 +15,31 @@ from src.services.users.models import UserModel, RoleModel
class UserRepo(AsyncInitializingComponent):
__order__ = 1
AUTH: Auth = None
def __init__(self, app: WebApp, database: Database, redis: RedisDB):
self.engine = database.engine
self.AUTH = Auth(
database.db,
token_store=RedisTokenStore(redis.client),
user_model=UserModel,
)
self.AUTH.backend.attach_middleware(app.app)
self.database = database
self.redis = redis
self.app = app
self.user_model = UserModel
self.role_model = RoleModel
self.rule_model = CasbinRule
async def initialize(self):
self.AUTH = Auth(
self.database.db,
token_store=RedisTokenStore(self.redis.client),
user_model=UserModel,
)
self.AUTH.backend.attach_middleware(self.app.app)
await self.AUTH.create_role_user("admin")
async def register_user(
self,
username: str,
password: SecretStr,
email: str,
student_id: Optional[str],
phone: Optional[str],
):
@ -44,7 +47,6 @@ class UserRepo(AsyncInitializingComponent):
values = {
"username": username,
"password": password,
"email": email,
"student_id": student_id,
"phone": phone,
}
@ -58,7 +60,6 @@ class UserRepo(AsyncInitializingComponent):
async def get_user(
self,
username: Optional[str] = None,
email: Optional[str] = None,
student_id: Optional[str] = None,
phone: Optional[str] = None,
) -> Optional[UserModel]:
@ -66,8 +67,6 @@ class UserRepo(AsyncInitializingComponent):
statement = select(self.user_model)
if username:
statement = statement.where(self.user_model.username == username)
if email:
statement = statement.where(self.user_model.email == email)
if student_id:
statement = statement.where(self.user_model.student_id == student_id)
if phone:
@ -117,3 +116,10 @@ class UserRepo(AsyncInitializingComponent):
await session.commit()
await session.refresh(rule)
return rule
async def create_login_history(self, login_history: "LoginHistory"):
async with AsyncSession(self.engine) as session:
session.add(login_history)
await session.commit()
await session.refresh(login_history)
return login_history

View File

@ -3,7 +3,7 @@ from typing import Optional
from fastapi_amis_admin.utils.pydantic import PYDANTIC_V2
from fastapi_amis_admin.utils.translation import i18n as _
from fastapi_user_auth.mixins.models import UsernameMixin, PasswordMixin, EmailMixin
from fastapi_user_auth.mixins.models import UsernameMixin, PasswordMixin
from fastapi_user_auth.utils.sqltypes import SecretStrType
from pydantic import BaseModel, SecretStr
from sqlmodel import Field
@ -29,7 +29,7 @@ class UserLoginOut(UserModel):
)
class UserRegIn(UsernameMixin, PasswordMixin, EmailMixin, StudentIdMixin, PhoneMixin):
class UserRegIn(UsernameMixin, PasswordMixin, StudentIdMixin, PhoneMixin):
"""用户注册"""
password2: str = Field(title=_("Confirm Password"), max_length=128)

View File

@ -1,6 +1,6 @@
from typing import Optional
from fastapi_user_auth.auth.models import CasbinRule
from fastapi_user_auth.auth.models import CasbinRule, LoginHistory
from persica.factory.component import AsyncInitializingComponent
from pydantic import SecretStr
@ -10,6 +10,8 @@ from .schemas import UserRoleEnum
class UserServices(AsyncInitializingComponent):
__order__ = 1
def __init__(self, repo: UserRepo):
self.repo = repo
self.user_model = UserModel
@ -19,6 +21,8 @@ class UserServices(AsyncInitializingComponent):
async def initialize(self):
for g in UserRoleEnum.__members__.keys():
key = g.lower()
if key == "admin":
continue
if await self.get_role(key=key) is None:
await self.create_role(key, f"{key} role")
print(f"Create role: {key}")
@ -27,12 +31,14 @@ class UserServices(AsyncInitializingComponent):
self,
username: str,
password: SecretStr,
email: str,
student_id: Optional[str],
phone: Optional[str],
):
return await self.repo.register_user(
username, password, email, student_id, phone
username,
password,
student_id,
phone,
)
async def login_user(self, user: "UserModel") -> str:
@ -41,11 +47,10 @@ class UserServices(AsyncInitializingComponent):
async def get_user(
self,
username: Optional[str] = None,
email: Optional[str] = None,
student_id: Optional[str] = None,
phone: Optional[str] = None,
) -> Optional[UserModel]:
return await self.repo.get_user(username, email, student_id, phone)
return await self.repo.get_user(username, student_id, phone)
async def get_role(
self, rid: Optional[int] = None, key: Optional[str] = None
@ -75,3 +80,16 @@ class UserServices(AsyncInitializingComponent):
async def add_user_to_role_group(self, username: str, role_key: str) -> CasbinRule:
rule = self.rule_model(ptype="g", v0=f"u:{username}", v1=f"r:{role_key}")
return await self.repo.create_role_rule(rule)
async def create_login_history(
self, user: "UserModel", ip: str, ua: str, forwarded_for: str
):
history = LoginHistory(
user_id=user.id,
login_name=user.username,
ip=ip,
user_agent=ua,
login_status="登录成功",
forwarded_for=forwarded_for,
)
return await self.repo.create_login_history(history)