sqlmodel/scripts/generate_select.py

64 lines
1.8 KiB
Python
Raw Normal View History

import os
from itertools import product
from pathlib import Path
from typing import List, Tuple
import black
from jinja2 import Template
from pydantic import BaseModel
template_path = Path(__file__).parent.parent / "sqlmodel/sql/expression.py.jinja2"
destiny_path = Path(__file__).parent.parent / "sqlmodel/sql/expression.py"
number_of_types = 4
class Arg(BaseModel):
name: str
annotation: str
arg_groups: List[Arg] = []
signatures: List[Tuple[List[Arg], List[str]]] = []
for total_args in range(2, number_of_types + 1):
arg_types_tuples = product(["scalar", "model"], repeat=total_args)
for arg_type_tuple in arg_types_tuples:
args: List[Arg] = []
return_types: List[str] = []
for i, arg_type in enumerate(arg_type_tuple):
if arg_type == "scalar":
t_var = f"_TScalar_{i}"
arg = Arg(name=f"entity_{i}", annotation=t_var)
ret_type = t_var
else:
t_type = f"_TModel_{i}"
t_var = f"Type[{t_type}]"
arg = Arg(name=f"entity_{i}", annotation=t_var)
ret_type = t_type
args.append(arg)
return_types.append(ret_type)
signatures.append((args, return_types))
template: Template = Template(template_path.read_text())
result = template.render(number_of_types=number_of_types, signatures=signatures)
result = (
"# WARNING: do not modify this code, it is generated by "
"expression.py.jinja2\n\n" + result
)
result = black.format_str(result, mode=black.Mode())
current_content = destiny_path.read_text()
if current_content != result and os.getenv("CHECK_JINJA"):
raise RuntimeError(
"sqlmodel/sql/expression.py content not update with Jinja2 template"
)
destiny_path.write_text(result)