feat: login
This commit is contained in:
parent
2b7f89388b
commit
a5e58557e3
@ -1,8 +1,8 @@
|
|||||||
"""users
|
"""users
|
||||||
|
|
||||||
Revision ID: 089138f9c051
|
Revision ID: 3785e9a2a0c0
|
||||||
Revises:
|
Revises:
|
||||||
Create Date: 2024-11-04 15:31:52.096235
|
Create Date: 2024-11-04 19:12:43.374773
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -12,7 +12,7 @@ import sqlmodel
|
|||||||
from fastapi_user_auth.utils.sqltypes import SecretStrType
|
from fastapi_user_auth.utils.sqltypes import SecretStrType
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision = "089138f9c051"
|
revision = "3785e9a2a0c0"
|
||||||
down_revision = None
|
down_revision = None
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
@ -162,12 +162,12 @@ def upgrade() -> None:
|
|||||||
sqlmodel.sql.sqltypes.AutoString(length=255),
|
sqlmodel.sql.sqltypes.AutoString(length=255),
|
||||||
nullable=True,
|
nullable=True,
|
||||||
),
|
),
|
||||||
|
sa.Column("phone", sqlmodel.sql.sqltypes.AutoString(length=15), nullable=True),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"student_id",
|
"student_id",
|
||||||
sqlmodel.sql.sqltypes.AutoString(length=15),
|
sqlmodel.sql.sqltypes.AutoString(length=15),
|
||||||
nullable=True,
|
nullable=True,
|
||||||
),
|
),
|
||||||
sa.Column("phone", sqlmodel.sql.sqltypes.AutoString(length=15), nullable=True),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
sa.PrimaryKeyConstraint("id"),
|
||||||
)
|
)
|
||||||
op.create_index(
|
op.create_index(
|
@ -10,6 +10,8 @@ from src.env import config
|
|||||||
|
|
||||||
|
|
||||||
class WebApp(AsyncInitializingComponent):
|
class WebApp(AsyncInitializingComponent):
|
||||||
|
__order__ = 3
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
dependencies = []
|
dependencies = []
|
||||||
self.app = FastAPI(dependencies=dependencies)
|
self.app = FastAPI(dependencies=dependencies)
|
||||||
|
@ -1,20 +1,39 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from fastapi_amis_admin.crud import BaseApiOut
|
from fastapi_amis_admin.crud import BaseApiOut
|
||||||
from persica.factory.component import AsyncInitializingComponent
|
from persica.factory.component import AsyncInitializingComponent
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException, Depends
|
||||||
from starlette import status
|
from starlette import status
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import Response
|
||||||
|
|
||||||
from src.core.web_app import WebApp
|
from src.core.web_app import WebApp
|
||||||
from src.services.users.schemas import UserRegIn, SystemUserEnum, UserLoginOut
|
from src.services.users.schemas import UserRegIn, SystemUserEnum, UserLoginOut
|
||||||
from src.services.users.services import UserServices
|
from src.services.users.services import UserServices
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from fastapi_user_auth.auth import Auth
|
||||||
|
|
||||||
|
|
||||||
class UserRoutes(AsyncInitializingComponent):
|
class UserRoutes(AsyncInitializingComponent):
|
||||||
|
__order__ = 2
|
||||||
|
|
||||||
def __init__(self, app: WebApp, user_services: UserServices):
|
def __init__(self, app: WebApp, user_services: UserServices):
|
||||||
self.router = APIRouter(prefix="/user")
|
self.router = APIRouter(prefix="/user")
|
||||||
self.router.add_api_route("/register", self.register, methods=["POST"])
|
self.router.add_api_route("/register", self.register, methods=["POST"])
|
||||||
|
self.router.add_api_route("/login", self.login, methods=["POST"])
|
||||||
self.user_services = user_services
|
self.user_services = user_services
|
||||||
app.app.include_router(self.router)
|
self.app = app.app
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
self.router.add_api_route(
|
||||||
|
"/need_login",
|
||||||
|
self.need_login,
|
||||||
|
methods=["GET"],
|
||||||
|
dependencies=[Depends(self.user_services.repo.AUTH.requires("admin")())],
|
||||||
|
)
|
||||||
|
self.app.include_router(self.router)
|
||||||
|
|
||||||
async def register(self, data: UserRegIn):
|
async def register(self, data: UserRegIn):
|
||||||
if data.username.upper() in SystemUserEnum.__members__:
|
if data.username.upper() in SystemUserEnum.__members__:
|
||||||
@ -22,9 +41,6 @@ class UserRoutes(AsyncInitializingComponent):
|
|||||||
user = await self.user_services.get_user(username=data.username)
|
user = await self.user_services.get_user(username=data.username)
|
||||||
if user:
|
if user:
|
||||||
return BaseApiOut(status=-1, msg="用户名已被注册", data=None)
|
return BaseApiOut(status=-1, msg="用户名已被注册", data=None)
|
||||||
user = await self.user_services.get_user(email=data.email)
|
|
||||||
if user:
|
|
||||||
return BaseApiOut(status=-1, msg="邮箱已被注册", data=None)
|
|
||||||
role = "student"
|
role = "student"
|
||||||
if not (data.student_id or data.phone):
|
if not (data.student_id or data.phone):
|
||||||
return BaseApiOut(status=-1, msg="学号或手机号至少填写一项", data=None)
|
return BaseApiOut(status=-1, msg="学号或手机号至少填写一项", data=None)
|
||||||
@ -43,7 +59,6 @@ class UserRoutes(AsyncInitializingComponent):
|
|||||||
user = await self.user_services.register_user(
|
user = await self.user_services.register_user(
|
||||||
username=data.username,
|
username=data.username,
|
||||||
password=data.password,
|
password=data.password,
|
||||||
email=data.email,
|
|
||||||
student_id=data.student_id,
|
student_id=data.student_id,
|
||||||
phone=data.phone,
|
phone=data.phone,
|
||||||
)
|
)
|
||||||
@ -58,3 +73,51 @@ class UserRoutes(AsyncInitializingComponent):
|
|||||||
token_info = UserLoginOut.model_validate(user)
|
token_info = UserLoginOut.model_validate(user)
|
||||||
token_info.access_token = await self.user_services.login_user(user)
|
token_info.access_token = await self.user_services.login_user(user)
|
||||||
return BaseApiOut(code=0, msg="注册成功", data=token_info)
|
return BaseApiOut(code=0, msg="注册成功", data=token_info)
|
||||||
|
|
||||||
|
async def create_login_history(self, request: "Request"):
|
||||||
|
# 保存登录记录
|
||||||
|
ip = request.client.host # 获取真实ip
|
||||||
|
# 获取代理ip
|
||||||
|
ips = [
|
||||||
|
request.headers.get(key, "").strip()
|
||||||
|
for key in ["x-forwarded-for", "x-real-ip", "x-client-ip", "remote-host"]
|
||||||
|
]
|
||||||
|
forwarded_for = ",".join([i for i in set(ips) if i and i != ip])
|
||||||
|
ua = request.headers.get("user-agent", "")
|
||||||
|
return await self.user_services.create_login_history(
|
||||||
|
request.scope.get("user"), ip, ua, forwarded_for
|
||||||
|
)
|
||||||
|
|
||||||
|
async def login(
|
||||||
|
self, request: Request, response: Response, username: str, password: str
|
||||||
|
):
|
||||||
|
auth: "Auth" = request.auth
|
||||||
|
if request.scope.get("user"):
|
||||||
|
return BaseApiOut(
|
||||||
|
code=1, msg="用户已登录", data=UserLoginOut.model_validate(request.user)
|
||||||
|
)
|
||||||
|
user = await auth.authenticate_user(username=username, password=password)
|
||||||
|
if not user:
|
||||||
|
return BaseApiOut(status=-1, msg="用户名或密码错误")
|
||||||
|
if not user.is_active:
|
||||||
|
return BaseApiOut(status=-2, msg="用户未激活")
|
||||||
|
request.scope["user"] = user
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.create_login_history(request)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Error Execute SQL:{e}",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
token_info = UserLoginOut.model_validate(request.user)
|
||||||
|
token_info.access_token = await auth.backend.token_store.write_token(
|
||||||
|
request.user.dict()
|
||||||
|
)
|
||||||
|
response.set_cookie("Authorization", f"bearer {token_info.access_token}")
|
||||||
|
return BaseApiOut(code=0, data=token_info)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def need_login():
|
||||||
|
return {}
|
||||||
|
@ -15,3 +15,5 @@ class PhoneMixin(SQLModel):
|
|||||||
|
|
||||||
class UserModel(BaseUser, StudentIdMixin, PhoneMixin, table=True):
|
class UserModel(BaseUser, StudentIdMixin, PhoneMixin, table=True):
|
||||||
__table_args__ = {"extend_existing": True}
|
__table_args__ = {"extend_existing": True}
|
||||||
|
|
||||||
|
email: Optional[str] = None
|
||||||
|
@ -2,7 +2,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from fastapi_user_auth.auth import Auth
|
from fastapi_user_auth.auth import Auth
|
||||||
from fastapi_user_auth.auth.backends.redis import RedisTokenStore
|
from fastapi_user_auth.auth.backends.redis import RedisTokenStore
|
||||||
from fastapi_user_auth.auth.models import CasbinRule
|
from fastapi_user_auth.auth.models import CasbinRule, LoginHistory
|
||||||
from persica.factory.component import AsyncInitializingComponent
|
from persica.factory.component import AsyncInitializingComponent
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
@ -15,28 +15,31 @@ from src.services.users.models import UserModel, RoleModel
|
|||||||
|
|
||||||
|
|
||||||
class UserRepo(AsyncInitializingComponent):
|
class UserRepo(AsyncInitializingComponent):
|
||||||
|
__order__ = 1
|
||||||
AUTH: Auth = None
|
AUTH: Auth = None
|
||||||
|
|
||||||
def __init__(self, app: WebApp, database: Database, redis: RedisDB):
|
def __init__(self, app: WebApp, database: Database, redis: RedisDB):
|
||||||
self.engine = database.engine
|
self.engine = database.engine
|
||||||
self.AUTH = Auth(
|
self.database = database
|
||||||
database.db,
|
self.redis = redis
|
||||||
token_store=RedisTokenStore(redis.client),
|
self.app = app
|
||||||
user_model=UserModel,
|
|
||||||
)
|
|
||||||
self.AUTH.backend.attach_middleware(app.app)
|
|
||||||
self.user_model = UserModel
|
self.user_model = UserModel
|
||||||
self.role_model = RoleModel
|
self.role_model = RoleModel
|
||||||
self.rule_model = CasbinRule
|
self.rule_model = CasbinRule
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
|
self.AUTH = Auth(
|
||||||
|
self.database.db,
|
||||||
|
token_store=RedisTokenStore(self.redis.client),
|
||||||
|
user_model=UserModel,
|
||||||
|
)
|
||||||
|
self.AUTH.backend.attach_middleware(self.app.app)
|
||||||
await self.AUTH.create_role_user("admin")
|
await self.AUTH.create_role_user("admin")
|
||||||
|
|
||||||
async def register_user(
|
async def register_user(
|
||||||
self,
|
self,
|
||||||
username: str,
|
username: str,
|
||||||
password: SecretStr,
|
password: SecretStr,
|
||||||
email: str,
|
|
||||||
student_id: Optional[str],
|
student_id: Optional[str],
|
||||||
phone: Optional[str],
|
phone: Optional[str],
|
||||||
):
|
):
|
||||||
@ -44,7 +47,6 @@ class UserRepo(AsyncInitializingComponent):
|
|||||||
values = {
|
values = {
|
||||||
"username": username,
|
"username": username,
|
||||||
"password": password,
|
"password": password,
|
||||||
"email": email,
|
|
||||||
"student_id": student_id,
|
"student_id": student_id,
|
||||||
"phone": phone,
|
"phone": phone,
|
||||||
}
|
}
|
||||||
@ -58,7 +60,6 @@ class UserRepo(AsyncInitializingComponent):
|
|||||||
async def get_user(
|
async def get_user(
|
||||||
self,
|
self,
|
||||||
username: Optional[str] = None,
|
username: Optional[str] = None,
|
||||||
email: Optional[str] = None,
|
|
||||||
student_id: Optional[str] = None,
|
student_id: Optional[str] = None,
|
||||||
phone: Optional[str] = None,
|
phone: Optional[str] = None,
|
||||||
) -> Optional[UserModel]:
|
) -> Optional[UserModel]:
|
||||||
@ -66,8 +67,6 @@ class UserRepo(AsyncInitializingComponent):
|
|||||||
statement = select(self.user_model)
|
statement = select(self.user_model)
|
||||||
if username:
|
if username:
|
||||||
statement = statement.where(self.user_model.username == username)
|
statement = statement.where(self.user_model.username == username)
|
||||||
if email:
|
|
||||||
statement = statement.where(self.user_model.email == email)
|
|
||||||
if student_id:
|
if student_id:
|
||||||
statement = statement.where(self.user_model.student_id == student_id)
|
statement = statement.where(self.user_model.student_id == student_id)
|
||||||
if phone:
|
if phone:
|
||||||
@ -117,3 +116,10 @@ class UserRepo(AsyncInitializingComponent):
|
|||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(rule)
|
await session.refresh(rule)
|
||||||
return rule
|
return rule
|
||||||
|
|
||||||
|
async def create_login_history(self, login_history: "LoginHistory"):
|
||||||
|
async with AsyncSession(self.engine) as session:
|
||||||
|
session.add(login_history)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(login_history)
|
||||||
|
return login_history
|
||||||
|
@ -3,7 +3,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from fastapi_amis_admin.utils.pydantic import PYDANTIC_V2
|
from fastapi_amis_admin.utils.pydantic import PYDANTIC_V2
|
||||||
from fastapi_amis_admin.utils.translation import i18n as _
|
from fastapi_amis_admin.utils.translation import i18n as _
|
||||||
from fastapi_user_auth.mixins.models import UsernameMixin, PasswordMixin, EmailMixin
|
from fastapi_user_auth.mixins.models import UsernameMixin, PasswordMixin
|
||||||
from fastapi_user_auth.utils.sqltypes import SecretStrType
|
from fastapi_user_auth.utils.sqltypes import SecretStrType
|
||||||
from pydantic import BaseModel, SecretStr
|
from pydantic import BaseModel, SecretStr
|
||||||
from sqlmodel import Field
|
from sqlmodel import Field
|
||||||
@ -29,7 +29,7 @@ class UserLoginOut(UserModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class UserRegIn(UsernameMixin, PasswordMixin, EmailMixin, StudentIdMixin, PhoneMixin):
|
class UserRegIn(UsernameMixin, PasswordMixin, StudentIdMixin, PhoneMixin):
|
||||||
"""用户注册"""
|
"""用户注册"""
|
||||||
|
|
||||||
password2: str = Field(title=_("Confirm Password"), max_length=128)
|
password2: str = Field(title=_("Confirm Password"), max_length=128)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi_user_auth.auth.models import CasbinRule
|
from fastapi_user_auth.auth.models import CasbinRule, LoginHistory
|
||||||
from persica.factory.component import AsyncInitializingComponent
|
from persica.factory.component import AsyncInitializingComponent
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
@ -10,6 +10,8 @@ from .schemas import UserRoleEnum
|
|||||||
|
|
||||||
|
|
||||||
class UserServices(AsyncInitializingComponent):
|
class UserServices(AsyncInitializingComponent):
|
||||||
|
__order__ = 1
|
||||||
|
|
||||||
def __init__(self, repo: UserRepo):
|
def __init__(self, repo: UserRepo):
|
||||||
self.repo = repo
|
self.repo = repo
|
||||||
self.user_model = UserModel
|
self.user_model = UserModel
|
||||||
@ -19,6 +21,8 @@ class UserServices(AsyncInitializingComponent):
|
|||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
for g in UserRoleEnum.__members__.keys():
|
for g in UserRoleEnum.__members__.keys():
|
||||||
key = g.lower()
|
key = g.lower()
|
||||||
|
if key == "admin":
|
||||||
|
continue
|
||||||
if await self.get_role(key=key) is None:
|
if await self.get_role(key=key) is None:
|
||||||
await self.create_role(key, f"{key} role")
|
await self.create_role(key, f"{key} role")
|
||||||
print(f"Create role: {key}")
|
print(f"Create role: {key}")
|
||||||
@ -27,12 +31,14 @@ class UserServices(AsyncInitializingComponent):
|
|||||||
self,
|
self,
|
||||||
username: str,
|
username: str,
|
||||||
password: SecretStr,
|
password: SecretStr,
|
||||||
email: str,
|
|
||||||
student_id: Optional[str],
|
student_id: Optional[str],
|
||||||
phone: Optional[str],
|
phone: Optional[str],
|
||||||
):
|
):
|
||||||
return await self.repo.register_user(
|
return await self.repo.register_user(
|
||||||
username, password, email, student_id, phone
|
username,
|
||||||
|
password,
|
||||||
|
student_id,
|
||||||
|
phone,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def login_user(self, user: "UserModel") -> str:
|
async def login_user(self, user: "UserModel") -> str:
|
||||||
@ -41,11 +47,10 @@ class UserServices(AsyncInitializingComponent):
|
|||||||
async def get_user(
|
async def get_user(
|
||||||
self,
|
self,
|
||||||
username: Optional[str] = None,
|
username: Optional[str] = None,
|
||||||
email: Optional[str] = None,
|
|
||||||
student_id: Optional[str] = None,
|
student_id: Optional[str] = None,
|
||||||
phone: Optional[str] = None,
|
phone: Optional[str] = None,
|
||||||
) -> Optional[UserModel]:
|
) -> Optional[UserModel]:
|
||||||
return await self.repo.get_user(username, email, student_id, phone)
|
return await self.repo.get_user(username, student_id, phone)
|
||||||
|
|
||||||
async def get_role(
|
async def get_role(
|
||||||
self, rid: Optional[int] = None, key: Optional[str] = None
|
self, rid: Optional[int] = None, key: Optional[str] = None
|
||||||
@ -75,3 +80,16 @@ class UserServices(AsyncInitializingComponent):
|
|||||||
async def add_user_to_role_group(self, username: str, role_key: str) -> CasbinRule:
|
async def add_user_to_role_group(self, username: str, role_key: str) -> CasbinRule:
|
||||||
rule = self.rule_model(ptype="g", v0=f"u:{username}", v1=f"r:{role_key}")
|
rule = self.rule_model(ptype="g", v0=f"u:{username}", v1=f"r:{role_key}")
|
||||||
return await self.repo.create_role_rule(rule)
|
return await self.repo.create_role_rule(rule)
|
||||||
|
|
||||||
|
async def create_login_history(
|
||||||
|
self, user: "UserModel", ip: str, ua: str, forwarded_for: str
|
||||||
|
):
|
||||||
|
history = LoginHistory(
|
||||||
|
user_id=user.id,
|
||||||
|
login_name=user.username,
|
||||||
|
ip=ip,
|
||||||
|
user_agent=ua,
|
||||||
|
login_status="登录成功",
|
||||||
|
forwarded_for=forwarded_for,
|
||||||
|
)
|
||||||
|
return await self.repo.create_login_history(history)
|
||||||
|
Loading…
Reference in New Issue
Block a user