init
This commit is contained in:
parent
ecc2d8929f
commit
51c9e6a6bc
3
.gitignore
vendored
3
.gitignore
vendored
@ -127,3 +127,6 @@ dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
.idea/
|
||||
.env.example
|
||||
data/
|
||||
|
8
apis/check_login.py
Normal file
8
apis/check_login.py
Normal file
@ -0,0 +1,8 @@
|
||||
from defs import app, need_auth_routes
|
||||
|
||||
|
||||
@app.get("/check_login")
|
||||
async def check_login():
|
||||
return {"code": 200, "msg": "登录状态有效"}
|
||||
|
||||
need_auth_routes.append("/check_login")
|
24
apis/check_session.py
Normal file
24
apis/check_session.py
Normal file
@ -0,0 +1,24 @@
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from defs import app, need_auth_routes
|
||||
from models.services.session import SessionAction
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def check_session_middleware(request: Request, call_next):
|
||||
if request.url.path not in need_auth_routes:
|
||||
return await call_next(request)
|
||||
uid = request.cookies.get("uid")
|
||||
session = request.cookies.get("session")
|
||||
try:
|
||||
if not uid or not session:
|
||||
raise ValueError
|
||||
uid = int(uid)
|
||||
session = str(session)
|
||||
auth_success = await SessionAction.check_session(uid, session)
|
||||
if not auth_success:
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
return JSONResponse(status_code=401, content={"code": 401, "msg": "Cookie 无效"})
|
||||
return await call_next(request)
|
35
apis/login.py
Normal file
35
apis/login.py
Normal file
@ -0,0 +1,35 @@
|
||||
from errors.user import UserNotFoundError, UserPasswordIncorrectError
|
||||
from fastapi import Response
|
||||
|
||||
from defs import app
|
||||
from models.services.session import SessionAction
|
||||
from models.services.user import UserAction
|
||||
from utils.user import User
|
||||
|
||||
|
||||
async def authenticate_user(username: str, password: str) -> int:
|
||||
user = await UserAction.get_user_by_username(username)
|
||||
if user is None:
|
||||
raise UserNotFoundError
|
||||
if user.password != password:
|
||||
raise UserPasswordIncorrectError
|
||||
return user.uid
|
||||
|
||||
|
||||
async def update_session(uid: int, session: str):
|
||||
await SessionAction.update_session(uid, session)
|
||||
|
||||
|
||||
@app.post("/login")
|
||||
async def login(user: User, response: Response):
|
||||
try:
|
||||
uid = await authenticate_user(user.username, user.password)
|
||||
except UserNotFoundError:
|
||||
return {"code": 403, "msg": "用户不存在"}
|
||||
except UserPasswordIncorrectError:
|
||||
return {"code": 403, "msg": "用户名或密码错误"}
|
||||
session = SessionAction.gen_session()
|
||||
await update_session(uid, session)
|
||||
response.set_cookie(key="uid", value=str(uid))
|
||||
response.set_cookie(key="session", value=session)
|
||||
return {"code": 200, "msg": "登录成功", "data": {"uid": str(uid), "session": session}}
|
25
apis/reg.py
Normal file
25
apis/reg.py
Normal file
@ -0,0 +1,25 @@
|
||||
from errors.user import UserAlreadyExistsError
|
||||
|
||||
from defs import app
|
||||
from models.services.user import UserAction
|
||||
from utils.user import User
|
||||
|
||||
|
||||
async def reg_user(username: str, password: str) -> int:
|
||||
user = await UserAction.get_user_by_username(username)
|
||||
if user:
|
||||
raise UserAlreadyExistsError
|
||||
user = UserAction.gen_new_user(
|
||||
username,
|
||||
password,
|
||||
)
|
||||
await UserAction.add_user(user)
|
||||
|
||||
|
||||
@app.post("/reg")
|
||||
async def reg(user: User):
|
||||
try:
|
||||
await reg_user(user.username, user.password)
|
||||
except UserAlreadyExistsError:
|
||||
return {"code": 409, "msg": "用户已存在"}
|
||||
return {"code": 200, "msg": "注册成功"}
|
10
defs.py
Normal file
10
defs.py
Normal file
@ -0,0 +1,10 @@
|
||||
import asyncio
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from models.sqlite import Sqlite
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
app = FastAPI()
|
||||
sqlite = Sqlite()
|
||||
need_auth_routes = []
|
0
errors/__init__.py
Normal file
0
errors/__init__.py
Normal file
18
errors/user.py
Normal file
18
errors/user.py
Normal file
@ -0,0 +1,18 @@
|
||||
class UserException(Exception):
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
|
||||
|
||||
class UserNotFoundError(UserException):
|
||||
def __init__(self, message="User not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class UserAlreadyExistsError(UserException):
|
||||
def __init__(self, message="User already exists"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class UserPasswordIncorrectError(UserException):
|
||||
def __init__(self, message="User password incorrect"):
|
||||
super().__init__(message)
|
24
main.py
Normal file
24
main.py
Normal file
@ -0,0 +1,24 @@
|
||||
import importlib
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
from settings import HOST, PORT
|
||||
from defs import app, sqlite, loop
|
||||
|
||||
# 遍历 apis 文件夹下的所有文件,并且使用 importlib 导入
|
||||
# 从而实现自动导入
|
||||
for filename in os.listdir("apis"):
|
||||
if filename.endswith(".py"):
|
||||
importlib.import_module(f"apis.{filename[:-3]}")
|
||||
|
||||
|
||||
async def main():
|
||||
await sqlite.create_db_and_tables()
|
||||
server = uvicorn.Server(
|
||||
config=uvicorn.Config(app, host=HOST, port=PORT)
|
||||
)
|
||||
await server.serve()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop.run_until_complete(main())
|
13
models/models/user.py
Normal file
13
models/models/user.py
Normal file
@ -0,0 +1,13 @@
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
|
||||
class User(SQLModel, table=True):
|
||||
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
||||
|
||||
uid: int = Field(primary_key=True, default=None)
|
||||
username: str = Field(default="")
|
||||
password: str = Field(default="")
|
||||
is_admin: bool = Field(default=False)
|
||||
register_time: int = Field(default="")
|
||||
last_login_time: int = Field(default="")
|
||||
session: str = Field(default="")
|
45
models/services/session.py
Normal file
45
models/services/session.py
Normal file
@ -0,0 +1,45 @@
|
||||
import secrets
|
||||
import string
|
||||
import time
|
||||
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from defs import sqlite
|
||||
from models.models.user import User
|
||||
|
||||
|
||||
class SessionAction:
|
||||
@staticmethod
|
||||
async def check_session(uid: int, session_value: str) -> bool:
|
||||
async with sqlite.session() as session:
|
||||
session = cast(AsyncSession, session)
|
||||
statement = select(User).where(User.uid == uid)
|
||||
results = await session.exec(statement)
|
||||
user_: User = user[0] if (user := results.first()) else None
|
||||
return False if user_ is None else user_.session == session_value
|
||||
|
||||
@staticmethod
|
||||
async def update_session(uid: int, session_value: str) -> None:
|
||||
async with sqlite.session() as session:
|
||||
session = cast(AsyncSession, session)
|
||||
statement = select(User).where(User.uid == uid)
|
||||
results = await session.exec(statement)
|
||||
user_: User = user[0] if (user := results.first()) else None
|
||||
if user_ is None:
|
||||
return
|
||||
user_.last_login_time = int(time.time())
|
||||
user_.session = session_value
|
||||
await session.commit()
|
||||
await session.refresh(user_)
|
||||
|
||||
@staticmethod
|
||||
def gen_session() -> str:
|
||||
return ''.join(
|
||||
secrets.choice(
|
||||
string.ascii_uppercase + string.ascii_lowercase + string.digits
|
||||
)
|
||||
for _ in range(30)
|
||||
)
|
76
models/services/user.py
Normal file
76
models/services/user.py
Normal file
@ -0,0 +1,76 @@
|
||||
import time
|
||||
from typing import cast, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from defs import sqlite
|
||||
from models.models.user import User
|
||||
|
||||
|
||||
class UserAction:
|
||||
@staticmethod
|
||||
async def add_user(user: User):
|
||||
async with sqlite.session() as session:
|
||||
session = cast(AsyncSession, session)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
|
||||
@staticmethod
|
||||
async def get_user_by_username(username: str) -> Optional[User]:
|
||||
async with sqlite.session() as session:
|
||||
session = cast(AsyncSession, session)
|
||||
statement = select(User).where(User.username == username)
|
||||
results = await session.exec(statement)
|
||||
return user[0] if (user := results.first()) else None
|
||||
|
||||
@staticmethod
|
||||
async def update_user(old_user: User, new_user: User = None):
|
||||
if new_user:
|
||||
old_user.username = new_user.username
|
||||
old_user.password = new_user.password
|
||||
old_user.is_admin = new_user.is_admin
|
||||
old_user.register_time = new_user.register_time
|
||||
old_user.last_login_time = new_user.last_login_time
|
||||
old_user.session = new_user.session
|
||||
async with sqlite.session() as session:
|
||||
session = cast(AsyncSession, session)
|
||||
session.add(old_user)
|
||||
await session.commit()
|
||||
await session.refresh(old_user)
|
||||
|
||||
@staticmethod
|
||||
async def add_or_update_user(user: User):
|
||||
if old_user := await UserAction.get_user_by_username(user.username):
|
||||
await UserAction.update_user(old_user, user)
|
||||
else:
|
||||
await UserAction.add_user(user)
|
||||
|
||||
@staticmethod
|
||||
async def change_user_password(username: str, password: str) -> bool:
|
||||
user = await UserAction.get_user_by_username(username)
|
||||
if not user:
|
||||
return False
|
||||
user.password = password
|
||||
await UserAction.update_user(user)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def gen_new_user(
|
||||
username: str,
|
||||
password: str,
|
||||
is_admin: bool = False,
|
||||
register_time: int = 0,
|
||||
last_login_time: int = 0,
|
||||
session: str = "",
|
||||
) -> User:
|
||||
if not register_time:
|
||||
register_time = int(time.time())
|
||||
return User(
|
||||
username=username,
|
||||
password=password,
|
||||
is_admin=is_admin,
|
||||
register_time=register_time,
|
||||
last_login_time=last_login_time,
|
||||
session=session,
|
||||
)
|
30
models/sqlite.py
Normal file
30
models/sqlite.py
Normal file
@ -0,0 +1,30 @@
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from models.models.user import User
|
||||
from pathlib import Path
|
||||
|
||||
__all__ = ["User", "Sqlite"]
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
DataPath = Path("data")
|
||||
DataPath.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
class Sqlite:
|
||||
def __init__(self):
|
||||
self.engine = create_async_engine("sqlite+aiosqlite:///data/data.db")
|
||||
self.session = sessionmaker(bind=self.engine, class_=AsyncSession)
|
||||
|
||||
async def create_db_and_tables(self):
|
||||
async with self.engine.begin() as session:
|
||||
await session.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
async def get_session(self):
|
||||
async with self.session() as session:
|
||||
yield session
|
||||
|
||||
def stop(self):
|
||||
self.session.close_all()
|
11
requirements.txt
Normal file
11
requirements.txt
Normal file
@ -0,0 +1,11 @@
|
||||
httpx==0.23.3
|
||||
fastapi==0.94.1
|
||||
python-multipart
|
||||
starlette==0.26.1
|
||||
uvicorn==0.21.0
|
||||
sqlalchemy==1.4.41
|
||||
sqlmodel==0.0.8
|
||||
aiosqlite==0.18.0
|
||||
pydantic~=1.10.6
|
||||
python-dotenv==1.0.0
|
||||
aiofiles==23.1.0
|
7
settings.py
Normal file
7
settings.py
Normal file
@ -0,0 +1,7 @@
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
HOST = os.getenv('HOST')
|
||||
PORT = int(os.getenv('PORT'))
|
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
6
utils/user.py
Normal file
6
utils/user.py
Normal file
@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
username: str
|
||||
password: str
|
Loading…
Reference in New Issue
Block a user