diff --git a/.gitignore b/.gitignore index b6e4761..9638bdf 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,6 @@ dmypy.json # Pyre type checker .pyre/ +.idea/ +.env.example +data/ diff --git a/apis/check_login.py b/apis/check_login.py new file mode 100644 index 0000000..c9848b3 --- /dev/null +++ b/apis/check_login.py @@ -0,0 +1,8 @@ +from defs import app, need_auth_routes + + +@app.get("/check_login") +async def check_login(): + return {"code": 200, "msg": "登录状态有效"} + +need_auth_routes.append("/check_login") diff --git a/apis/check_session.py b/apis/check_session.py new file mode 100644 index 0000000..aa9d627 --- /dev/null +++ b/apis/check_session.py @@ -0,0 +1,24 @@ +from fastapi import Request +from fastapi.responses import JSONResponse + +from defs import app, need_auth_routes +from models.services.session import SessionAction + + +@app.middleware("http") +async def check_session_middleware(request: Request, call_next): + if request.url.path not in need_auth_routes: + return await call_next(request) + uid = request.cookies.get("uid") + session = request.cookies.get("session") + try: + if not uid or not session: + raise ValueError + uid = int(uid) + session = str(session) + auth_success = await SessionAction.check_session(uid, session) + if not auth_success: + raise ValueError + except ValueError: + return JSONResponse(status_code=401, content={"code": 401, "msg": "Cookie 无效"}) + return await call_next(request) diff --git a/apis/login.py b/apis/login.py new file mode 100644 index 0000000..35cbeb4 --- /dev/null +++ b/apis/login.py @@ -0,0 +1,35 @@ +from errors.user import UserNotFoundError, UserPasswordIncorrectError +from fastapi import Response + +from defs import app +from models.services.session import SessionAction +from models.services.user import UserAction +from utils.user import User + + +async def authenticate_user(username: str, password: str) -> int: + user = await UserAction.get_user_by_username(username) + if user is None: + raise UserNotFoundError + if user.password != password: + raise UserPasswordIncorrectError + return user.uid + + +async def update_session(uid: int, session: str): + await SessionAction.update_session(uid, session) + + +@app.post("/login") +async def login(user: User, response: Response): + try: + uid = await authenticate_user(user.username, user.password) + except UserNotFoundError: + return {"code": 403, "msg": "用户不存在"} + except UserPasswordIncorrectError: + return {"code": 403, "msg": "用户名或密码错误"} + session = SessionAction.gen_session() + await update_session(uid, session) + response.set_cookie(key="uid", value=str(uid)) + response.set_cookie(key="session", value=session) + return {"code": 200, "msg": "登录成功", "data": {"uid": str(uid), "session": session}} diff --git a/apis/reg.py b/apis/reg.py new file mode 100644 index 0000000..450e7bf --- /dev/null +++ b/apis/reg.py @@ -0,0 +1,25 @@ +from errors.user import UserAlreadyExistsError + +from defs import app +from models.services.user import UserAction +from utils.user import User + + +async def reg_user(username: str, password: str) -> int: + user = await UserAction.get_user_by_username(username) + if user: + raise UserAlreadyExistsError + user = UserAction.gen_new_user( + username, + password, + ) + await UserAction.add_user(user) + + +@app.post("/reg") +async def reg(user: User): + try: + await reg_user(user.username, user.password) + except UserAlreadyExistsError: + return {"code": 409, "msg": "用户已存在"} + return {"code": 200, "msg": "注册成功"} diff --git a/defs.py b/defs.py new file mode 100644 index 0000000..c5fa96b --- /dev/null +++ b/defs.py @@ -0,0 +1,10 @@ +import asyncio + +from fastapi import FastAPI + +from models.sqlite import Sqlite + +loop = asyncio.get_event_loop() +app = FastAPI() +sqlite = Sqlite() +need_auth_routes = [] diff --git a/errors/__init__.py b/errors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/errors/user.py b/errors/user.py new file mode 100644 index 0000000..ba8535e --- /dev/null +++ b/errors/user.py @@ -0,0 +1,18 @@ +class UserException(Exception): + def __init__(self, message): + self.message = message + + +class UserNotFoundError(UserException): + def __init__(self, message="User not found"): + super().__init__(message) + + +class UserAlreadyExistsError(UserException): + def __init__(self, message="User already exists"): + super().__init__(message) + + +class UserPasswordIncorrectError(UserException): + def __init__(self, message="User password incorrect"): + super().__init__(message) diff --git a/main.py b/main.py new file mode 100644 index 0000000..165b342 --- /dev/null +++ b/main.py @@ -0,0 +1,24 @@ +import importlib +import os + +import uvicorn +from settings import HOST, PORT +from defs import app, sqlite, loop + +# 遍历 apis 文件夹下的所有文件,并且使用 importlib 导入 +# 从而实现自动导入 +for filename in os.listdir("apis"): + if filename.endswith(".py"): + importlib.import_module(f"apis.{filename[:-3]}") + + +async def main(): + await sqlite.create_db_and_tables() + server = uvicorn.Server( + config=uvicorn.Config(app, host=HOST, port=PORT) + ) + await server.serve() + + +if __name__ == "__main__": + loop.run_until_complete(main()) diff --git a/models/models/user.py b/models/models/user.py new file mode 100644 index 0000000..3d86bb7 --- /dev/null +++ b/models/models/user.py @@ -0,0 +1,13 @@ +from sqlmodel import SQLModel, Field + + +class User(SQLModel, table=True): + __table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci") + + uid: int = Field(primary_key=True, default=None) + username: str = Field(default="") + password: str = Field(default="") + is_admin: bool = Field(default=False) + register_time: int = Field(default="") + last_login_time: int = Field(default="") + session: str = Field(default="") diff --git a/models/services/session.py b/models/services/session.py new file mode 100644 index 0000000..c07f917 --- /dev/null +++ b/models/services/session.py @@ -0,0 +1,45 @@ +import secrets +import string +import time + +from typing import cast + +from sqlalchemy import select +from sqlmodel.ext.asyncio.session import AsyncSession + +from defs import sqlite +from models.models.user import User + + +class SessionAction: + @staticmethod + async def check_session(uid: int, session_value: str) -> bool: + async with sqlite.session() as session: + session = cast(AsyncSession, session) + statement = select(User).where(User.uid == uid) + results = await session.exec(statement) + user_: User = user[0] if (user := results.first()) else None + return False if user_ is None else user_.session == session_value + + @staticmethod + async def update_session(uid: int, session_value: str) -> None: + async with sqlite.session() as session: + session = cast(AsyncSession, session) + statement = select(User).where(User.uid == uid) + results = await session.exec(statement) + user_: User = user[0] if (user := results.first()) else None + if user_ is None: + return + user_.last_login_time = int(time.time()) + user_.session = session_value + await session.commit() + await session.refresh(user_) + + @staticmethod + def gen_session() -> str: + return ''.join( + secrets.choice( + string.ascii_uppercase + string.ascii_lowercase + string.digits + ) + for _ in range(30) + ) diff --git a/models/services/user.py b/models/services/user.py new file mode 100644 index 0000000..cbeb1aa --- /dev/null +++ b/models/services/user.py @@ -0,0 +1,76 @@ +import time +from typing import cast, Optional + +from sqlalchemy import select +from sqlmodel.ext.asyncio.session import AsyncSession + +from defs import sqlite +from models.models.user import User + + +class UserAction: + @staticmethod + async def add_user(user: User): + async with sqlite.session() as session: + session = cast(AsyncSession, session) + session.add(user) + await session.commit() + + @staticmethod + async def get_user_by_username(username: str) -> Optional[User]: + async with sqlite.session() as session: + session = cast(AsyncSession, session) + statement = select(User).where(User.username == username) + results = await session.exec(statement) + return user[0] if (user := results.first()) else None + + @staticmethod + async def update_user(old_user: User, new_user: User = None): + if new_user: + old_user.username = new_user.username + old_user.password = new_user.password + old_user.is_admin = new_user.is_admin + old_user.register_time = new_user.register_time + old_user.last_login_time = new_user.last_login_time + old_user.session = new_user.session + async with sqlite.session() as session: + session = cast(AsyncSession, session) + session.add(old_user) + await session.commit() + await session.refresh(old_user) + + @staticmethod + async def add_or_update_user(user: User): + if old_user := await UserAction.get_user_by_username(user.username): + await UserAction.update_user(old_user, user) + else: + await UserAction.add_user(user) + + @staticmethod + async def change_user_password(username: str, password: str) -> bool: + user = await UserAction.get_user_by_username(username) + if not user: + return False + user.password = password + await UserAction.update_user(user) + return True + + @staticmethod + def gen_new_user( + username: str, + password: str, + is_admin: bool = False, + register_time: int = 0, + last_login_time: int = 0, + session: str = "", + ) -> User: + if not register_time: + register_time = int(time.time()) + return User( + username=username, + password=password, + is_admin=is_admin, + register_time=register_time, + last_login_time=last_login_time, + session=session, + ) diff --git a/models/sqlite.py b/models/sqlite.py new file mode 100644 index 0000000..6b05f9f --- /dev/null +++ b/models/sqlite.py @@ -0,0 +1,30 @@ +from sqlmodel import SQLModel + +from models.models.user import User +from pathlib import Path + +__all__ = ["User", "Sqlite"] + +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.orm import sessionmaker +from sqlmodel.ext.asyncio.session import AsyncSession + +DataPath = Path("data") +DataPath.mkdir(exist_ok=True, parents=True) + + +class Sqlite: + def __init__(self): + self.engine = create_async_engine("sqlite+aiosqlite:///data/data.db") + self.session = sessionmaker(bind=self.engine, class_=AsyncSession) + + async def create_db_and_tables(self): + async with self.engine.begin() as session: + await session.run_sync(SQLModel.metadata.create_all) + + async def get_session(self): + async with self.session() as session: + yield session + + def stop(self): + self.session.close_all() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e03dc17 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +httpx==0.23.3 +fastapi==0.94.1 +python-multipart +starlette==0.26.1 +uvicorn==0.21.0 +sqlalchemy==1.4.41 +sqlmodel==0.0.8 +aiosqlite==0.18.0 +pydantic~=1.10.6 +python-dotenv==1.0.0 +aiofiles==23.1.0 diff --git a/settings.py b/settings.py new file mode 100644 index 0000000..037d82b --- /dev/null +++ b/settings.py @@ -0,0 +1,7 @@ +import os +from dotenv import load_dotenv + +load_dotenv() + +HOST = os.getenv('HOST') +PORT = int(os.getenv('PORT')) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/user.py b/utils/user.py new file mode 100644 index 0000000..35a2129 --- /dev/null +++ b/utils/user.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class User(BaseModel): + username: str + password: str