From eb2d7b7bdda58ba8073e5c3634b78d83cd828a7e Mon Sep 17 00:00:00 2001 From: xtaodada Date: Sat, 15 Apr 2023 21:58:25 +0800 Subject: [PATCH] create some funcs --- apis/middleware/check_admin.py | 22 +++++++ apis/{ => middleware}/check_session.py | 7 ++- apis/post/create_post.py | 47 +++++++++++++++ apis/post/edit_post.py | 46 +++++++++++++++ apis/post/get_posts.py | 49 +++++++++++++++ apis/topic/create_topic.py | 24 ++++++++ apis/topic/get_topics.py | 15 +++++ apis/{ => user}/change_password.py | 0 apis/{ => user}/check_login.py | 0 apis/user/get_me.py | 19 ++++++ apis/{ => user}/login.py | 0 apis/{ => user}/reg.py | 0 defs.py | 2 + errors/post.py | 28 +++++++++ main.py | 8 ++- models/models/post.py | 28 +++++++++ models/models/topic.py | 20 +++++++ models/models/user.py | 11 ++++ models/services/post.py | 82 ++++++++++++++++++++++++++ models/services/topic.py | 64 ++++++++++++++++++++ models/services/user.py | 13 ++++ 21 files changed, 480 insertions(+), 5 deletions(-) create mode 100644 apis/middleware/check_admin.py rename apis/{ => middleware}/check_session.py (71%) create mode 100644 apis/post/create_post.py create mode 100644 apis/post/edit_post.py create mode 100644 apis/post/get_posts.py create mode 100644 apis/topic/create_topic.py create mode 100644 apis/topic/get_topics.py rename apis/{ => user}/change_password.py (100%) rename apis/{ => user}/check_login.py (100%) create mode 100644 apis/user/get_me.py rename apis/{ => user}/login.py (100%) rename apis/{ => user}/reg.py (100%) create mode 100644 errors/post.py create mode 100644 models/models/post.py create mode 100644 models/models/topic.py create mode 100644 models/services/post.py create mode 100644 models/services/topic.py diff --git a/apis/middleware/check_admin.py b/apis/middleware/check_admin.py new file mode 100644 index 0000000..90a67f9 --- /dev/null +++ b/apis/middleware/check_admin.py @@ -0,0 +1,22 @@ +from fastapi import Request +from fastapi.responses import JSONResponse + +from defs import app, need_admin_routes +from models.services.user import UserAction + + +@app.middleware("http") +async def check_admin_middleware(request: Request, call_next): + if request.url.path not in need_admin_routes: + return await call_next(request) + uid = request.cookies.get("uid") + session = request.cookies.get("session") + try: + if not uid or not session: + raise ValueError + uid = int(uid) + if not await UserAction.check_admin(uid): + raise ValueError + except ValueError: + return JSONResponse(status_code=403, content={"code": 403, "msg": "此操作需要管理员权限"}) + return await call_next(request) diff --git a/apis/check_session.py b/apis/middleware/check_session.py similarity index 71% rename from apis/check_session.py rename to apis/middleware/check_session.py index aa9d627..35550d7 100644 --- a/apis/check_session.py +++ b/apis/middleware/check_session.py @@ -1,7 +1,7 @@ from fastapi import Request from fastapi.responses import JSONResponse -from defs import app, need_auth_routes +from defs import app, need_auth_routes, need_auth_uid_only_routes from models.services.session import SessionAction @@ -20,5 +20,8 @@ async def check_session_middleware(request: Request, call_next): if not auth_success: raise ValueError except ValueError: - return JSONResponse(status_code=401, content={"code": 401, "msg": "Cookie 无效"}) + if request.url.path in need_auth_uid_only_routes: + request.cookies["uid"] = "" + else: + return JSONResponse(status_code=401, content={"code": 401, "msg": "Cookie 无效"}) return await call_next(request) diff --git a/apis/post/create_post.py b/apis/post/create_post.py new file mode 100644 index 0000000..6280d09 --- /dev/null +++ b/apis/post/create_post.py @@ -0,0 +1,47 @@ +from pydantic import BaseModel + +from defs import app, need_auth_routes +from errors.post import * +from fastapi import Request +from models.services.post import PostAction +from models.services.topic import TopicAction +from models.services.user import UserAction + + +class CreatePost(BaseModel): + tid: int + title: str + content: str + + +async def create_post_func(model: CreatePost, uid: int): + topic = await TopicAction.get_topic_by_tid(model.tid) + if topic is None: + raise PostTopicNotValidException() + if len(model.title) > 100: + raise PostTitleTooLongException() + if len(model.content) > 5000: + raise PostContentTooLongException() + if topic.need_admin: + if not await UserAction.check_admin(uid): + raise PostTopicNeedAdminException() + post = PostAction.gen_new_post( + model.tid, + uid, + model.title, + model.content, + ) + await PostAction.add_post(post) + + +@app.post("/create_post") +async def create_post(post: CreatePost, request: Request): + uid = int(request.cookies.get("uid")) + try: + await create_post_func(post, uid) + except PostException as e: + return {"code": 400, "msg": e.message} + return {"code": 200, "msg": "创建成功"} + + +need_auth_routes.append("/create_post") diff --git a/apis/post/edit_post.py b/apis/post/edit_post.py new file mode 100644 index 0000000..ffd5d91 --- /dev/null +++ b/apis/post/edit_post.py @@ -0,0 +1,46 @@ +from pydantic import BaseModel + +from defs import app, need_auth_routes +from errors.post import * +from fastapi import Request +from models.services.post import PostAction +from models.services.user import UserAction + +import time + + +class EditPost(BaseModel): + pid: int + title: str + content: str + + +async def edit_post_func(model: EditPost, uid: int): + post = await PostAction.get_post_by_pid(model.pid) + if post is None: + raise PostNotExistException() + admin = await UserAction.check_admin(uid) + if not admin: + if post.uid != uid: + raise PostTopicNeedAdminException() + if len(model.title) > 100: + raise PostTitleTooLongException() + if len(model.content) > 5000: + raise PostContentTooLongException() + post.title = model.title + post.content = model.content + post.update_time = int(time.time()) + await PostAction.update_post(post) + + +@app.post("/edit_post") +async def edit_post(post: EditPost, request: Request): + uid = int(request.cookies.get("uid")) + try: + await edit_post_func(post, uid) + except PostException as e: + return {"code": 400, "msg": e.message} + return {"code": 200, "msg": "修改成功"} + + +need_auth_routes.append("/edit_post") diff --git a/apis/post/get_posts.py b/apis/post/get_posts.py new file mode 100644 index 0000000..c7c3bd9 --- /dev/null +++ b/apis/post/get_posts.py @@ -0,0 +1,49 @@ +from typing import List, Dict + +from pydantic import BaseModel + +from defs import app +from errors.post import * +from fastapi import Request +from models.services.post import PostAction +from models.services.topic import TopicAction +from models.services.user import UserAction + + +class GetPost(BaseModel): + tid: int + + +async def get_post_func(model: GetPost = None, uid: int = None) -> List[Dict]: + tid = model.tid if model else None + admin = await UserAction.check_admin(uid) if uid else False + if tid: + topic = await TopicAction.get_topic_by_tid(tid) + if not topic: + raise PostTopicNotValidException() + if topic.need_admin: + if not admin: + raise PostTopicNeedAdminException() + posts = await PostAction.get_posts_by_tid(tid, admin) + return [post.dict_post() for post in posts] + + +@app.get("/get_posts") +async def get_posts_get(request: Request): + uid = request.cookies.get("uid") + if uid is not None: + uid = int(uid) + data = await get_post_func(uid=uid) + return {"code": 200, "msg": "获取成功", "data": data} + + +@app.post("/get_posts") +async def get_posts_post(model: GetPost, request: Request): + uid = request.cookies.get("uid") + if uid is not None: + uid = int(uid) + try: + data = await get_post_func(model, uid) + except PostException as e: + return {"code": 403, "msg": e.message} + return {"code": 200, "msg": "获取成功", "data": data} diff --git a/apis/topic/create_topic.py b/apis/topic/create_topic.py new file mode 100644 index 0000000..8eac2e3 --- /dev/null +++ b/apis/topic/create_topic.py @@ -0,0 +1,24 @@ +from pydantic import BaseModel + +from defs import app, need_auth_routes, need_admin_routes +from models.services.topic import TopicAction + + +class CreateTopic(BaseModel): + title: str + need_admin: bool + + +async def create_topic_func(title: str, need_admin: bool): + topic = TopicAction.gen_new_topic(title, need_admin=need_admin) + await TopicAction.add_topic(topic) + + +@app.post("/create_topic") +async def create_topic(topic: CreateTopic): + await create_topic_func(topic.title, topic.need_admin) + return {"code": 200, "msg": "创建成功"} + + +need_auth_routes.append("/create_topic") +need_admin_routes.append("/create_topic") diff --git a/apis/topic/get_topics.py b/apis/topic/get_topics.py new file mode 100644 index 0000000..10538c3 --- /dev/null +++ b/apis/topic/get_topics.py @@ -0,0 +1,15 @@ +from typing import List, Dict, Any + +from defs import app +from models.services.topic import TopicAction + + +async def get_topics_func() -> List[Dict[str, Any]]: + topics = await TopicAction.get_all_topic() + return [topic.dict_topic() for topic in topics] + + +@app.get("/get_topics") +async def create_topic(): + data = await get_topics_func() + return {"code": 200, "msg": "获取成功", "data": data} diff --git a/apis/change_password.py b/apis/user/change_password.py similarity index 100% rename from apis/change_password.py rename to apis/user/change_password.py diff --git a/apis/check_login.py b/apis/user/check_login.py similarity index 100% rename from apis/check_login.py rename to apis/user/check_login.py diff --git a/apis/user/get_me.py b/apis/user/get_me.py new file mode 100644 index 0000000..95fb04c --- /dev/null +++ b/apis/user/get_me.py @@ -0,0 +1,19 @@ +from typing import Dict + +from defs import app, need_auth_routes +from fastapi import Request +from models.services.user import UserAction + + +async def get_me_func(uid: int) -> Dict: + user = await UserAction.get_user_by_id(uid) + return user.dict_user() + + +@app.get("/get_me") +async def get_me(request: Request): + uid = int(request.cookies.get("uid")) + user = await get_me_func(uid) + return {"code": 200, "msg": "获取成功", "data": user} + +need_auth_routes.append("/check_login") diff --git a/apis/login.py b/apis/user/login.py similarity index 100% rename from apis/login.py rename to apis/user/login.py diff --git a/apis/reg.py b/apis/user/reg.py similarity index 100% rename from apis/reg.py rename to apis/user/reg.py diff --git a/defs.py b/defs.py index e2220c7..035bacc 100644 --- a/defs.py +++ b/defs.py @@ -12,3 +12,5 @@ app.add_middleware( ) sqlite = Sqlite() need_auth_routes = [] +need_auth_uid_only_routes = [] +need_admin_routes = [] diff --git a/errors/post.py b/errors/post.py new file mode 100644 index 0000000..3329268 --- /dev/null +++ b/errors/post.py @@ -0,0 +1,28 @@ +class PostException(Exception): + def __init__(self, message: str = ""): + self.message = message + + +class PostTitleTooLongException(PostException): + def __init__(self, message: str = "标题过长"): + super().__init__(message) + + +class PostContentTooLongException(PostException): + def __init__(self, message: str = "内容过长"): + super().__init__(message) + + +class PostTopicNotValidException(PostException): + def __init__(self, message: str = "主题不存在"): + super().__init__(message) + + +class PostTopicNeedAdminException(PostException): + def __init__(self, message: str = "需要管理员权限"): + super().__init__(message) + + +class PostNotExistException(PostException): + def __init__(self, message: str = "文章不存在"): + super().__init__(message) diff --git a/main.py b/main.py index 165b342..9717048 100644 --- a/main.py +++ b/main.py @@ -7,9 +7,11 @@ from defs import app, sqlite, loop # 遍历 apis 文件夹下的所有文件,并且使用 importlib 导入 # 从而实现自动导入 -for filename in os.listdir("apis"): - if filename.endswith(".py"): - importlib.import_module(f"apis.{filename[:-3]}") +for root, dirs, files in os.walk("apis"): + start_index = root.replace(os.sep, ".") + for filename in files: + if filename.endswith(".py"): + importlib.import_module(f"{start_index}.{filename[:-3]}") async def main(): diff --git a/models/models/post.py b/models/models/post.py new file mode 100644 index 0000000..1b0c71c --- /dev/null +++ b/models/models/post.py @@ -0,0 +1,28 @@ +from typing import Dict + +from sqlmodel import SQLModel, Field + + +class Post(SQLModel, table=True): + __table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci") + + pid: int = Field(primary_key=True, default=None) + tid: int = Field(default=None) + uid: int = Field(default=None) + content: str = Field(default="") + title: str = Field(default="") + create_time: int = Field(default="") + update_time: int = Field(default="") + is_hidden: bool = Field(default=False) + is_delete: bool = Field(default=False) + + def dict_post(self) -> Dict: + return { + "pid": self.pid, + "tid": self.tid, + "uid": self.uid, + "title": self.title, + "content": self.content, + "create_time": self.create_time, + "update_time": self.update_time, + } diff --git a/models/models/topic.py b/models/models/topic.py new file mode 100644 index 0000000..c14eeea --- /dev/null +++ b/models/models/topic.py @@ -0,0 +1,20 @@ +from typing import Dict + +from sqlmodel import SQLModel, Field + + +class Topic(SQLModel, table=True): + __table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci") + + tid: int = Field(primary_key=True, default=None) + title: str = Field(default="") + create_time: int = Field(default="") + need_admin: bool = Field(default=False) + + def dict_topic(self) -> Dict: + return { + "tid": self.tid, + "title": self.title, + "create_time": self.create_time, + "need_admin": self.need_admin, + } diff --git a/models/models/user.py b/models/models/user.py index 3d86bb7..b063a0d 100644 --- a/models/models/user.py +++ b/models/models/user.py @@ -1,3 +1,5 @@ +from typing import Dict + from sqlmodel import SQLModel, Field @@ -11,3 +13,12 @@ class User(SQLModel, table=True): register_time: int = Field(default="") last_login_time: int = Field(default="") session: str = Field(default="") + + def dict_user(self) -> Dict: + return { + "uid": self.uid, + "username": self.username, + "is_admin": self.is_admin, + "register_time": self.register_time, + "last_login_time": self.last_login_time, + } diff --git a/models/services/post.py b/models/services/post.py new file mode 100644 index 0000000..ef3c6dc --- /dev/null +++ b/models/services/post.py @@ -0,0 +1,82 @@ +import time +from typing import cast, Optional, List + +from sqlalchemy import select +from sqlmodel.ext.asyncio.session import AsyncSession + +from defs import sqlite +from models.models.post import Post + + +class PostAction: + @staticmethod + async def add_post(post: Post): + async with sqlite.session() as session: + session = cast(AsyncSession, session) + session.add(post) + await session.commit() + + @staticmethod + async def update_post(old_post: Post, new_post: Post = None): + if new_post: + old_post.tid = new_post.tid + old_post.uid = new_post.uid + old_post.title = new_post.title + old_post.content = new_post.content + old_post.create_time = new_post.create_time + old_post.update_time = new_post.update_time + old_post.is_hidden = new_post.is_hidden + old_post.is_delete = new_post.is_delete + async with sqlite.session() as session: + session = cast(AsyncSession, session) + session.add(old_post) + await session.commit() + await session.refresh(old_post) + + @staticmethod + async def get_post_by_pid(pid: int) -> Post: + async with sqlite.session() as session: + session = cast(AsyncSession, session) + query = select(Post).where(Post.pid == pid) + results = await session.execute(query) + return post[0] if (post := results.first()) else None + + @staticmethod + async def get_posts_by_tid(tid: int = None, admin: bool = False) -> List[Post]: + async with sqlite.session() as session: + session = cast(AsyncSession, session) + query = select(Post) + if tid: + query = query.where(Post.tid == tid) + if not admin: + query = query.where( + Post.is_delete == False + ).where( + Post.is_hidden == False + ) + results = await session.execute(query) + return results.scalars().all() + + @staticmethod + def gen_new_post( + tid: int, + uid: int, + title: str, + content: str, + create_time: Optional[int] = None, + update_time: Optional[int] = None, + is_hidden: Optional[bool] = False, + is_delete: Optional[bool] = False, + ) -> Post: + if not create_time: + create_time = int(time.time()) + return Post( + tid=tid, + uid=uid, + title=title, + content=content, + create_time=create_time, + update_time=update_time, + is_hidden=is_hidden, + is_delete=is_delete, + ) diff --git a/models/services/topic.py b/models/services/topic.py new file mode 100644 index 0000000..36bfc76 --- /dev/null +++ b/models/services/topic.py @@ -0,0 +1,64 @@ +import time +from typing import cast, Optional, List + +from sqlalchemy import select +from sqlmodel.ext.asyncio.session import AsyncSession + +from defs import sqlite +from models.models.topic import Topic + + +class TopicAction: + @staticmethod + async def add_topic(topic: Topic): + async with sqlite.session() as session: + session = cast(AsyncSession, session) + session.add(topic) + await session.commit() + + @staticmethod + async def update_topic(old_topic: Topic, new_topic: Topic = None): + if new_topic: + old_topic.title = new_topic.title + old_topic.create_time = new_topic.create_time + old_topic.need_admin = new_topic.need_admin + async with sqlite.session() as session: + session = cast(AsyncSession, session) + session.add(old_topic) + await session.commit() + await session.refresh(old_topic) + + @staticmethod + async def get_topic_by_tid(tid: int) -> Optional[Topic]: + async with sqlite.session() as session: + session = cast(AsyncSession, session) + query = select(Topic).where(Topic.tid == tid) + results = await session.execute(query) + return topic[0] if (topic := results.first()) else None + + @staticmethod + async def check_need_admin(tid: int) -> bool: + topic = await TopicAction.get_topic_by_tid(tid) + return topic.need_admin if topic else False + + @staticmethod + async def get_all_topic() -> List[Topic]: + async with sqlite.session() as session: + session = cast(AsyncSession, session) + query = select(Topic) + results = await session.execute(query) + return results.scalars().all() + + @staticmethod + def gen_new_topic( + title: str, + create_time: Optional[int] = None, + need_admin: Optional[bool] = None, + ) -> Topic: + if not create_time: + create_time = int(time.time()) + return Topic( + title=title, + create_time=create_time, + need_admin=need_admin, + ) diff --git a/models/services/user.py b/models/services/user.py index cbeb1aa..ba0a532 100644 --- a/models/services/user.py +++ b/models/services/user.py @@ -24,6 +24,19 @@ class UserAction: results = await session.exec(statement) return user[0] if (user := results.first()) else None + @staticmethod + async def get_user_by_id(uid: int) -> Optional[User]: + async with sqlite.session() as session: + session = cast(AsyncSession, session) + statement = select(User).where(User.uid == uid) + results = await session.exec(statement) + return user[0] if (user := results.first()) else None + + @staticmethod + async def check_admin(uid: int) -> bool: + user = await UserAction.get_user_by_id(uid) + return user.is_admin if user else False + @staticmethod async def update_user(old_user: User, new_user: User = None): if new_user: