This commit is contained in:
xtaodada 2023-04-14 22:09:47 +08:00
parent ecc2d8929f
commit 51c9e6a6bc
Signed by: xtaodada
GPG Key ID: 4CBB3F4FA8C85659
17 changed files with 335 additions and 0 deletions

3
.gitignore vendored
View File

@ -127,3 +127,6 @@ dmypy.json
# Pyre type checker
.pyre/
.idea/
.env.example
data/

8
apis/check_login.py Normal file
View File

@ -0,0 +1,8 @@
from defs import app, need_auth_routes
@app.get("/check_login")
async def check_login():
return {"code": 200, "msg": "登录状态有效"}
need_auth_routes.append("/check_login")

24
apis/check_session.py Normal file
View File

@ -0,0 +1,24 @@
from fastapi import Request
from fastapi.responses import JSONResponse
from defs import app, need_auth_routes
from models.services.session import SessionAction
@app.middleware("http")
async def check_session_middleware(request: Request, call_next):
if request.url.path not in need_auth_routes:
return await call_next(request)
uid = request.cookies.get("uid")
session = request.cookies.get("session")
try:
if not uid or not session:
raise ValueError
uid = int(uid)
session = str(session)
auth_success = await SessionAction.check_session(uid, session)
if not auth_success:
raise ValueError
except ValueError:
return JSONResponse(status_code=401, content={"code": 401, "msg": "Cookie 无效"})
return await call_next(request)

35
apis/login.py Normal file
View File

@ -0,0 +1,35 @@
from errors.user import UserNotFoundError, UserPasswordIncorrectError
from fastapi import Response
from defs import app
from models.services.session import SessionAction
from models.services.user import UserAction
from utils.user import User
async def authenticate_user(username: str, password: str) -> int:
user = await UserAction.get_user_by_username(username)
if user is None:
raise UserNotFoundError
if user.password != password:
raise UserPasswordIncorrectError
return user.uid
async def update_session(uid: int, session: str):
await SessionAction.update_session(uid, session)
@app.post("/login")
async def login(user: User, response: Response):
try:
uid = await authenticate_user(user.username, user.password)
except UserNotFoundError:
return {"code": 403, "msg": "用户不存在"}
except UserPasswordIncorrectError:
return {"code": 403, "msg": "用户名或密码错误"}
session = SessionAction.gen_session()
await update_session(uid, session)
response.set_cookie(key="uid", value=str(uid))
response.set_cookie(key="session", value=session)
return {"code": 200, "msg": "登录成功", "data": {"uid": str(uid), "session": session}}

25
apis/reg.py Normal file
View File

@ -0,0 +1,25 @@
from errors.user import UserAlreadyExistsError
from defs import app
from models.services.user import UserAction
from utils.user import User
async def reg_user(username: str, password: str) -> int:
user = await UserAction.get_user_by_username(username)
if user:
raise UserAlreadyExistsError
user = UserAction.gen_new_user(
username,
password,
)
await UserAction.add_user(user)
@app.post("/reg")
async def reg(user: User):
try:
await reg_user(user.username, user.password)
except UserAlreadyExistsError:
return {"code": 409, "msg": "用户已存在"}
return {"code": 200, "msg": "注册成功"}

10
defs.py Normal file
View File

@ -0,0 +1,10 @@
import asyncio
from fastapi import FastAPI
from models.sqlite import Sqlite
loop = asyncio.get_event_loop()
app = FastAPI()
sqlite = Sqlite()
need_auth_routes = []

0
errors/__init__.py Normal file
View File

18
errors/user.py Normal file
View File

@ -0,0 +1,18 @@
class UserException(Exception):
def __init__(self, message):
self.message = message
class UserNotFoundError(UserException):
def __init__(self, message="User not found"):
super().__init__(message)
class UserAlreadyExistsError(UserException):
def __init__(self, message="User already exists"):
super().__init__(message)
class UserPasswordIncorrectError(UserException):
def __init__(self, message="User password incorrect"):
super().__init__(message)

24
main.py Normal file
View File

@ -0,0 +1,24 @@
import importlib
import os
import uvicorn
from settings import HOST, PORT
from defs import app, sqlite, loop
# 遍历 apis 文件夹下的所有文件,并且使用 importlib 导入
# 从而实现自动导入
for filename in os.listdir("apis"):
if filename.endswith(".py"):
importlib.import_module(f"apis.{filename[:-3]}")
async def main():
await sqlite.create_db_and_tables()
server = uvicorn.Server(
config=uvicorn.Config(app, host=HOST, port=PORT)
)
await server.serve()
if __name__ == "__main__":
loop.run_until_complete(main())

13
models/models/user.py Normal file
View File

@ -0,0 +1,13 @@
from sqlmodel import SQLModel, Field
class User(SQLModel, table=True):
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
uid: int = Field(primary_key=True, default=None)
username: str = Field(default="")
password: str = Field(default="")
is_admin: bool = Field(default=False)
register_time: int = Field(default="")
last_login_time: int = Field(default="")
session: str = Field(default="")

View File

@ -0,0 +1,45 @@
import secrets
import string
import time
from typing import cast
from sqlalchemy import select
from sqlmodel.ext.asyncio.session import AsyncSession
from defs import sqlite
from models.models.user import User
class SessionAction:
@staticmethod
async def check_session(uid: int, session_value: str) -> bool:
async with sqlite.session() as session:
session = cast(AsyncSession, session)
statement = select(User).where(User.uid == uid)
results = await session.exec(statement)
user_: User = user[0] if (user := results.first()) else None
return False if user_ is None else user_.session == session_value
@staticmethod
async def update_session(uid: int, session_value: str) -> None:
async with sqlite.session() as session:
session = cast(AsyncSession, session)
statement = select(User).where(User.uid == uid)
results = await session.exec(statement)
user_: User = user[0] if (user := results.first()) else None
if user_ is None:
return
user_.last_login_time = int(time.time())
user_.session = session_value
await session.commit()
await session.refresh(user_)
@staticmethod
def gen_session() -> str:
return ''.join(
secrets.choice(
string.ascii_uppercase + string.ascii_lowercase + string.digits
)
for _ in range(30)
)

76
models/services/user.py Normal file
View File

@ -0,0 +1,76 @@
import time
from typing import cast, Optional
from sqlalchemy import select
from sqlmodel.ext.asyncio.session import AsyncSession
from defs import sqlite
from models.models.user import User
class UserAction:
@staticmethod
async def add_user(user: User):
async with sqlite.session() as session:
session = cast(AsyncSession, session)
session.add(user)
await session.commit()
@staticmethod
async def get_user_by_username(username: str) -> Optional[User]:
async with sqlite.session() as session:
session = cast(AsyncSession, session)
statement = select(User).where(User.username == username)
results = await session.exec(statement)
return user[0] if (user := results.first()) else None
@staticmethod
async def update_user(old_user: User, new_user: User = None):
if new_user:
old_user.username = new_user.username
old_user.password = new_user.password
old_user.is_admin = new_user.is_admin
old_user.register_time = new_user.register_time
old_user.last_login_time = new_user.last_login_time
old_user.session = new_user.session
async with sqlite.session() as session:
session = cast(AsyncSession, session)
session.add(old_user)
await session.commit()
await session.refresh(old_user)
@staticmethod
async def add_or_update_user(user: User):
if old_user := await UserAction.get_user_by_username(user.username):
await UserAction.update_user(old_user, user)
else:
await UserAction.add_user(user)
@staticmethod
async def change_user_password(username: str, password: str) -> bool:
user = await UserAction.get_user_by_username(username)
if not user:
return False
user.password = password
await UserAction.update_user(user)
return True
@staticmethod
def gen_new_user(
username: str,
password: str,
is_admin: bool = False,
register_time: int = 0,
last_login_time: int = 0,
session: str = "",
) -> User:
if not register_time:
register_time = int(time.time())
return User(
username=username,
password=password,
is_admin=is_admin,
register_time=register_time,
last_login_time=last_login_time,
session=session,
)

30
models/sqlite.py Normal file
View File

@ -0,0 +1,30 @@
from sqlmodel import SQLModel
from models.models.user import User
from pathlib import Path
__all__ = ["User", "Sqlite"]
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlmodel.ext.asyncio.session import AsyncSession
DataPath = Path("data")
DataPath.mkdir(exist_ok=True, parents=True)
class Sqlite:
def __init__(self):
self.engine = create_async_engine("sqlite+aiosqlite:///data/data.db")
self.session = sessionmaker(bind=self.engine, class_=AsyncSession)
async def create_db_and_tables(self):
async with self.engine.begin() as session:
await session.run_sync(SQLModel.metadata.create_all)
async def get_session(self):
async with self.session() as session:
yield session
def stop(self):
self.session.close_all()

11
requirements.txt Normal file
View File

@ -0,0 +1,11 @@
httpx==0.23.3
fastapi==0.94.1
python-multipart
starlette==0.26.1
uvicorn==0.21.0
sqlalchemy==1.4.41
sqlmodel==0.0.8
aiosqlite==0.18.0
pydantic~=1.10.6
python-dotenv==1.0.0
aiofiles==23.1.0

7
settings.py Normal file
View File

@ -0,0 +1,7 @@
import os
from dotenv import load_dotenv
load_dotenv()
HOST = os.getenv('HOST')
PORT = int(os.getenv('PORT'))

0
utils/__init__.py Normal file
View File

6
utils/user.py Normal file
View File

@ -0,0 +1,6 @@
from pydantic import BaseModel
class User(BaseModel):
username: str
password: str