diff --git a/alembic/env.py b/alembic/env.py index 4ff4e20..dc1a58e 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -1,3 +1,4 @@ +from itertools import chain import os import asyncio from importlib import import_module @@ -13,7 +14,7 @@ from sqlmodel import SQLModel from alembic import context -from utils.const import PROJECT_ROOT +from utils.const import CORE_DIR, PLUGIN_DIR, PROJECT_ROOT from utils.log import logger # this is the Alembic Config object, which provides @@ -27,10 +28,11 @@ if config.config_file_name is not None: def scan_models() -> Iterator[str]: - """扫描所有 models.py 模块。 + """扫描 core 和 plugins 目录下所有 models.py 模块。 我们规定所有插件的 model 都需要放在名为 models.py 的文件里。""" + dirs = [CORE_DIR, PLUGIN_DIR] - for path in PROJECT_ROOT.glob("**/models.py"): + for path in chain(*[d.glob("**/models.py") for d in dirs]): yield str(path.relative_to(PROJECT_ROOT).with_suffix("")).replace(os.sep, ".") diff --git a/utils/const.py b/utils/const.py index 6105d23..626a9f5 100644 --- a/utils/const.py +++ b/utils/const.py @@ -11,6 +11,8 @@ __all__ = [ # 项目根目录 PROJECT_ROOT = Path(__file__).joinpath('../..').resolve() +# Core 目录 +CORE_DIR = PROJECT_ROOT / 'core' # 插件目录 PLUGIN_DIR = PROJECT_ROOT / 'plugins' # 资源目录