feat: support handler
This commit is contained in:
parent
a5e58557e3
commit
dfd39412f6
9
main.py
9
main.py
@ -1,14 +1,7 @@
|
|||||||
from persica.context.application import ApplicationContext
|
from src.app import app
|
||||||
from persica.applicationbuilder import ApplicationBuilder
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
app = (
|
|
||||||
ApplicationBuilder()
|
|
||||||
.set_application_context_class(ApplicationContext)
|
|
||||||
.set_scanner_packages(["src.core", "src.route"])
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
app.run()
|
app.run()
|
||||||
|
|
||||||
|
|
||||||
|
9
src/app.py
Normal file
9
src/app.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from persica.applicationbuilder import ApplicationBuilder
|
||||||
|
from persica.context.application import ApplicationContext
|
||||||
|
|
||||||
|
app = (
|
||||||
|
ApplicationBuilder()
|
||||||
|
.set_application_context_class(ApplicationContext)
|
||||||
|
.set_scanner_packages(["src.core", "src.plugin", "src.route"])
|
||||||
|
.build()
|
||||||
|
)
|
@ -28,6 +28,7 @@ class WebApp(AsyncInitializingComponent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
|
print("开始启动 web 服务")
|
||||||
self.init_web()
|
self.init_web()
|
||||||
self.web_server = uvicorn.Server(
|
self.web_server = uvicorn.Server(
|
||||||
config=uvicorn.Config(
|
config=uvicorn.Config(
|
||||||
|
2
src/plugin/__init__.py
Normal file
2
src/plugin/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from ._handler import handler
|
||||||
|
from ._plugin import Plugin
|
137
src/plugin/_handler.py
Normal file
137
src/plugin/_handler.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Any, Callable, Dict, ParamSpec, TypeVar, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from fastapi_user_auth.auth import Auth
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
T = TypeVar("T")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
|
class HandlerMethodEnum(str, Enum):
|
||||||
|
GET = "get"
|
||||||
|
POST = "post"
|
||||||
|
|
||||||
|
|
||||||
|
HANDLER_DATA_ATTR_NAME = "_handler_datas"
|
||||||
|
"""用于储存生成 handler 时所需要的参数(例如 block)的属性名"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(init=True)
|
||||||
|
class HandlerData:
|
||||||
|
method: HandlerMethodEnum
|
||||||
|
path: str
|
||||||
|
admin: bool
|
||||||
|
student: bool
|
||||||
|
out: bool
|
||||||
|
kwargs: Dict[str, Any]
|
||||||
|
callback: Optional[MethodType]
|
||||||
|
|
||||||
|
def get_depends(self, auth: "Auth") -> Optional[Depends]:
|
||||||
|
if not (self.admin or self.student or self.out):
|
||||||
|
return None
|
||||||
|
roles = []
|
||||||
|
if self.admin:
|
||||||
|
roles.append("admin")
|
||||||
|
if self.student:
|
||||||
|
roles.append("student")
|
||||||
|
if self.out:
|
||||||
|
roles.append("out")
|
||||||
|
return Depends(auth.requires(roles)())
|
||||||
|
|
||||||
|
|
||||||
|
class _Handler:
|
||||||
|
_method: "HandlerMethodEnum"
|
||||||
|
|
||||||
|
kwargs: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
method: "HandlerMethodEnum",
|
||||||
|
admin: bool = True,
|
||||||
|
student: bool = False,
|
||||||
|
out: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
self.path = path
|
||||||
|
self._method = method
|
||||||
|
self.admin = admin
|
||||||
|
self.student = student
|
||||||
|
self.out = out
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def __call__(self, func: Callable[P, R]) -> Callable[P, R]:
|
||||||
|
"""decorator实现,从 func 生成 Handler"""
|
||||||
|
|
||||||
|
handler_datas = getattr(func, HANDLER_DATA_ATTR_NAME, [])
|
||||||
|
handler_datas.append(
|
||||||
|
HandlerData(
|
||||||
|
method=self._method,
|
||||||
|
path=self.path,
|
||||||
|
admin=self.admin,
|
||||||
|
student=self.student,
|
||||||
|
out=self.out,
|
||||||
|
kwargs=self.kwargs,
|
||||||
|
callback=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
setattr(func, HANDLER_DATA_ATTR_NAME, handler_datas)
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
class _GetHandler(_Handler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
admin: bool = False,
|
||||||
|
student: bool = False,
|
||||||
|
out: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
method = HandlerMethodEnum.GET
|
||||||
|
super(_GetHandler, self).__init__(
|
||||||
|
path=path,
|
||||||
|
method=method,
|
||||||
|
admin=admin,
|
||||||
|
student=student,
|
||||||
|
out=out,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _PostHandler(_Handler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
admin: bool = False,
|
||||||
|
student: bool = False,
|
||||||
|
out: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
method = HandlerMethodEnum.POST
|
||||||
|
super(_PostHandler, self).__init__(
|
||||||
|
path=path,
|
||||||
|
method=method,
|
||||||
|
admin=admin,
|
||||||
|
student=student,
|
||||||
|
out=out,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyPep8Naming
|
||||||
|
class handler(_Handler):
|
||||||
|
post = _PostHandler
|
||||||
|
get = _GetHandler
|
||||||
|
|
||||||
|
def __init__(self, **kwargs) -> None:
|
||||||
|
super().__init__("", HandlerMethodEnum.POST, **kwargs)
|
||||||
|
|
||||||
|
raise NotImplementedError("handler class is not instantiable")
|
80
src/plugin/_plugin.py
Normal file
80
src/plugin/_plugin.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import List, ClassVar, TYPE_CHECKING, Optional
|
||||||
|
from types import MethodType
|
||||||
|
from multiprocessing import RLock as Lock
|
||||||
|
|
||||||
|
from persica.factory.component import AsyncInitializingComponent
|
||||||
|
|
||||||
|
from src.app import app
|
||||||
|
from src.core.web_app import WebApp
|
||||||
|
from src.plugin._handler import HandlerData
|
||||||
|
from src.services.users.services import UserServices
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from multiprocessing.synchronize import RLock as LockType
|
||||||
|
|
||||||
|
_HANDLER_DATA_ATTR_NAME = "_handler_datas"
|
||||||
|
"""用于储存生成 handler 时所需要的参数(例如 block)的属性名"""
|
||||||
|
|
||||||
|
_EXCLUDE_ATTRS = ["handlers", "jobs", "install", "initialize"]
|
||||||
|
|
||||||
|
|
||||||
|
class Plugin(AsyncInitializingComponent):
|
||||||
|
__order__ = 2
|
||||||
|
|
||||||
|
_lock: ClassVar["LockType"] = Lock()
|
||||||
|
_asyncio_lock: ClassVar["LockType"] = asyncio.Lock()
|
||||||
|
_installed: bool = False
|
||||||
|
|
||||||
|
_prefix: Optional[str] = None
|
||||||
|
_handlers: Optional[List[HandlerData]] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def handlers(self) -> List[HandlerData]:
|
||||||
|
"""该插件的所有 handler"""
|
||||||
|
with self._lock:
|
||||||
|
if self._handlers is None:
|
||||||
|
self._handlers = []
|
||||||
|
|
||||||
|
for attr in dir(self):
|
||||||
|
if (
|
||||||
|
not (attr.startswith("_") or attr in _EXCLUDE_ATTRS)
|
||||||
|
and isinstance(func := getattr(self, attr), MethodType)
|
||||||
|
and (datas := getattr(func, _HANDLER_DATA_ATTR_NAME, []))
|
||||||
|
):
|
||||||
|
for data in datas:
|
||||||
|
data: "HandlerData"
|
||||||
|
data.callback = func
|
||||||
|
self._handlers.append(data)
|
||||||
|
return self._handlers
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
await self.install()
|
||||||
|
|
||||||
|
async def install(self) -> None:
|
||||||
|
"""安装"""
|
||||||
|
print(f"安装插件 {self.__class__.__name__}")
|
||||||
|
if not self._installed:
|
||||||
|
async with self._asyncio_lock:
|
||||||
|
web: WebApp = app.factory.get_object(WebApp)
|
||||||
|
auth: UserServices = app.factory.get_object(UserServices)
|
||||||
|
|
||||||
|
# self._install_jobs()
|
||||||
|
|
||||||
|
for h in self.handlers:
|
||||||
|
dep = h.kwargs.get("dependencies", [])
|
||||||
|
if _dep := h.get_depends(auth.repo.AUTH):
|
||||||
|
dep.append(_dep)
|
||||||
|
|
||||||
|
web.app.add_api_route(
|
||||||
|
path=(
|
||||||
|
h.path
|
||||||
|
if self._prefix is None
|
||||||
|
else f"{self._prefix}{h.path}"
|
||||||
|
),
|
||||||
|
endpoint=h.callback,
|
||||||
|
methods=[h.method.value],
|
||||||
|
dependencies=dep,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._installed = True
|
@ -1,59 +1,52 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from fastapi_amis_admin.crud import BaseApiOut
|
from fastapi_amis_admin.crud import BaseApiOut
|
||||||
from persica.factory.component import AsyncInitializingComponent
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Depends
|
from fastapi import HTTPException
|
||||||
|
from persica.factory.component import AsyncInitializingComponent
|
||||||
from starlette import status
|
from starlette import status
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
|
|
||||||
from src.core.web_app import WebApp
|
from src.plugin import Plugin, handler
|
||||||
from src.services.users.schemas import UserRegIn, SystemUserEnum, UserLoginOut
|
from src.services.users.schemas import (
|
||||||
|
UserRegIn,
|
||||||
|
SystemUserEnum,
|
||||||
|
UserLoginOut,
|
||||||
|
UserRoleEnum,
|
||||||
|
)
|
||||||
from src.services.users.services import UserServices
|
from src.services.users.services import UserServices
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from fastapi_user_auth.auth import Auth
|
from fastapi_user_auth.auth import Auth
|
||||||
|
|
||||||
|
|
||||||
class UserRoutes(AsyncInitializingComponent):
|
class UserRoutes(Plugin, AsyncInitializingComponent):
|
||||||
__order__ = 2
|
_prefix = "/user"
|
||||||
|
|
||||||
def __init__(self, app: WebApp, user_services: UserServices):
|
def __init__(self, user_services: UserServices):
|
||||||
self.router = APIRouter(prefix="/user")
|
|
||||||
self.router.add_api_route("/register", self.register, methods=["POST"])
|
|
||||||
self.router.add_api_route("/login", self.login, methods=["POST"])
|
|
||||||
self.user_services = user_services
|
self.user_services = user_services
|
||||||
self.app = app.app
|
|
||||||
|
|
||||||
async def initialize(self):
|
|
||||||
self.router.add_api_route(
|
|
||||||
"/need_login",
|
|
||||||
self.need_login,
|
|
||||||
methods=["GET"],
|
|
||||||
dependencies=[Depends(self.user_services.repo.AUTH.requires("admin")())],
|
|
||||||
)
|
|
||||||
self.app.include_router(self.router)
|
|
||||||
|
|
||||||
|
@handler.post("/register", admin=False)
|
||||||
async def register(self, data: UserRegIn):
|
async def register(self, data: UserRegIn):
|
||||||
if data.username.upper() in SystemUserEnum.__members__:
|
if data.username.upper() in SystemUserEnum.__members__:
|
||||||
return BaseApiOut(status=-1, msg="用户名已被注册", data=None)
|
return BaseApiOut(status=-1, msg="用户名已被注册", data=None)
|
||||||
user = await self.user_services.get_user(username=data.username)
|
user = await self.user_services.get_user(username=data.username)
|
||||||
if user:
|
if user:
|
||||||
return BaseApiOut(status=-1, msg="用户名已被注册", data=None)
|
return BaseApiOut(status=-1, msg="用户名已被注册", data=None)
|
||||||
role = "student"
|
role = UserRoleEnum.STUDENT.value
|
||||||
if not (data.student_id or data.phone):
|
if not (data.student_id or data.phone):
|
||||||
return BaseApiOut(status=-1, msg="学号或手机号至少填写一项", data=None)
|
return BaseApiOut(status=-1, msg="学号或手机号至少填写一项", data=None)
|
||||||
if data.student_id:
|
if data.student_id:
|
||||||
user = await self.user_services.get_user(student_id=data.student_id)
|
user = await self.user_services.get_user(student_id=data.student_id)
|
||||||
if user:
|
if user:
|
||||||
return BaseApiOut(status=-1, msg="学号已被注册", data=None)
|
return BaseApiOut(status=-1, msg="学号已被注册", data=None)
|
||||||
role = "student"
|
role = UserRoleEnum.STUDENT.value
|
||||||
if data.phone:
|
if data.phone:
|
||||||
user = await self.user_services.get_user(phone=data.phone)
|
user = await self.user_services.get_user(phone=data.phone)
|
||||||
if user:
|
if user:
|
||||||
return BaseApiOut(status=-1, msg="手机号已被注册", data=None)
|
return BaseApiOut(status=-1, msg="手机号已被注册", data=None)
|
||||||
role = "out"
|
role = UserRoleEnum.OUT.value
|
||||||
# 检查通过,注册用户
|
# 检查通过,注册用户
|
||||||
try:
|
try:
|
||||||
user = await self.user_services.register_user(
|
user = await self.user_services.register_user(
|
||||||
@ -88,6 +81,7 @@ class UserRoutes(AsyncInitializingComponent):
|
|||||||
request.scope.get("user"), ip, ua, forwarded_for
|
request.scope.get("user"), ip, ua, forwarded_for
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@handler.post("/login", admin=False)
|
||||||
async def login(
|
async def login(
|
||||||
self, request: Request, response: Response, username: str, password: str
|
self, request: Request, response: Response, username: str, password: str
|
||||||
):
|
):
|
||||||
@ -118,6 +112,6 @@ class UserRoutes(AsyncInitializingComponent):
|
|||||||
response.set_cookie("Authorization", f"bearer {token_info.access_token}")
|
response.set_cookie("Authorization", f"bearer {token_info.access_token}")
|
||||||
return BaseApiOut(code=0, data=token_info)
|
return BaseApiOut(code=0, data=token_info)
|
||||||
|
|
||||||
@staticmethod
|
@handler.get("/need_login", student=True, out=True)
|
||||||
async def need_login():
|
async def need_login(self):
|
||||||
return {}
|
return {}
|
||||||
|
@ -3,6 +3,7 @@ from typing import Optional
|
|||||||
from fastapi_user_auth.auth import Auth
|
from fastapi_user_auth.auth import Auth
|
||||||
from fastapi_user_auth.auth.backends.redis import RedisTokenStore
|
from fastapi_user_auth.auth.backends.redis import RedisTokenStore
|
||||||
from fastapi_user_auth.auth.models import CasbinRule, LoginHistory
|
from fastapi_user_auth.auth.models import CasbinRule, LoginHistory
|
||||||
|
from fastapi_user_auth.utils.casbin import update_subject_roles
|
||||||
from persica.factory.component import AsyncInitializingComponent
|
from persica.factory.component import AsyncInitializingComponent
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
@ -35,6 +36,7 @@ class UserRepo(AsyncInitializingComponent):
|
|||||||
)
|
)
|
||||||
self.AUTH.backend.attach_middleware(self.app.app)
|
self.AUTH.backend.attach_middleware(self.app.app)
|
||||||
await self.AUTH.create_role_user("admin")
|
await self.AUTH.create_role_user("admin")
|
||||||
|
await self.AUTH.enforcer.load_policy()
|
||||||
|
|
||||||
async def register_user(
|
async def register_user(
|
||||||
self,
|
self,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from typing import Optional
|
from typing import Optional, List, Union, Sequence
|
||||||
|
|
||||||
from fastapi_user_auth.auth.models import CasbinRule, LoginHistory
|
from fastapi_user_auth.auth.models import CasbinRule, LoginHistory
|
||||||
|
from fastapi_user_auth.utils.casbin import update_subject_roles
|
||||||
from persica.factory.component import AsyncInitializingComponent
|
from persica.factory.component import AsyncInitializingComponent
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
@ -71,15 +72,20 @@ class UserServices(AsyncInitializingComponent):
|
|||||||
) -> Optional[RoleModel]:
|
) -> Optional[RoleModel]:
|
||||||
return await self.repo.get_role_rule(ptype, v0, v1)
|
return await self.repo.get_role_rule(ptype, v0, v1)
|
||||||
|
|
||||||
async def is_user_in_role_group(self, username: str, role_key: str) -> bool:
|
async def is_user_in_role_group(
|
||||||
return (
|
self, username: str, roles: Union[str, Sequence[str]], is_any: bool = True
|
||||||
await self.repo.get_role_rule("g", f"u:{username}", f"r:{role_key}")
|
) -> bool:
|
||||||
is not None
|
return await self.repo.AUTH.has_role_for_user(username, roles, is_any)
|
||||||
)
|
|
||||||
|
|
||||||
async def add_user_to_role_group(self, username: str, role_key: str) -> CasbinRule:
|
async def add_user_to_role_group(
|
||||||
rule = self.rule_model(ptype="g", v0=f"u:{username}", v1=f"r:{role_key}")
|
self, username: str, roles: Union[str, Sequence[str]]
|
||||||
return await self.repo.create_role_rule(rule)
|
):
|
||||||
|
if isinstance(roles, str):
|
||||||
|
roles = [roles]
|
||||||
|
new_roles = [f"r:{role}" for role in roles]
|
||||||
|
await update_subject_roles(
|
||||||
|
self.repo.AUTH.enforcer, subject=f"u:{username}", role_keys=new_roles
|
||||||
|
)
|
||||||
|
|
||||||
async def create_login_history(
|
async def create_login_history(
|
||||||
self, user: "UserModel", ip: str, ua: str, forwarded_for: str
|
self, user: "UserModel", ip: str, ua: str, forwarded_for: str
|
||||||
|
Loading…
Reference in New Issue
Block a user