diff --git a/main.py b/main.py index 12a5376..7ab74de 100644 --- a/main.py +++ b/main.py @@ -1,14 +1,7 @@ -from persica.context.application import ApplicationContext -from persica.applicationbuilder import ApplicationBuilder +from src.app import app def main(): - app = ( - ApplicationBuilder() - .set_application_context_class(ApplicationContext) - .set_scanner_packages(["src.core", "src.route"]) - .build() - ) app.run() diff --git a/src/app.py b/src/app.py new file mode 100644 index 0000000..57494fa --- /dev/null +++ b/src/app.py @@ -0,0 +1,9 @@ +from persica.applicationbuilder import ApplicationBuilder +from persica.context.application import ApplicationContext + +app = ( + ApplicationBuilder() + .set_application_context_class(ApplicationContext) + .set_scanner_packages(["src.core", "src.plugin", "src.route"]) + .build() +) diff --git a/src/core/web_app.py b/src/core/web_app.py index 28543f7..ba72322 100644 --- a/src/core/web_app.py +++ b/src/core/web_app.py @@ -28,6 +28,7 @@ class WebApp(AsyncInitializingComponent): ) async def start(self): + print("开始启动 web 服务") self.init_web() self.web_server = uvicorn.Server( config=uvicorn.Config( diff --git a/src/plugin/__init__.py b/src/plugin/__init__.py new file mode 100644 index 0000000..8f0d565 --- /dev/null +++ b/src/plugin/__init__.py @@ -0,0 +1,2 @@ +from ._handler import handler +from ._plugin import Plugin diff --git a/src/plugin/_handler.py b/src/plugin/_handler.py new file mode 100644 index 0000000..2b62dfb --- /dev/null +++ b/src/plugin/_handler.py @@ -0,0 +1,137 @@ +from dataclasses import dataclass +from enum import Enum +from types import MethodType +from typing import Any, Callable, Dict, ParamSpec, TypeVar, Optional, TYPE_CHECKING + +from fastapi import Depends + +if TYPE_CHECKING: + from fastapi_user_auth.auth import Auth + +P = ParamSpec("P") +T = TypeVar("T") +R = TypeVar("R") + + +class HandlerMethodEnum(str, Enum): + GET = "get" + POST = "post" + + +HANDLER_DATA_ATTR_NAME = "_handler_datas" +"""用于储存生成 handler 时所需要的参数(例如 block)的属性名""" + + +@dataclass(init=True) +class HandlerData: + method: HandlerMethodEnum + path: str + admin: bool + student: bool + out: bool + kwargs: Dict[str, Any] + callback: Optional[MethodType] + + def get_depends(self, auth: "Auth") -> Optional[Depends]: + if not (self.admin or self.student or self.out): + return None + roles = [] + if self.admin: + roles.append("admin") + if self.student: + roles.append("student") + if self.out: + roles.append("out") + return Depends(auth.requires(roles)()) + + +class _Handler: + _method: "HandlerMethodEnum" + + kwargs: Dict[str, Any] = {} + + def __init__( + self, + path: str, + method: "HandlerMethodEnum", + admin: bool = True, + student: bool = False, + out: bool = False, + **kwargs, + ) -> None: + self.path = path + self._method = method + self.admin = admin + self.student = student + self.out = out + self.kwargs = kwargs + + def __call__(self, func: Callable[P, R]) -> Callable[P, R]: + """decorator实现,从 func 生成 Handler""" + + handler_datas = getattr(func, HANDLER_DATA_ATTR_NAME, []) + handler_datas.append( + HandlerData( + method=self._method, + path=self.path, + admin=self.admin, + student=self.student, + out=self.out, + kwargs=self.kwargs, + callback=None, + ) + ) + setattr(func, HANDLER_DATA_ATTR_NAME, handler_datas) + + return func + + +class _GetHandler(_Handler): + def __init__( + self, + path: str, + admin: bool = False, + student: bool = False, + out: bool = False, + **kwargs, + ): + method = HandlerMethodEnum.GET + super(_GetHandler, self).__init__( + path=path, + method=method, + admin=admin, + student=student, + out=out, + **kwargs, + ) + + +class _PostHandler(_Handler): + def __init__( + self, + path: str, + admin: bool = False, + student: bool = False, + out: bool = False, + **kwargs, + ): + method = HandlerMethodEnum.POST + super(_PostHandler, self).__init__( + path=path, + method=method, + admin=admin, + student=student, + out=out, + **kwargs, + ) + + +# noinspection PyPep8Naming +class handler(_Handler): + post = _PostHandler + get = _GetHandler + + def __init__(self, **kwargs) -> None: + super().__init__("", HandlerMethodEnum.POST, **kwargs) + + raise NotImplementedError("handler class is not instantiable") diff --git a/src/plugin/_plugin.py b/src/plugin/_plugin.py new file mode 100644 index 0000000..6babf49 --- /dev/null +++ b/src/plugin/_plugin.py @@ -0,0 +1,80 @@ +import asyncio +from typing import List, ClassVar, TYPE_CHECKING, Optional +from types import MethodType +from multiprocessing import RLock as Lock + +from persica.factory.component import AsyncInitializingComponent + +from src.app import app +from src.core.web_app import WebApp +from src.plugin._handler import HandlerData +from src.services.users.services import UserServices + +if TYPE_CHECKING: + from multiprocessing.synchronize import RLock as LockType + +_HANDLER_DATA_ATTR_NAME = "_handler_datas" +"""用于储存生成 handler 时所需要的参数(例如 block)的属性名""" + +_EXCLUDE_ATTRS = ["handlers", "jobs", "install", "initialize"] + + +class Plugin(AsyncInitializingComponent): + __order__ = 2 + + _lock: ClassVar["LockType"] = Lock() + _asyncio_lock: ClassVar["LockType"] = asyncio.Lock() + _installed: bool = False + + _prefix: Optional[str] = None + _handlers: Optional[List[HandlerData]] = None + + @property + def handlers(self) -> List[HandlerData]: + """该插件的所有 handler""" + with self._lock: + if self._handlers is None: + self._handlers = [] + + for attr in dir(self): + if ( + not (attr.startswith("_") or attr in _EXCLUDE_ATTRS) + and isinstance(func := getattr(self, attr), MethodType) + and (datas := getattr(func, _HANDLER_DATA_ATTR_NAME, [])) + ): + for data in datas: + data: "HandlerData" + data.callback = func + self._handlers.append(data) + return self._handlers + + async def initialize(self): + await self.install() + + async def install(self) -> None: + """安装""" + print(f"安装插件 {self.__class__.__name__}") + if not self._installed: + async with self._asyncio_lock: + web: WebApp = app.factory.get_object(WebApp) + auth: UserServices = app.factory.get_object(UserServices) + + # self._install_jobs() + + for h in self.handlers: + dep = h.kwargs.get("dependencies", []) + if _dep := h.get_depends(auth.repo.AUTH): + dep.append(_dep) + + web.app.add_api_route( + path=( + h.path + if self._prefix is None + else f"{self._prefix}{h.path}" + ), + endpoint=h.callback, + methods=[h.method.value], + dependencies=dep, + ) + + self._installed = True diff --git a/src/route/users.py b/src/route/users.py index 343ba57..8d9a1c7 100644 --- a/src/route/users.py +++ b/src/route/users.py @@ -1,59 +1,52 @@ from typing import TYPE_CHECKING from fastapi_amis_admin.crud import BaseApiOut -from persica.factory.component import AsyncInitializingComponent -from fastapi import APIRouter, HTTPException, Depends +from fastapi import HTTPException +from persica.factory.component import AsyncInitializingComponent from starlette import status from starlette.requests import Request from starlette.responses import Response -from src.core.web_app import WebApp -from src.services.users.schemas import UserRegIn, SystemUserEnum, UserLoginOut +from src.plugin import Plugin, handler +from src.services.users.schemas import ( + UserRegIn, + SystemUserEnum, + UserLoginOut, + UserRoleEnum, +) from src.services.users.services import UserServices if TYPE_CHECKING: from fastapi_user_auth.auth import Auth -class UserRoutes(AsyncInitializingComponent): - __order__ = 2 +class UserRoutes(Plugin, AsyncInitializingComponent): + _prefix = "/user" - def __init__(self, app: WebApp, user_services: UserServices): - self.router = APIRouter(prefix="/user") - self.router.add_api_route("/register", self.register, methods=["POST"]) - self.router.add_api_route("/login", self.login, methods=["POST"]) + def __init__(self, user_services: UserServices): self.user_services = user_services - self.app = app.app - - async def initialize(self): - self.router.add_api_route( - "/need_login", - self.need_login, - methods=["GET"], - dependencies=[Depends(self.user_services.repo.AUTH.requires("admin")())], - ) - self.app.include_router(self.router) + @handler.post("/register", admin=False) async def register(self, data: UserRegIn): if data.username.upper() in SystemUserEnum.__members__: return BaseApiOut(status=-1, msg="用户名已被注册", data=None) user = await self.user_services.get_user(username=data.username) if user: return BaseApiOut(status=-1, msg="用户名已被注册", data=None) - role = "student" + role = UserRoleEnum.STUDENT.value if not (data.student_id or data.phone): return BaseApiOut(status=-1, msg="学号或手机号至少填写一项", data=None) if data.student_id: user = await self.user_services.get_user(student_id=data.student_id) if user: return BaseApiOut(status=-1, msg="学号已被注册", data=None) - role = "student" + role = UserRoleEnum.STUDENT.value if data.phone: user = await self.user_services.get_user(phone=data.phone) if user: return BaseApiOut(status=-1, msg="手机号已被注册", data=None) - role = "out" + role = UserRoleEnum.OUT.value # 检查通过,注册用户 try: user = await self.user_services.register_user( @@ -88,6 +81,7 @@ class UserRoutes(AsyncInitializingComponent): request.scope.get("user"), ip, ua, forwarded_for ) + @handler.post("/login", admin=False) async def login( self, request: Request, response: Response, username: str, password: str ): @@ -118,6 +112,6 @@ class UserRoutes(AsyncInitializingComponent): response.set_cookie("Authorization", f"bearer {token_info.access_token}") return BaseApiOut(code=0, data=token_info) - @staticmethod - async def need_login(): + @handler.get("/need_login", student=True, out=True) + async def need_login(self): return {} diff --git a/src/services/users/repositories.py b/src/services/users/repositories.py index 7b656a5..9bd5e1d 100644 --- a/src/services/users/repositories.py +++ b/src/services/users/repositories.py @@ -3,6 +3,7 @@ from typing import Optional from fastapi_user_auth.auth import Auth from fastapi_user_auth.auth.backends.redis import RedisTokenStore from fastapi_user_auth.auth.models import CasbinRule, LoginHistory +from fastapi_user_auth.utils.casbin import update_subject_roles from persica.factory.component import AsyncInitializingComponent from pydantic import SecretStr from sqlmodel import select @@ -35,6 +36,7 @@ class UserRepo(AsyncInitializingComponent): ) self.AUTH.backend.attach_middleware(self.app.app) await self.AUTH.create_role_user("admin") + await self.AUTH.enforcer.load_policy() async def register_user( self, diff --git a/src/services/users/services.py b/src/services/users/services.py index dc9fa48..7c32a91 100644 --- a/src/services/users/services.py +++ b/src/services/users/services.py @@ -1,6 +1,7 @@ -from typing import Optional +from typing import Optional, List, Union, Sequence from fastapi_user_auth.auth.models import CasbinRule, LoginHistory +from fastapi_user_auth.utils.casbin import update_subject_roles from persica.factory.component import AsyncInitializingComponent from pydantic import SecretStr @@ -71,15 +72,20 @@ class UserServices(AsyncInitializingComponent): ) -> Optional[RoleModel]: return await self.repo.get_role_rule(ptype, v0, v1) - async def is_user_in_role_group(self, username: str, role_key: str) -> bool: - return ( - await self.repo.get_role_rule("g", f"u:{username}", f"r:{role_key}") - is not None - ) + async def is_user_in_role_group( + self, username: str, roles: Union[str, Sequence[str]], is_any: bool = True + ) -> bool: + return await self.repo.AUTH.has_role_for_user(username, roles, is_any) - async def add_user_to_role_group(self, username: str, role_key: str) -> CasbinRule: - rule = self.rule_model(ptype="g", v0=f"u:{username}", v1=f"r:{role_key}") - return await self.repo.create_role_rule(rule) + async def add_user_to_role_group( + self, username: str, roles: Union[str, Sequence[str]] + ): + if isinstance(roles, str): + roles = [roles] + new_roles = [f"r:{role}" for role in roles] + await update_subject_roles( + self.repo.AUTH.enforcer, subject=f"u:{username}", role_keys=new_roles + ) async def create_login_history( self, user: "UserModel", ip: str, ua: str, forwarded_for: str