create some funcs
This commit is contained in:
parent
dffb1e8d06
commit
eb2d7b7bdd
22
apis/middleware/check_admin.py
Normal file
22
apis/middleware/check_admin.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from fastapi import Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from defs import app, need_admin_routes
|
||||||
|
from models.services.user import UserAction
|
||||||
|
|
||||||
|
|
||||||
|
@app.middleware("http")
|
||||||
|
async def check_admin_middleware(request: Request, call_next):
|
||||||
|
if request.url.path not in need_admin_routes:
|
||||||
|
return await call_next(request)
|
||||||
|
uid = request.cookies.get("uid")
|
||||||
|
session = request.cookies.get("session")
|
||||||
|
try:
|
||||||
|
if not uid or not session:
|
||||||
|
raise ValueError
|
||||||
|
uid = int(uid)
|
||||||
|
if not await UserAction.check_admin(uid):
|
||||||
|
raise ValueError
|
||||||
|
except ValueError:
|
||||||
|
return JSONResponse(status_code=403, content={"code": 403, "msg": "此操作需要管理员权限"})
|
||||||
|
return await call_next(request)
|
@ -1,7 +1,7 @@
|
|||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from defs import app, need_auth_routes
|
from defs import app, need_auth_routes, need_auth_uid_only_routes
|
||||||
from models.services.session import SessionAction
|
from models.services.session import SessionAction
|
||||||
|
|
||||||
|
|
||||||
@ -20,5 +20,8 @@ async def check_session_middleware(request: Request, call_next):
|
|||||||
if not auth_success:
|
if not auth_success:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return JSONResponse(status_code=401, content={"code": 401, "msg": "Cookie 无效"})
|
if request.url.path in need_auth_uid_only_routes:
|
||||||
|
request.cookies["uid"] = ""
|
||||||
|
else:
|
||||||
|
return JSONResponse(status_code=401, content={"code": 401, "msg": "Cookie 无效"})
|
||||||
return await call_next(request)
|
return await call_next(request)
|
47
apis/post/create_post.py
Normal file
47
apis/post/create_post.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from defs import app, need_auth_routes
|
||||||
|
from errors.post import *
|
||||||
|
from fastapi import Request
|
||||||
|
from models.services.post import PostAction
|
||||||
|
from models.services.topic import TopicAction
|
||||||
|
from models.services.user import UserAction
|
||||||
|
|
||||||
|
|
||||||
|
class CreatePost(BaseModel):
|
||||||
|
tid: int
|
||||||
|
title: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
async def create_post_func(model: CreatePost, uid: int):
|
||||||
|
topic = await TopicAction.get_topic_by_tid(model.tid)
|
||||||
|
if topic is None:
|
||||||
|
raise PostTopicNotValidException()
|
||||||
|
if len(model.title) > 100:
|
||||||
|
raise PostTitleTooLongException()
|
||||||
|
if len(model.content) > 5000:
|
||||||
|
raise PostContentTooLongException()
|
||||||
|
if topic.need_admin:
|
||||||
|
if not await UserAction.check_admin(uid):
|
||||||
|
raise PostTopicNeedAdminException()
|
||||||
|
post = PostAction.gen_new_post(
|
||||||
|
model.tid,
|
||||||
|
uid,
|
||||||
|
model.title,
|
||||||
|
model.content,
|
||||||
|
)
|
||||||
|
await PostAction.add_post(post)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/create_post")
|
||||||
|
async def create_post(post: CreatePost, request: Request):
|
||||||
|
uid = int(request.cookies.get("uid"))
|
||||||
|
try:
|
||||||
|
await create_post_func(post, uid)
|
||||||
|
except PostException as e:
|
||||||
|
return {"code": 400, "msg": e.message}
|
||||||
|
return {"code": 200, "msg": "创建成功"}
|
||||||
|
|
||||||
|
|
||||||
|
need_auth_routes.append("/create_post")
|
46
apis/post/edit_post.py
Normal file
46
apis/post/edit_post.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from defs import app, need_auth_routes
|
||||||
|
from errors.post import *
|
||||||
|
from fastapi import Request
|
||||||
|
from models.services.post import PostAction
|
||||||
|
from models.services.user import UserAction
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class EditPost(BaseModel):
|
||||||
|
pid: int
|
||||||
|
title: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
async def edit_post_func(model: EditPost, uid: int):
|
||||||
|
post = await PostAction.get_post_by_pid(model.pid)
|
||||||
|
if post is None:
|
||||||
|
raise PostNotExistException()
|
||||||
|
admin = await UserAction.check_admin(uid)
|
||||||
|
if not admin:
|
||||||
|
if post.uid != uid:
|
||||||
|
raise PostTopicNeedAdminException()
|
||||||
|
if len(model.title) > 100:
|
||||||
|
raise PostTitleTooLongException()
|
||||||
|
if len(model.content) > 5000:
|
||||||
|
raise PostContentTooLongException()
|
||||||
|
post.title = model.title
|
||||||
|
post.content = model.content
|
||||||
|
post.update_time = int(time.time())
|
||||||
|
await PostAction.update_post(post)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/edit_post")
|
||||||
|
async def edit_post(post: EditPost, request: Request):
|
||||||
|
uid = int(request.cookies.get("uid"))
|
||||||
|
try:
|
||||||
|
await edit_post_func(post, uid)
|
||||||
|
except PostException as e:
|
||||||
|
return {"code": 400, "msg": e.message}
|
||||||
|
return {"code": 200, "msg": "修改成功"}
|
||||||
|
|
||||||
|
|
||||||
|
need_auth_routes.append("/edit_post")
|
49
apis/post/get_posts.py
Normal file
49
apis/post/get_posts.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from defs import app
|
||||||
|
from errors.post import *
|
||||||
|
from fastapi import Request
|
||||||
|
from models.services.post import PostAction
|
||||||
|
from models.services.topic import TopicAction
|
||||||
|
from models.services.user import UserAction
|
||||||
|
|
||||||
|
|
||||||
|
class GetPost(BaseModel):
|
||||||
|
tid: int
|
||||||
|
|
||||||
|
|
||||||
|
async def get_post_func(model: GetPost = None, uid: int = None) -> List[Dict]:
|
||||||
|
tid = model.tid if model else None
|
||||||
|
admin = await UserAction.check_admin(uid) if uid else False
|
||||||
|
if tid:
|
||||||
|
topic = await TopicAction.get_topic_by_tid(tid)
|
||||||
|
if not topic:
|
||||||
|
raise PostTopicNotValidException()
|
||||||
|
if topic.need_admin:
|
||||||
|
if not admin:
|
||||||
|
raise PostTopicNeedAdminException()
|
||||||
|
posts = await PostAction.get_posts_by_tid(tid, admin)
|
||||||
|
return [post.dict_post() for post in posts]
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/get_posts")
|
||||||
|
async def get_posts_get(request: Request):
|
||||||
|
uid = request.cookies.get("uid")
|
||||||
|
if uid is not None:
|
||||||
|
uid = int(uid)
|
||||||
|
data = await get_post_func(uid=uid)
|
||||||
|
return {"code": 200, "msg": "获取成功", "data": data}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/get_posts")
|
||||||
|
async def get_posts_post(model: GetPost, request: Request):
|
||||||
|
uid = request.cookies.get("uid")
|
||||||
|
if uid is not None:
|
||||||
|
uid = int(uid)
|
||||||
|
try:
|
||||||
|
data = await get_post_func(model, uid)
|
||||||
|
except PostException as e:
|
||||||
|
return {"code": 403, "msg": e.message}
|
||||||
|
return {"code": 200, "msg": "获取成功", "data": data}
|
24
apis/topic/create_topic.py
Normal file
24
apis/topic/create_topic.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from defs import app, need_auth_routes, need_admin_routes
|
||||||
|
from models.services.topic import TopicAction
|
||||||
|
|
||||||
|
|
||||||
|
class CreateTopic(BaseModel):
|
||||||
|
title: str
|
||||||
|
need_admin: bool
|
||||||
|
|
||||||
|
|
||||||
|
async def create_topic_func(title: str, need_admin: bool):
|
||||||
|
topic = TopicAction.gen_new_topic(title, need_admin=need_admin)
|
||||||
|
await TopicAction.add_topic(topic)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/create_topic")
|
||||||
|
async def create_topic(topic: CreateTopic):
|
||||||
|
await create_topic_func(topic.title, topic.need_admin)
|
||||||
|
return {"code": 200, "msg": "创建成功"}
|
||||||
|
|
||||||
|
|
||||||
|
need_auth_routes.append("/create_topic")
|
||||||
|
need_admin_routes.append("/create_topic")
|
15
apis/topic/get_topics.py
Normal file
15
apis/topic/get_topics.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
from defs import app
|
||||||
|
from models.services.topic import TopicAction
|
||||||
|
|
||||||
|
|
||||||
|
async def get_topics_func() -> List[Dict[str, Any]]:
|
||||||
|
topics = await TopicAction.get_all_topic()
|
||||||
|
return [topic.dict_topic() for topic in topics]
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/get_topics")
|
||||||
|
async def create_topic():
|
||||||
|
data = await get_topics_func()
|
||||||
|
return {"code": 200, "msg": "获取成功", "data": data}
|
19
apis/user/get_me.py
Normal file
19
apis/user/get_me.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from defs import app, need_auth_routes
|
||||||
|
from fastapi import Request
|
||||||
|
from models.services.user import UserAction
|
||||||
|
|
||||||
|
|
||||||
|
async def get_me_func(uid: int) -> Dict:
|
||||||
|
user = await UserAction.get_user_by_id(uid)
|
||||||
|
return user.dict_user()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/get_me")
|
||||||
|
async def get_me(request: Request):
|
||||||
|
uid = int(request.cookies.get("uid"))
|
||||||
|
user = await get_me_func(uid)
|
||||||
|
return {"code": 200, "msg": "获取成功", "data": user}
|
||||||
|
|
||||||
|
need_auth_routes.append("/check_login")
|
2
defs.py
2
defs.py
@ -12,3 +12,5 @@ app.add_middleware(
|
|||||||
)
|
)
|
||||||
sqlite = Sqlite()
|
sqlite = Sqlite()
|
||||||
need_auth_routes = []
|
need_auth_routes = []
|
||||||
|
need_auth_uid_only_routes = []
|
||||||
|
need_admin_routes = []
|
||||||
|
28
errors/post.py
Normal file
28
errors/post.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
class PostException(Exception):
|
||||||
|
def __init__(self, message: str = ""):
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
|
||||||
|
class PostTitleTooLongException(PostException):
|
||||||
|
def __init__(self, message: str = "标题过长"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class PostContentTooLongException(PostException):
|
||||||
|
def __init__(self, message: str = "内容过长"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class PostTopicNotValidException(PostException):
|
||||||
|
def __init__(self, message: str = "主题不存在"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class PostTopicNeedAdminException(PostException):
|
||||||
|
def __init__(self, message: str = "需要管理员权限"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class PostNotExistException(PostException):
|
||||||
|
def __init__(self, message: str = "文章不存在"):
|
||||||
|
super().__init__(message)
|
8
main.py
8
main.py
@ -7,9 +7,11 @@ from defs import app, sqlite, loop
|
|||||||
|
|
||||||
# 遍历 apis 文件夹下的所有文件,并且使用 importlib 导入
|
# 遍历 apis 文件夹下的所有文件,并且使用 importlib 导入
|
||||||
# 从而实现自动导入
|
# 从而实现自动导入
|
||||||
for filename in os.listdir("apis"):
|
for root, dirs, files in os.walk("apis"):
|
||||||
if filename.endswith(".py"):
|
start_index = root.replace(os.sep, ".")
|
||||||
importlib.import_module(f"apis.{filename[:-3]}")
|
for filename in files:
|
||||||
|
if filename.endswith(".py"):
|
||||||
|
importlib.import_module(f"{start_index}.{filename[:-3]}")
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
28
models/models/post.py
Normal file
28
models/models/post.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from sqlmodel import SQLModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class Post(SQLModel, table=True):
|
||||||
|
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
||||||
|
|
||||||
|
pid: int = Field(primary_key=True, default=None)
|
||||||
|
tid: int = Field(default=None)
|
||||||
|
uid: int = Field(default=None)
|
||||||
|
content: str = Field(default="")
|
||||||
|
title: str = Field(default="")
|
||||||
|
create_time: int = Field(default="")
|
||||||
|
update_time: int = Field(default="")
|
||||||
|
is_hidden: bool = Field(default=False)
|
||||||
|
is_delete: bool = Field(default=False)
|
||||||
|
|
||||||
|
def dict_post(self) -> Dict:
|
||||||
|
return {
|
||||||
|
"pid": self.pid,
|
||||||
|
"tid": self.tid,
|
||||||
|
"uid": self.uid,
|
||||||
|
"title": self.title,
|
||||||
|
"content": self.content,
|
||||||
|
"create_time": self.create_time,
|
||||||
|
"update_time": self.update_time,
|
||||||
|
}
|
20
models/models/topic.py
Normal file
20
models/models/topic.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from sqlmodel import SQLModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class Topic(SQLModel, table=True):
|
||||||
|
__table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci")
|
||||||
|
|
||||||
|
tid: int = Field(primary_key=True, default=None)
|
||||||
|
title: str = Field(default="")
|
||||||
|
create_time: int = Field(default="")
|
||||||
|
need_admin: bool = Field(default=False)
|
||||||
|
|
||||||
|
def dict_topic(self) -> Dict:
|
||||||
|
return {
|
||||||
|
"tid": self.tid,
|
||||||
|
"title": self.title,
|
||||||
|
"create_time": self.create_time,
|
||||||
|
"need_admin": self.need_admin,
|
||||||
|
}
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
from sqlmodel import SQLModel, Field
|
from sqlmodel import SQLModel, Field
|
||||||
|
|
||||||
|
|
||||||
@ -11,3 +13,12 @@ class User(SQLModel, table=True):
|
|||||||
register_time: int = Field(default="")
|
register_time: int = Field(default="")
|
||||||
last_login_time: int = Field(default="")
|
last_login_time: int = Field(default="")
|
||||||
session: str = Field(default="")
|
session: str = Field(default="")
|
||||||
|
|
||||||
|
def dict_user(self) -> Dict:
|
||||||
|
return {
|
||||||
|
"uid": self.uid,
|
||||||
|
"username": self.username,
|
||||||
|
"is_admin": self.is_admin,
|
||||||
|
"register_time": self.register_time,
|
||||||
|
"last_login_time": self.last_login_time,
|
||||||
|
}
|
||||||
|
82
models/services/post.py
Normal file
82
models/services/post.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import time
|
||||||
|
from typing import cast, Optional, List
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from defs import sqlite
|
||||||
|
from models.models.post import Post
|
||||||
|
|
||||||
|
|
||||||
|
class PostAction:
|
||||||
|
@staticmethod
|
||||||
|
async def add_post(post: Post):
|
||||||
|
async with sqlite.session() as session:
|
||||||
|
session = cast(AsyncSession, session)
|
||||||
|
session.add(post)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def update_post(old_post: Post, new_post: Post = None):
|
||||||
|
if new_post:
|
||||||
|
old_post.tid = new_post.tid
|
||||||
|
old_post.uid = new_post.uid
|
||||||
|
old_post.title = new_post.title
|
||||||
|
old_post.content = new_post.content
|
||||||
|
old_post.create_time = new_post.create_time
|
||||||
|
old_post.update_time = new_post.update_time
|
||||||
|
old_post.is_hidden = new_post.is_hidden
|
||||||
|
old_post.is_delete = new_post.is_delete
|
||||||
|
async with sqlite.session() as session:
|
||||||
|
session = cast(AsyncSession, session)
|
||||||
|
session.add(old_post)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(old_post)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_post_by_pid(pid: int) -> Post:
|
||||||
|
async with sqlite.session() as session:
|
||||||
|
session = cast(AsyncSession, session)
|
||||||
|
query = select(Post).where(Post.pid == pid)
|
||||||
|
results = await session.execute(query)
|
||||||
|
return post[0] if (post := results.first()) else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_posts_by_tid(tid: int = None, admin: bool = False) -> List[Post]:
|
||||||
|
async with sqlite.session() as session:
|
||||||
|
session = cast(AsyncSession, session)
|
||||||
|
query = select(Post)
|
||||||
|
if tid:
|
||||||
|
query = query.where(Post.tid == tid)
|
||||||
|
if not admin:
|
||||||
|
query = query.where(
|
||||||
|
Post.is_delete == False
|
||||||
|
).where(
|
||||||
|
Post.is_hidden == False
|
||||||
|
)
|
||||||
|
results = await session.execute(query)
|
||||||
|
return results.scalars().all()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gen_new_post(
|
||||||
|
tid: int,
|
||||||
|
uid: int,
|
||||||
|
title: str,
|
||||||
|
content: str,
|
||||||
|
create_time: Optional[int] = None,
|
||||||
|
update_time: Optional[int] = None,
|
||||||
|
is_hidden: Optional[bool] = False,
|
||||||
|
is_delete: Optional[bool] = False,
|
||||||
|
) -> Post:
|
||||||
|
if not create_time:
|
||||||
|
create_time = int(time.time())
|
||||||
|
return Post(
|
||||||
|
tid=tid,
|
||||||
|
uid=uid,
|
||||||
|
title=title,
|
||||||
|
content=content,
|
||||||
|
create_time=create_time,
|
||||||
|
update_time=update_time,
|
||||||
|
is_hidden=is_hidden,
|
||||||
|
is_delete=is_delete,
|
||||||
|
)
|
64
models/services/topic.py
Normal file
64
models/services/topic.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
import time
|
||||||
|
from typing import cast, Optional, List
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from defs import sqlite
|
||||||
|
from models.models.topic import Topic
|
||||||
|
|
||||||
|
|
||||||
|
class TopicAction:
|
||||||
|
@staticmethod
|
||||||
|
async def add_topic(topic: Topic):
|
||||||
|
async with sqlite.session() as session:
|
||||||
|
session = cast(AsyncSession, session)
|
||||||
|
session.add(topic)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def update_topic(old_topic: Topic, new_topic: Topic = None):
|
||||||
|
if new_topic:
|
||||||
|
old_topic.title = new_topic.title
|
||||||
|
old_topic.create_time = new_topic.create_time
|
||||||
|
old_topic.need_admin = new_topic.need_admin
|
||||||
|
async with sqlite.session() as session:
|
||||||
|
session = cast(AsyncSession, session)
|
||||||
|
session.add(old_topic)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(old_topic)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_topic_by_tid(tid: int) -> Optional[Topic]:
|
||||||
|
async with sqlite.session() as session:
|
||||||
|
session = cast(AsyncSession, session)
|
||||||
|
query = select(Topic).where(Topic.tid == tid)
|
||||||
|
results = await session.execute(query)
|
||||||
|
return topic[0] if (topic := results.first()) else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def check_need_admin(tid: int) -> bool:
|
||||||
|
topic = await TopicAction.get_topic_by_tid(tid)
|
||||||
|
return topic.need_admin if topic else False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_all_topic() -> List[Topic]:
|
||||||
|
async with sqlite.session() as session:
|
||||||
|
session = cast(AsyncSession, session)
|
||||||
|
query = select(Topic)
|
||||||
|
results = await session.execute(query)
|
||||||
|
return results.scalars().all()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gen_new_topic(
|
||||||
|
title: str,
|
||||||
|
create_time: Optional[int] = None,
|
||||||
|
need_admin: Optional[bool] = None,
|
||||||
|
) -> Topic:
|
||||||
|
if not create_time:
|
||||||
|
create_time = int(time.time())
|
||||||
|
return Topic(
|
||||||
|
title=title,
|
||||||
|
create_time=create_time,
|
||||||
|
need_admin=need_admin,
|
||||||
|
)
|
@ -24,6 +24,19 @@ class UserAction:
|
|||||||
results = await session.exec(statement)
|
results = await session.exec(statement)
|
||||||
return user[0] if (user := results.first()) else None
|
return user[0] if (user := results.first()) else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_user_by_id(uid: int) -> Optional[User]:
|
||||||
|
async with sqlite.session() as session:
|
||||||
|
session = cast(AsyncSession, session)
|
||||||
|
statement = select(User).where(User.uid == uid)
|
||||||
|
results = await session.exec(statement)
|
||||||
|
return user[0] if (user := results.first()) else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def check_admin(uid: int) -> bool:
|
||||||
|
user = await UserAction.get_user_by_id(uid)
|
||||||
|
return user.is_admin if user else False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def update_user(old_user: User, new_user: User = None):
|
async def update_user(old_user: User, new_user: User = None):
|
||||||
if new_user:
|
if new_user:
|
||||||
|
Loading…
Reference in New Issue
Block a user