feat: login
This commit is contained in:
parent
2b7f89388b
commit
a5e58557e3
@ -1,8 +1,8 @@
|
||||
"""users
|
||||
|
||||
Revision ID: 089138f9c051
|
||||
Revision ID: 3785e9a2a0c0
|
||||
Revises:
|
||||
Create Date: 2024-11-04 15:31:52.096235
|
||||
Create Date: 2024-11-04 19:12:43.374773
|
||||
|
||||
"""
|
||||
|
||||
@ -12,7 +12,7 @@ import sqlmodel
|
||||
from fastapi_user_auth.utils.sqltypes import SecretStrType
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "089138f9c051"
|
||||
revision = "3785e9a2a0c0"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
@ -162,12 +162,12 @@ def upgrade() -> None:
|
||||
sqlmodel.sql.sqltypes.AutoString(length=255),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("phone", sqlmodel.sql.sqltypes.AutoString(length=15), nullable=True),
|
||||
sa.Column(
|
||||
"student_id",
|
||||
sqlmodel.sql.sqltypes.AutoString(length=15),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("phone", sqlmodel.sql.sqltypes.AutoString(length=15), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
@ -10,6 +10,8 @@ from src.env import config
|
||||
|
||||
|
||||
class WebApp(AsyncInitializingComponent):
|
||||
__order__ = 3
|
||||
|
||||
def __init__(self):
|
||||
dependencies = []
|
||||
self.app = FastAPI(dependencies=dependencies)
|
||||
|
@ -1,20 +1,39 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi_amis_admin.crud import BaseApiOut
|
||||
from persica.factory.component import AsyncInitializingComponent
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from starlette import status
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from src.core.web_app import WebApp
|
||||
from src.services.users.schemas import UserRegIn, SystemUserEnum, UserLoginOut
|
||||
from src.services.users.services import UserServices
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi_user_auth.auth import Auth
|
||||
|
||||
|
||||
class UserRoutes(AsyncInitializingComponent):
|
||||
__order__ = 2
|
||||
|
||||
def __init__(self, app: WebApp, user_services: UserServices):
|
||||
self.router = APIRouter(prefix="/user")
|
||||
self.router.add_api_route("/register", self.register, methods=["POST"])
|
||||
self.router.add_api_route("/login", self.login, methods=["POST"])
|
||||
self.user_services = user_services
|
||||
app.app.include_router(self.router)
|
||||
self.app = app.app
|
||||
|
||||
async def initialize(self):
|
||||
self.router.add_api_route(
|
||||
"/need_login",
|
||||
self.need_login,
|
||||
methods=["GET"],
|
||||
dependencies=[Depends(self.user_services.repo.AUTH.requires("admin")())],
|
||||
)
|
||||
self.app.include_router(self.router)
|
||||
|
||||
async def register(self, data: UserRegIn):
|
||||
if data.username.upper() in SystemUserEnum.__members__:
|
||||
@ -22,9 +41,6 @@ class UserRoutes(AsyncInitializingComponent):
|
||||
user = await self.user_services.get_user(username=data.username)
|
||||
if user:
|
||||
return BaseApiOut(status=-1, msg="用户名已被注册", data=None)
|
||||
user = await self.user_services.get_user(email=data.email)
|
||||
if user:
|
||||
return BaseApiOut(status=-1, msg="邮箱已被注册", data=None)
|
||||
role = "student"
|
||||
if not (data.student_id or data.phone):
|
||||
return BaseApiOut(status=-1, msg="学号或手机号至少填写一项", data=None)
|
||||
@ -43,7 +59,6 @@ class UserRoutes(AsyncInitializingComponent):
|
||||
user = await self.user_services.register_user(
|
||||
username=data.username,
|
||||
password=data.password,
|
||||
email=data.email,
|
||||
student_id=data.student_id,
|
||||
phone=data.phone,
|
||||
)
|
||||
@ -58,3 +73,51 @@ class UserRoutes(AsyncInitializingComponent):
|
||||
token_info = UserLoginOut.model_validate(user)
|
||||
token_info.access_token = await self.user_services.login_user(user)
|
||||
return BaseApiOut(code=0, msg="注册成功", data=token_info)
|
||||
|
||||
async def create_login_history(self, request: "Request"):
|
||||
# 保存登录记录
|
||||
ip = request.client.host # 获取真实ip
|
||||
# 获取代理ip
|
||||
ips = [
|
||||
request.headers.get(key, "").strip()
|
||||
for key in ["x-forwarded-for", "x-real-ip", "x-client-ip", "remote-host"]
|
||||
]
|
||||
forwarded_for = ",".join([i for i in set(ips) if i and i != ip])
|
||||
ua = request.headers.get("user-agent", "")
|
||||
return await self.user_services.create_login_history(
|
||||
request.scope.get("user"), ip, ua, forwarded_for
|
||||
)
|
||||
|
||||
async def login(
|
||||
self, request: Request, response: Response, username: str, password: str
|
||||
):
|
||||
auth: "Auth" = request.auth
|
||||
if request.scope.get("user"):
|
||||
return BaseApiOut(
|
||||
code=1, msg="用户已登录", data=UserLoginOut.model_validate(request.user)
|
||||
)
|
||||
user = await auth.authenticate_user(username=username, password=password)
|
||||
if not user:
|
||||
return BaseApiOut(status=-1, msg="用户名或密码错误")
|
||||
if not user.is_active:
|
||||
return BaseApiOut(status=-2, msg="用户未激活")
|
||||
request.scope["user"] = user
|
||||
|
||||
try:
|
||||
await self.create_login_history(request)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error Execute SQL:{e}",
|
||||
) from e
|
||||
|
||||
token_info = UserLoginOut.model_validate(request.user)
|
||||
token_info.access_token = await auth.backend.token_store.write_token(
|
||||
request.user.dict()
|
||||
)
|
||||
response.set_cookie("Authorization", f"bearer {token_info.access_token}")
|
||||
return BaseApiOut(code=0, data=token_info)
|
||||
|
||||
@staticmethod
|
||||
async def need_login():
|
||||
return {}
|
||||
|
@ -15,3 +15,5 @@ class PhoneMixin(SQLModel):
|
||||
|
||||
class UserModel(BaseUser, StudentIdMixin, PhoneMixin, table=True):
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
email: Optional[str] = None
|
||||
|
@ -2,7 +2,7 @@ from typing import Optional
|
||||
|
||||
from fastapi_user_auth.auth import Auth
|
||||
from fastapi_user_auth.auth.backends.redis import RedisTokenStore
|
||||
from fastapi_user_auth.auth.models import CasbinRule
|
||||
from fastapi_user_auth.auth.models import CasbinRule, LoginHistory
|
||||
from persica.factory.component import AsyncInitializingComponent
|
||||
from pydantic import SecretStr
|
||||
from sqlmodel import select
|
||||
@ -15,28 +15,31 @@ from src.services.users.models import UserModel, RoleModel
|
||||
|
||||
|
||||
class UserRepo(AsyncInitializingComponent):
|
||||
__order__ = 1
|
||||
AUTH: Auth = None
|
||||
|
||||
def __init__(self, app: WebApp, database: Database, redis: RedisDB):
|
||||
self.engine = database.engine
|
||||
self.AUTH = Auth(
|
||||
database.db,
|
||||
token_store=RedisTokenStore(redis.client),
|
||||
user_model=UserModel,
|
||||
)
|
||||
self.AUTH.backend.attach_middleware(app.app)
|
||||
self.database = database
|
||||
self.redis = redis
|
||||
self.app = app
|
||||
self.user_model = UserModel
|
||||
self.role_model = RoleModel
|
||||
self.rule_model = CasbinRule
|
||||
|
||||
async def initialize(self):
|
||||
self.AUTH = Auth(
|
||||
self.database.db,
|
||||
token_store=RedisTokenStore(self.redis.client),
|
||||
user_model=UserModel,
|
||||
)
|
||||
self.AUTH.backend.attach_middleware(self.app.app)
|
||||
await self.AUTH.create_role_user("admin")
|
||||
|
||||
async def register_user(
|
||||
self,
|
||||
username: str,
|
||||
password: SecretStr,
|
||||
email: str,
|
||||
student_id: Optional[str],
|
||||
phone: Optional[str],
|
||||
):
|
||||
@ -44,7 +47,6 @@ class UserRepo(AsyncInitializingComponent):
|
||||
values = {
|
||||
"username": username,
|
||||
"password": password,
|
||||
"email": email,
|
||||
"student_id": student_id,
|
||||
"phone": phone,
|
||||
}
|
||||
@ -58,7 +60,6 @@ class UserRepo(AsyncInitializingComponent):
|
||||
async def get_user(
|
||||
self,
|
||||
username: Optional[str] = None,
|
||||
email: Optional[str] = None,
|
||||
student_id: Optional[str] = None,
|
||||
phone: Optional[str] = None,
|
||||
) -> Optional[UserModel]:
|
||||
@ -66,8 +67,6 @@ class UserRepo(AsyncInitializingComponent):
|
||||
statement = select(self.user_model)
|
||||
if username:
|
||||
statement = statement.where(self.user_model.username == username)
|
||||
if email:
|
||||
statement = statement.where(self.user_model.email == email)
|
||||
if student_id:
|
||||
statement = statement.where(self.user_model.student_id == student_id)
|
||||
if phone:
|
||||
@ -117,3 +116,10 @@ class UserRepo(AsyncInitializingComponent):
|
||||
await session.commit()
|
||||
await session.refresh(rule)
|
||||
return rule
|
||||
|
||||
async def create_login_history(self, login_history: "LoginHistory"):
|
||||
async with AsyncSession(self.engine) as session:
|
||||
session.add(login_history)
|
||||
await session.commit()
|
||||
await session.refresh(login_history)
|
||||
return login_history
|
||||
|
@ -3,7 +3,7 @@ from typing import Optional
|
||||
|
||||
from fastapi_amis_admin.utils.pydantic import PYDANTIC_V2
|
||||
from fastapi_amis_admin.utils.translation import i18n as _
|
||||
from fastapi_user_auth.mixins.models import UsernameMixin, PasswordMixin, EmailMixin
|
||||
from fastapi_user_auth.mixins.models import UsernameMixin, PasswordMixin
|
||||
from fastapi_user_auth.utils.sqltypes import SecretStrType
|
||||
from pydantic import BaseModel, SecretStr
|
||||
from sqlmodel import Field
|
||||
@ -29,7 +29,7 @@ class UserLoginOut(UserModel):
|
||||
)
|
||||
|
||||
|
||||
class UserRegIn(UsernameMixin, PasswordMixin, EmailMixin, StudentIdMixin, PhoneMixin):
|
||||
class UserRegIn(UsernameMixin, PasswordMixin, StudentIdMixin, PhoneMixin):
|
||||
"""用户注册"""
|
||||
|
||||
password2: str = Field(title=_("Confirm Password"), max_length=128)
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi_user_auth.auth.models import CasbinRule
|
||||
from fastapi_user_auth.auth.models import CasbinRule, LoginHistory
|
||||
from persica.factory.component import AsyncInitializingComponent
|
||||
from pydantic import SecretStr
|
||||
|
||||
@ -10,6 +10,8 @@ from .schemas import UserRoleEnum
|
||||
|
||||
|
||||
class UserServices(AsyncInitializingComponent):
|
||||
__order__ = 1
|
||||
|
||||
def __init__(self, repo: UserRepo):
|
||||
self.repo = repo
|
||||
self.user_model = UserModel
|
||||
@ -19,6 +21,8 @@ class UserServices(AsyncInitializingComponent):
|
||||
async def initialize(self):
|
||||
for g in UserRoleEnum.__members__.keys():
|
||||
key = g.lower()
|
||||
if key == "admin":
|
||||
continue
|
||||
if await self.get_role(key=key) is None:
|
||||
await self.create_role(key, f"{key} role")
|
||||
print(f"Create role: {key}")
|
||||
@ -27,12 +31,14 @@ class UserServices(AsyncInitializingComponent):
|
||||
self,
|
||||
username: str,
|
||||
password: SecretStr,
|
||||
email: str,
|
||||
student_id: Optional[str],
|
||||
phone: Optional[str],
|
||||
):
|
||||
return await self.repo.register_user(
|
||||
username, password, email, student_id, phone
|
||||
username,
|
||||
password,
|
||||
student_id,
|
||||
phone,
|
||||
)
|
||||
|
||||
async def login_user(self, user: "UserModel") -> str:
|
||||
@ -41,11 +47,10 @@ class UserServices(AsyncInitializingComponent):
|
||||
async def get_user(
|
||||
self,
|
||||
username: Optional[str] = None,
|
||||
email: Optional[str] = None,
|
||||
student_id: Optional[str] = None,
|
||||
phone: Optional[str] = None,
|
||||
) -> Optional[UserModel]:
|
||||
return await self.repo.get_user(username, email, student_id, phone)
|
||||
return await self.repo.get_user(username, student_id, phone)
|
||||
|
||||
async def get_role(
|
||||
self, rid: Optional[int] = None, key: Optional[str] = None
|
||||
@ -75,3 +80,16 @@ class UserServices(AsyncInitializingComponent):
|
||||
async def add_user_to_role_group(self, username: str, role_key: str) -> CasbinRule:
|
||||
rule = self.rule_model(ptype="g", v0=f"u:{username}", v1=f"r:{role_key}")
|
||||
return await self.repo.create_role_rule(rule)
|
||||
|
||||
async def create_login_history(
|
||||
self, user: "UserModel", ip: str, ua: str, forwarded_for: str
|
||||
):
|
||||
history = LoginHistory(
|
||||
user_id=user.id,
|
||||
login_name=user.username,
|
||||
ip=ip,
|
||||
user_agent=ua,
|
||||
login_status="登录成功",
|
||||
forwarded_for=forwarded_for,
|
||||
)
|
||||
return await self.repo.create_login_history(history)
|
||||
|
Loading…
Reference in New Issue
Block a user