|
|
|
@ -19,19 +19,17 @@ from typing import (
|
|
|
|
|
Set,
|
|
|
|
|
Tuple,
|
|
|
|
|
Type,
|
|
|
|
|
TypeVar,
|
|
|
|
|
TypeVar,ForwardRef,
|
|
|
|
|
Union,
|
|
|
|
|
cast,
|
|
|
|
|
cast,get_origin,get_args
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseConfig, BaseModel
|
|
|
|
|
from pydantic.errors import ConfigError, DictError
|
|
|
|
|
from pydantic.fields import SHAPE_SINGLETON
|
|
|
|
|
from pydantic import BaseModel, ValidationError
|
|
|
|
|
from pydantic_core import PydanticUndefined, PydanticUndefinedType
|
|
|
|
|
from pydantic.fields import FieldInfo as PydanticFieldInfo
|
|
|
|
|
from pydantic.fields import ModelField, Undefined, UndefinedType
|
|
|
|
|
from pydantic.main import ModelMetaclass, validate_model
|
|
|
|
|
from pydantic.typing import NoArgAnyCallable, resolve_annotations
|
|
|
|
|
from pydantic.utils import ROOT_KEY, Representation
|
|
|
|
|
from pydantic._internal._model_construction import ModelMetaclass
|
|
|
|
|
from pydantic._internal._repr import Representation
|
|
|
|
|
from pydantic._internal._fields import PydanticGeneralMetadata
|
|
|
|
|
from sqlalchemy import Boolean, Column, Date, DateTime
|
|
|
|
|
from sqlalchemy import Enum as sa_Enum
|
|
|
|
|
from sqlalchemy import Float, ForeignKey, Integer, Interval, Numeric, inspect
|
|
|
|
@ -43,8 +41,10 @@ from sqlalchemy.sql.schema import MetaData
|
|
|
|
|
from sqlalchemy.sql.sqltypes import LargeBinary, Time
|
|
|
|
|
|
|
|
|
|
from .sql.sqltypes import GUID, AutoString
|
|
|
|
|
from .typing import SQLModelConfig
|
|
|
|
|
|
|
|
|
|
_T = TypeVar("_T")
|
|
|
|
|
NoArgAnyCallable = Callable[[], Any]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __dataclass_transform__(
|
|
|
|
@ -58,22 +58,22 @@ def __dataclass_transform__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FieldInfo(PydanticFieldInfo):
|
|
|
|
|
def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
|
|
|
|
|
def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None:
|
|
|
|
|
primary_key = kwargs.pop("primary_key", False)
|
|
|
|
|
nullable = kwargs.pop("nullable", Undefined)
|
|
|
|
|
foreign_key = kwargs.pop("foreign_key", Undefined)
|
|
|
|
|
nullable = kwargs.pop("nullable", PydanticUndefined)
|
|
|
|
|
foreign_key = kwargs.pop("foreign_key", PydanticUndefined)
|
|
|
|
|
unique = kwargs.pop("unique", False)
|
|
|
|
|
index = kwargs.pop("index", Undefined)
|
|
|
|
|
sa_column = kwargs.pop("sa_column", Undefined)
|
|
|
|
|
sa_column_args = kwargs.pop("sa_column_args", Undefined)
|
|
|
|
|
sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined)
|
|
|
|
|
if sa_column is not Undefined:
|
|
|
|
|
if sa_column_args is not Undefined:
|
|
|
|
|
index = kwargs.pop("index", PydanticUndefined)
|
|
|
|
|
sa_column = kwargs.pop("sa_column", PydanticUndefined)
|
|
|
|
|
sa_column_args = kwargs.pop("sa_column_args", PydanticUndefined)
|
|
|
|
|
sa_column_kwargs = kwargs.pop("sa_column_kwargs", PydanticUndefined)
|
|
|
|
|
if sa_column is not PydanticUndefined:
|
|
|
|
|
if sa_column_args is not PydanticUndefined:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Passing sa_column_args is not supported when "
|
|
|
|
|
"also passing a sa_column"
|
|
|
|
|
)
|
|
|
|
|
if sa_column_kwargs is not Undefined:
|
|
|
|
|
if sa_column_kwargs is not PydanticUndefined:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Passing sa_column_kwargs is not supported when "
|
|
|
|
|
"also passing a sa_column"
|
|
|
|
@ -118,7 +118,7 @@ class RelationshipInfo(Representation):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def Field(
|
|
|
|
|
default: Any = Undefined,
|
|
|
|
|
default: Any = PydanticUndefined,
|
|
|
|
|
*,
|
|
|
|
|
default_factory: Optional[NoArgAnyCallable] = None,
|
|
|
|
|
alias: Optional[str] = None,
|
|
|
|
@ -145,11 +145,11 @@ def Field(
|
|
|
|
|
primary_key: bool = False,
|
|
|
|
|
foreign_key: Optional[Any] = None,
|
|
|
|
|
unique: bool = False,
|
|
|
|
|
nullable: Union[bool, UndefinedType] = Undefined,
|
|
|
|
|
index: Union[bool, UndefinedType] = Undefined,
|
|
|
|
|
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
|
|
|
|
|
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
|
|
|
|
|
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
|
|
|
|
|
nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined,
|
|
|
|
|
index: Union[bool, PydanticUndefinedType] = PydanticUndefined,
|
|
|
|
|
sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore
|
|
|
|
|
sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined,
|
|
|
|
|
sa_column_kwargs: Union[Mapping[str, Any], PydanticUndefinedType] = PydanticUndefined,
|
|
|
|
|
schema_extra: Optional[Dict[str, Any]] = None,
|
|
|
|
|
) -> Any:
|
|
|
|
|
current_schema_extra = schema_extra or {}
|
|
|
|
@ -183,7 +183,6 @@ def Field(
|
|
|
|
|
sa_column_kwargs=sa_column_kwargs,
|
|
|
|
|
**current_schema_extra,
|
|
|
|
|
)
|
|
|
|
|
field_info._validate()
|
|
|
|
|
return field_info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -191,7 +190,7 @@ def Relationship(
|
|
|
|
|
*,
|
|
|
|
|
back_populates: Optional[str] = None,
|
|
|
|
|
link_model: Optional[Any] = None,
|
|
|
|
|
sa_relationship: Optional[RelationshipProperty] = None, # type: ignore
|
|
|
|
|
sa_relationship: Optional[RelationshipProperty] = None,
|
|
|
|
|
sa_relationship_args: Optional[Sequence[Any]] = None,
|
|
|
|
|
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
|
|
|
|
|
) -> Any:
|
|
|
|
@ -208,18 +207,18 @@ def Relationship(
|
|
|
|
|
@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
|
|
|
|
|
class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
|
|
|
|
__sqlmodel_relationships__: Dict[str, RelationshipInfo]
|
|
|
|
|
__config__: Type[BaseConfig]
|
|
|
|
|
__fields__: Dict[str, ModelField]
|
|
|
|
|
model_config: Type[SQLModelConfig]
|
|
|
|
|
model_fields: Dict[str, FieldInfo]
|
|
|
|
|
|
|
|
|
|
# Replicate SQLAlchemy
|
|
|
|
|
def __setattr__(cls, name: str, value: Any) -> None:
|
|
|
|
|
if getattr(cls.__config__, "table", False):
|
|
|
|
|
if cls.model_config.get("table", False):
|
|
|
|
|
DeclarativeMeta.__setattr__(cls, name, value)
|
|
|
|
|
else:
|
|
|
|
|
super().__setattr__(name, value)
|
|
|
|
|
|
|
|
|
|
def __delattr__(cls, name: str) -> None:
|
|
|
|
|
if getattr(cls.__config__, "table", False):
|
|
|
|
|
if cls.model_config.get("table", False):
|
|
|
|
|
DeclarativeMeta.__delattr__(cls, name)
|
|
|
|
|
else:
|
|
|
|
|
super().__delattr__(name)
|
|
|
|
@ -232,11 +231,10 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
|
|
|
|
class_dict: Dict[str, Any],
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> Any:
|
|
|
|
|
|
|
|
|
|
relationships: Dict[str, RelationshipInfo] = {}
|
|
|
|
|
dict_for_pydantic = {}
|
|
|
|
|
original_annotations = resolve_annotations(
|
|
|
|
|
class_dict.get("__annotations__", {}), class_dict.get("__module__", None)
|
|
|
|
|
)
|
|
|
|
|
original_annotations = class_dict.get("__annotations__", {})
|
|
|
|
|
pydantic_annotations = {}
|
|
|
|
|
relationship_annotations = {}
|
|
|
|
|
for k, v in class_dict.items():
|
|
|
|
@ -260,7 +258,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
|
|
|
|
# superclass causing an error
|
|
|
|
|
allowed_config_kwargs: Set[str] = {
|
|
|
|
|
key
|
|
|
|
|
for key in dir(BaseConfig)
|
|
|
|
|
for key in dir(SQLModelConfig)
|
|
|
|
|
if not (
|
|
|
|
|
key.startswith("__") and key.endswith("__")
|
|
|
|
|
) # skip dunder methods and attributes
|
|
|
|
@ -270,41 +268,49 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
|
|
|
|
key: pydantic_kwargs.pop(key)
|
|
|
|
|
for key in pydantic_kwargs.keys() & allowed_config_kwargs
|
|
|
|
|
}
|
|
|
|
|
config_table = getattr(class_dict.get('Config', object()), 'table', False)
|
|
|
|
|
# If we have a table, we need to have defaults for all fields
|
|
|
|
|
# Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything
|
|
|
|
|
if config_table is True:
|
|
|
|
|
for key in original_annotations.keys():
|
|
|
|
|
if dict_used.get(key, PydanticUndefined) is PydanticUndefined:
|
|
|
|
|
dict_used[key] = None
|
|
|
|
|
|
|
|
|
|
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
|
|
|
|
|
new_cls.__annotations__ = {
|
|
|
|
|
**relationship_annotations,
|
|
|
|
|
**pydantic_annotations,
|
|
|
|
|
**new_cls.__annotations__,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_config(name: str) -> Any:
|
|
|
|
|
config_class_value = getattr(new_cls.__config__, name, Undefined)
|
|
|
|
|
if config_class_value is not Undefined:
|
|
|
|
|
config_class_value = new_cls.model_config.get(name, PydanticUndefined)
|
|
|
|
|
if config_class_value is not PydanticUndefined:
|
|
|
|
|
return config_class_value
|
|
|
|
|
kwarg_value = kwargs.get(name, Undefined)
|
|
|
|
|
if kwarg_value is not Undefined:
|
|
|
|
|
kwarg_value = kwargs.get(name, PydanticUndefined)
|
|
|
|
|
if kwarg_value is not PydanticUndefined:
|
|
|
|
|
return kwarg_value
|
|
|
|
|
return Undefined
|
|
|
|
|
return PydanticUndefined
|
|
|
|
|
|
|
|
|
|
config_table = get_config("table")
|
|
|
|
|
if config_table is True:
|
|
|
|
|
# If it was passed by kwargs, ensure it's also set in config
|
|
|
|
|
new_cls.__config__.table = config_table
|
|
|
|
|
for k, v in new_cls.__fields__.items():
|
|
|
|
|
new_cls.model_config['table'] = config_table
|
|
|
|
|
for k, v in new_cls.model_fields.items():
|
|
|
|
|
col = get_column_from_field(v)
|
|
|
|
|
setattr(new_cls, k, col)
|
|
|
|
|
# Set a config flag to tell FastAPI that this should be read with a field
|
|
|
|
|
# in orm_mode instead of preemptively converting it to a dict.
|
|
|
|
|
# This could be done by reading new_cls.__config__.table in FastAPI, but
|
|
|
|
|
# This could be done by reading new_cls.model_config['table'] in FastAPI, but
|
|
|
|
|
# that's very specific about SQLModel, so let's have another config that
|
|
|
|
|
# other future tools based on Pydantic can use.
|
|
|
|
|
new_cls.__config__.read_with_orm_mode = True
|
|
|
|
|
new_cls.model_config['read_from_attributes'] = True
|
|
|
|
|
|
|
|
|
|
config_registry = get_config("registry")
|
|
|
|
|
if config_registry is not Undefined:
|
|
|
|
|
if config_registry is not PydanticUndefined:
|
|
|
|
|
config_registry = cast(registry, config_registry)
|
|
|
|
|
# If it was passed by kwargs, ensure it's also set in config
|
|
|
|
|
new_cls.__config__.registry = config_table
|
|
|
|
|
new_cls.model_config['registry'] = config_table
|
|
|
|
|
setattr(new_cls, "_sa_registry", config_registry)
|
|
|
|
|
setattr(new_cls, "metadata", config_registry.metadata)
|
|
|
|
|
setattr(new_cls, "__abstract__", True)
|
|
|
|
@ -320,13 +326,13 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
|
|
|
|
# triggers an error
|
|
|
|
|
base_is_table = False
|
|
|
|
|
for base in bases:
|
|
|
|
|
config = getattr(base, "__config__")
|
|
|
|
|
config = getattr(base, "model_config")
|
|
|
|
|
if config and getattr(config, "table", False):
|
|
|
|
|
base_is_table = True
|
|
|
|
|
break
|
|
|
|
|
if getattr(cls.__config__, "table", False) and not base_is_table:
|
|
|
|
|
if cls.model_config.get("table", False) and not base_is_table:
|
|
|
|
|
dict_used = dict_.copy()
|
|
|
|
|
for field_name, field_value in cls.__fields__.items():
|
|
|
|
|
for field_name, field_value in cls.model_fields.items():
|
|
|
|
|
dict_used[field_name] = get_column_from_field(field_value)
|
|
|
|
|
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
|
|
|
|
|
if rel_info.sa_relationship:
|
|
|
|
@ -335,16 +341,15 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
|
|
|
|
dict_used[rel_name] = rel_info.sa_relationship
|
|
|
|
|
continue
|
|
|
|
|
ann = cls.__annotations__[rel_name]
|
|
|
|
|
temp_field = ModelField.infer(
|
|
|
|
|
name=rel_name,
|
|
|
|
|
value=rel_info,
|
|
|
|
|
annotation=ann,
|
|
|
|
|
class_validators=None,
|
|
|
|
|
config=BaseConfig,
|
|
|
|
|
)
|
|
|
|
|
relationship_to = temp_field.type_
|
|
|
|
|
if isinstance(temp_field.type_, ForwardRef):
|
|
|
|
|
relationship_to = temp_field.type_.__forward_arg__
|
|
|
|
|
relationship_to = get_origin(ann)
|
|
|
|
|
# If Union (Optional), get the real field
|
|
|
|
|
if relationship_to is Union:
|
|
|
|
|
relationship_to = get_args(ann)[0]
|
|
|
|
|
# If a list, then also get the real field
|
|
|
|
|
elif relationship_to is list:
|
|
|
|
|
relationship_to = get_args(ann)[0]
|
|
|
|
|
if isinstance(relationship_to, ForwardRef):
|
|
|
|
|
relationship_to = relationship_to.__forward_arg__
|
|
|
|
|
rel_kwargs: Dict[str, Any] = {}
|
|
|
|
|
if rel_info.back_populates:
|
|
|
|
|
rel_kwargs["back_populates"] = rel_info.back_populates
|
|
|
|
@ -362,7 +367,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
|
|
|
|
rel_args.extend(rel_info.sa_relationship_args)
|
|
|
|
|
if rel_info.sa_relationship_kwargs:
|
|
|
|
|
rel_kwargs.update(rel_info.sa_relationship_kwargs)
|
|
|
|
|
rel_value: RelationshipProperty = relationship( # type: ignore
|
|
|
|
|
rel_value: RelationshipProperty = relationship(
|
|
|
|
|
relationship_to, *rel_args, **rel_kwargs
|
|
|
|
|
)
|
|
|
|
|
dict_used[rel_name] = rel_value
|
|
|
|
@ -372,68 +377,78 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
|
|
|
|
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_sqlalchemy_type(field: ModelField) -> Any:
|
|
|
|
|
if issubclass(field.type_, str):
|
|
|
|
|
if field.field_info.max_length:
|
|
|
|
|
return AutoString(length=field.field_info.max_length)
|
|
|
|
|
def get_sqlalchemy_type(field: FieldInfo) -> Any:
|
|
|
|
|
type_ = field.annotation
|
|
|
|
|
# Resolve Optional fields
|
|
|
|
|
if type_ is not None and get_origin(type_) is Union:
|
|
|
|
|
bases = get_args(type_)
|
|
|
|
|
if len(bases) > 2:
|
|
|
|
|
raise RuntimeError("Cannot have a (non-optional) union as a SQL alchemy field")
|
|
|
|
|
type_ = bases[0]
|
|
|
|
|
|
|
|
|
|
# The 3rd is PydanticGeneralMetadata
|
|
|
|
|
metadata = _get_field_metadata(field)
|
|
|
|
|
if issubclass(type_, str):
|
|
|
|
|
if getattr(metadata, "max_length", None):
|
|
|
|
|
return AutoString(length=metadata.max_length)
|
|
|
|
|
return AutoString
|
|
|
|
|
if issubclass(field.type_, float):
|
|
|
|
|
if issubclass(type_, float):
|
|
|
|
|
return Float
|
|
|
|
|
if issubclass(field.type_, bool):
|
|
|
|
|
if issubclass(type_, bool):
|
|
|
|
|
return Boolean
|
|
|
|
|
if issubclass(field.type_, int):
|
|
|
|
|
if issubclass(type_, int):
|
|
|
|
|
return Integer
|
|
|
|
|
if issubclass(field.type_, datetime):
|
|
|
|
|
if issubclass(type_, datetime):
|
|
|
|
|
return DateTime
|
|
|
|
|
if issubclass(field.type_, date):
|
|
|
|
|
if issubclass(type_, date):
|
|
|
|
|
return Date
|
|
|
|
|
if issubclass(field.type_, timedelta):
|
|
|
|
|
if issubclass(type_, timedelta):
|
|
|
|
|
return Interval
|
|
|
|
|
if issubclass(field.type_, time):
|
|
|
|
|
if issubclass(type_, time):
|
|
|
|
|
return Time
|
|
|
|
|
if issubclass(field.type_, Enum):
|
|
|
|
|
return sa_Enum(field.type_)
|
|
|
|
|
if issubclass(field.type_, bytes):
|
|
|
|
|
if issubclass(type_, Enum):
|
|
|
|
|
return sa_Enum(type_)
|
|
|
|
|
if issubclass(type_, bytes):
|
|
|
|
|
return LargeBinary
|
|
|
|
|
if issubclass(field.type_, Decimal):
|
|
|
|
|
if issubclass(type_, Decimal):
|
|
|
|
|
return Numeric(
|
|
|
|
|
precision=getattr(field.type_, "max_digits", None),
|
|
|
|
|
scale=getattr(field.type_, "decimal_places", None),
|
|
|
|
|
precision=getattr(metadata, "max_digits", None),
|
|
|
|
|
scale=getattr(metadata, "decimal_places", None),
|
|
|
|
|
)
|
|
|
|
|
if issubclass(field.type_, ipaddress.IPv4Address):
|
|
|
|
|
if issubclass(type_, ipaddress.IPv4Address):
|
|
|
|
|
return AutoString
|
|
|
|
|
if issubclass(field.type_, ipaddress.IPv4Network):
|
|
|
|
|
if issubclass(type_, ipaddress.IPv4Network):
|
|
|
|
|
return AutoString
|
|
|
|
|
if issubclass(field.type_, ipaddress.IPv6Address):
|
|
|
|
|
if issubclass(type_, ipaddress.IPv6Address):
|
|
|
|
|
return AutoString
|
|
|
|
|
if issubclass(field.type_, ipaddress.IPv6Network):
|
|
|
|
|
if issubclass(type_, ipaddress.IPv6Network):
|
|
|
|
|
return AutoString
|
|
|
|
|
if issubclass(field.type_, Path):
|
|
|
|
|
if issubclass(type_, Path):
|
|
|
|
|
return AutoString
|
|
|
|
|
if issubclass(field.type_, uuid.UUID):
|
|
|
|
|
if issubclass(type_, uuid.UUID):
|
|
|
|
|
return GUID
|
|
|
|
|
raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")
|
|
|
|
|
raise ValueError(f"The field {field.title} has no matching SQLAlchemy type")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_column_from_field(field: ModelField) -> Column: # type: ignore
|
|
|
|
|
sa_column = getattr(field.field_info, "sa_column", Undefined)
|
|
|
|
|
def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
|
|
|
|
|
sa_column = getattr(field, "sa_column", PydanticUndefined)
|
|
|
|
|
if isinstance(sa_column, Column):
|
|
|
|
|
return sa_column
|
|
|
|
|
sa_type = get_sqlalchemy_type(field)
|
|
|
|
|
primary_key = getattr(field.field_info, "primary_key", False)
|
|
|
|
|
index = getattr(field.field_info, "index", Undefined)
|
|
|
|
|
if index is Undefined:
|
|
|
|
|
primary_key = getattr(field, "primary_key", False)
|
|
|
|
|
index = getattr(field, "index", PydanticUndefined)
|
|
|
|
|
if index is PydanticUndefined:
|
|
|
|
|
index = False
|
|
|
|
|
nullable = not primary_key and _is_field_noneable(field)
|
|
|
|
|
# Override derived nullability if the nullable property is set explicitly
|
|
|
|
|
# on the field
|
|
|
|
|
if hasattr(field.field_info, "nullable"):
|
|
|
|
|
field_nullable = getattr(field.field_info, "nullable")
|
|
|
|
|
if field_nullable != Undefined:
|
|
|
|
|
if hasattr(field, "nullable"):
|
|
|
|
|
field_nullable = getattr(field, "nullable")
|
|
|
|
|
if field_nullable != PydanticUndefined:
|
|
|
|
|
nullable = field_nullable
|
|
|
|
|
args = []
|
|
|
|
|
foreign_key = getattr(field.field_info, "foreign_key", None)
|
|
|
|
|
unique = getattr(field.field_info, "unique", False)
|
|
|
|
|
foreign_key = getattr(field, "foreign_key", None)
|
|
|
|
|
unique = getattr(field, "unique", False)
|
|
|
|
|
if foreign_key:
|
|
|
|
|
args.append(ForeignKey(foreign_key))
|
|
|
|
|
kwargs = {
|
|
|
|
@ -442,18 +457,18 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore
|
|
|
|
|
"index": index,
|
|
|
|
|
"unique": unique,
|
|
|
|
|
}
|
|
|
|
|
sa_default = Undefined
|
|
|
|
|
if field.field_info.default_factory:
|
|
|
|
|
sa_default = field.field_info.default_factory
|
|
|
|
|
elif field.field_info.default is not Undefined:
|
|
|
|
|
sa_default = field.field_info.default
|
|
|
|
|
if sa_default is not Undefined:
|
|
|
|
|
sa_default = PydanticUndefined
|
|
|
|
|
if field.default_factory:
|
|
|
|
|
sa_default = field.default_factory
|
|
|
|
|
elif field.default is not PydanticUndefined:
|
|
|
|
|
sa_default = field.default
|
|
|
|
|
if sa_default is not PydanticUndefined:
|
|
|
|
|
kwargs["default"] = sa_default
|
|
|
|
|
sa_column_args = getattr(field.field_info, "sa_column_args", Undefined)
|
|
|
|
|
if sa_column_args is not Undefined:
|
|
|
|
|
sa_column_args = getattr(field, "sa_column_args", PydanticUndefined)
|
|
|
|
|
if sa_column_args is not PydanticUndefined:
|
|
|
|
|
args.extend(list(cast(Sequence[Any], sa_column_args)))
|
|
|
|
|
sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined)
|
|
|
|
|
if sa_column_kwargs is not Undefined:
|
|
|
|
|
sa_column_kwargs = getattr(field, "sa_column_kwargs", PydanticUndefined)
|
|
|
|
|
if sa_column_kwargs is not PydanticUndefined:
|
|
|
|
|
kwargs.update(cast(Dict[Any, Any], sa_column_kwargs))
|
|
|
|
|
return Column(sa_type, *args, **kwargs) # type: ignore
|
|
|
|
|
|
|
|
|
@ -462,13 +477,6 @@ class_registry = weakref.WeakValueDictionary() # type: ignore
|
|
|
|
|
|
|
|
|
|
default_registry = registry()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _value_items_is_true(v: Any) -> bool:
|
|
|
|
|
# Re-implement Pydantic's ValueItems.is_true() as it hasn't been released as of
|
|
|
|
|
# the current latest, Pydantic 1.8.2
|
|
|
|
|
return v is True or v is ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -476,43 +484,17 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
|
|
|
|
|
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
|
|
|
|
|
__slots__ = ("__weakref__",)
|
|
|
|
|
__tablename__: ClassVar[Union[str, Callable[..., str]]]
|
|
|
|
|
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore
|
|
|
|
|
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]]
|
|
|
|
|
__name__: ClassVar[str]
|
|
|
|
|
metadata: ClassVar[MetaData]
|
|
|
|
|
__allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six
|
|
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
|
orm_mode = True
|
|
|
|
|
|
|
|
|
|
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
|
|
|
|
|
new_object = super().__new__(cls)
|
|
|
|
|
# SQLAlchemy doesn't call __init__ on the base class
|
|
|
|
|
# Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html
|
|
|
|
|
# Set __fields_set__ here, that would have been set when calling __init__
|
|
|
|
|
# in the Pydantic model so that when SQLAlchemy sets attributes that are
|
|
|
|
|
# added (e.g. when querying from DB) to the __fields_set__, this already exists
|
|
|
|
|
object.__setattr__(new_object, "__fields_set__", set())
|
|
|
|
|
return new_object
|
|
|
|
|
model_config = SQLModelConfig(from_attributes=True)
|
|
|
|
|
|
|
|
|
|
def __init__(__pydantic_self__, **data: Any) -> None:
|
|
|
|
|
# Uses something other than `self` the first arg to allow "self" as a
|
|
|
|
|
# settable attribute
|
|
|
|
|
values, fields_set, validation_error = validate_model(
|
|
|
|
|
__pydantic_self__.__class__, data
|
|
|
|
|
)
|
|
|
|
|
# Only raise errors if not a SQLModel model
|
|
|
|
|
if (
|
|
|
|
|
not getattr(__pydantic_self__.__config__, "table", False)
|
|
|
|
|
and validation_error
|
|
|
|
|
):
|
|
|
|
|
raise validation_error
|
|
|
|
|
# Do not set values as in Pydantic, pass them through setattr, so SQLAlchemy
|
|
|
|
|
# can handle them
|
|
|
|
|
# object.__setattr__(__pydantic_self__, '__dict__', values)
|
|
|
|
|
for key, value in values.items():
|
|
|
|
|
setattr(__pydantic_self__, key, value)
|
|
|
|
|
object.__setattr__(__pydantic_self__, "__fields_set__", fields_set)
|
|
|
|
|
non_pydantic_keys = data.keys() - values.keys()
|
|
|
|
|
old_dict = __pydantic_self__.__dict__.copy()
|
|
|
|
|
super().__init__(**data)
|
|
|
|
|
__pydantic_self__.__dict__ = old_dict | __pydantic_self__.__dict__
|
|
|
|
|
non_pydantic_keys = data.keys() - __pydantic_self__.model_fields
|
|
|
|
|
for key in non_pydantic_keys:
|
|
|
|
|
if key in __pydantic_self__.__sqlmodel_relationships__:
|
|
|
|
|
setattr(__pydantic_self__, key, data[key])
|
|
|
|
@ -523,135 +505,36 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
|
|
|
|
|
return
|
|
|
|
|
else:
|
|
|
|
|
# Set in SQLAlchemy, before Pydantic to trigger events and updates
|
|
|
|
|
if getattr(self.__config__, "table", False) and is_instrumented(self, name): # type: ignore
|
|
|
|
|
if self.model_config.get("table", False) and is_instrumented(self, name):
|
|
|
|
|
set_attribute(self, name, value)
|
|
|
|
|
# Set in Pydantic model to trigger possible validation changes, only for
|
|
|
|
|
# non relationship values
|
|
|
|
|
if name not in self.__sqlmodel_relationships__:
|
|
|
|
|
super().__setattr__(name, value)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_orm(
|
|
|
|
|
cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None
|
|
|
|
|
) -> _TSQLModel:
|
|
|
|
|
# Duplicated from Pydantic
|
|
|
|
|
if not cls.__config__.orm_mode:
|
|
|
|
|
raise ConfigError(
|
|
|
|
|
"You must have the config attribute orm_mode=True to use from_orm"
|
|
|
|
|
)
|
|
|
|
|
obj = {ROOT_KEY: obj} if cls.__custom_root_type__ else cls._decompose_class(obj)
|
|
|
|
|
# SQLModel, support update dict
|
|
|
|
|
if update is not None:
|
|
|
|
|
obj = {**obj, **update}
|
|
|
|
|
# End SQLModel support dict
|
|
|
|
|
if not getattr(cls.__config__, "table", False):
|
|
|
|
|
# If not table, normal Pydantic code
|
|
|
|
|
m: _TSQLModel = cls.__new__(cls)
|
|
|
|
|
else:
|
|
|
|
|
# If table, create the new instance normally to make SQLAlchemy create
|
|
|
|
|
# the _sa_instance_state attribute
|
|
|
|
|
m = cls()
|
|
|
|
|
values, fields_set, validation_error = validate_model(cls, obj)
|
|
|
|
|
if validation_error:
|
|
|
|
|
raise validation_error
|
|
|
|
|
# Updated to trigger SQLAlchemy internal handling
|
|
|
|
|
if not getattr(cls.__config__, "table", False):
|
|
|
|
|
object.__setattr__(m, "__dict__", values)
|
|
|
|
|
else:
|
|
|
|
|
for key, value in values.items():
|
|
|
|
|
setattr(m, key, value)
|
|
|
|
|
# Continue with standard Pydantic logic
|
|
|
|
|
object.__setattr__(m, "__fields_set__", fields_set)
|
|
|
|
|
m._init_private_attributes()
|
|
|
|
|
return m
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def parse_obj(
|
|
|
|
|
cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None
|
|
|
|
|
) -> _TSQLModel:
|
|
|
|
|
obj = cls._enforce_dict_if_root(obj)
|
|
|
|
|
# SQLModel, support update dict
|
|
|
|
|
if update is not None:
|
|
|
|
|
obj = {**obj, **update}
|
|
|
|
|
# End SQLModel support dict
|
|
|
|
|
return super().parse_obj(obj)
|
|
|
|
|
super(SQLModel, self).__setattr__(name, value)
|
|
|
|
|
|
|
|
|
|
def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]:
|
|
|
|
|
# Don't show SQLAlchemy private attributes
|
|
|
|
|
return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_sa_")]
|
|
|
|
|
|
|
|
|
|
# From Pydantic, override to enforce validation with dict
|
|
|
|
|
@classmethod
|
|
|
|
|
def validate(cls: Type[_TSQLModel], value: Any) -> _TSQLModel:
|
|
|
|
|
if isinstance(value, cls):
|
|
|
|
|
return value.copy() if cls.__config__.copy_on_model_validation else value
|
|
|
|
|
|
|
|
|
|
value = cls._enforce_dict_if_root(value)
|
|
|
|
|
if isinstance(value, dict):
|
|
|
|
|
values, fields_set, validation_error = validate_model(cls, value)
|
|
|
|
|
if validation_error:
|
|
|
|
|
raise validation_error
|
|
|
|
|
model = cls(**value)
|
|
|
|
|
# Reset fields set, this would have been done in Pydantic in __init__
|
|
|
|
|
object.__setattr__(model, "__fields_set__", fields_set)
|
|
|
|
|
return model
|
|
|
|
|
elif cls.__config__.orm_mode:
|
|
|
|
|
return cls.from_orm(value)
|
|
|
|
|
elif cls.__custom_root_type__:
|
|
|
|
|
return cls.parse_obj(value)
|
|
|
|
|
else:
|
|
|
|
|
try:
|
|
|
|
|
value_as_dict = dict(value)
|
|
|
|
|
except (TypeError, ValueError) as e:
|
|
|
|
|
raise DictError() from e
|
|
|
|
|
return cls(**value_as_dict)
|
|
|
|
|
|
|
|
|
|
# From Pydantic, override to only show keys from fields, omit SQLAlchemy attributes
|
|
|
|
|
def _calculate_keys(
|
|
|
|
|
self,
|
|
|
|
|
include: Optional[Mapping[Union[int, str], Any]],
|
|
|
|
|
exclude: Optional[Mapping[Union[int, str], Any]],
|
|
|
|
|
exclude_unset: bool,
|
|
|
|
|
update: Optional[Dict[str, Any]] = None,
|
|
|
|
|
) -> Optional[AbstractSet[str]]:
|
|
|
|
|
if include is None and exclude is None and not exclude_unset:
|
|
|
|
|
# Original in Pydantic:
|
|
|
|
|
# return None
|
|
|
|
|
# Updated to not return SQLAlchemy attributes
|
|
|
|
|
# Do not include relationships as that would easily lead to infinite
|
|
|
|
|
# recursion, or traversing the whole database
|
|
|
|
|
return self.__fields__.keys() # | self.__sqlmodel_relationships__.keys()
|
|
|
|
|
|
|
|
|
|
keys: AbstractSet[str]
|
|
|
|
|
if exclude_unset:
|
|
|
|
|
keys = self.__fields_set__.copy()
|
|
|
|
|
else:
|
|
|
|
|
# Original in Pydantic:
|
|
|
|
|
# keys = self.__dict__.keys()
|
|
|
|
|
# Updated to not return SQLAlchemy attributes
|
|
|
|
|
# Do not include relationships as that would easily lead to infinite
|
|
|
|
|
# recursion, or traversing the whole database
|
|
|
|
|
keys = self.__fields__.keys() # | self.__sqlmodel_relationships__.keys()
|
|
|
|
|
if include is not None:
|
|
|
|
|
keys &= include.keys()
|
|
|
|
|
|
|
|
|
|
if update:
|
|
|
|
|
keys -= update.keys()
|
|
|
|
|
|
|
|
|
|
if exclude:
|
|
|
|
|
keys -= {k for k, v in exclude.items() if _value_items_is_true(v)}
|
|
|
|
|
|
|
|
|
|
return keys
|
|
|
|
|
|
|
|
|
|
@declared_attr # type: ignore
|
|
|
|
|
def __tablename__(cls) -> str:
|
|
|
|
|
return cls.__name__.lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_field_noneable(field: ModelField) -> bool:
|
|
|
|
|
if not field.required:
|
|
|
|
|
# Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947)
|
|
|
|
|
return field.allow_none and (
|
|
|
|
|
field.shape != SHAPE_SINGLETON or not field.sub_fields
|
|
|
|
|
)
|
|
|
|
|
def _is_field_noneable(field: FieldInfo) -> bool:
|
|
|
|
|
if not field.is_required():
|
|
|
|
|
if field.annotation is None or field.annotation is type(None):
|
|
|
|
|
return True
|
|
|
|
|
if get_origin(field.annotation) is Union:
|
|
|
|
|
for base in get_args(field.annotation):
|
|
|
|
|
if base is type(None):
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def _get_field_metadata(field: FieldInfo) -> object:
|
|
|
|
|
for meta in field.metadata:
|
|
|
|
|
if isinstance(meta, PydanticGeneralMetadata):
|
|
|
|
|
return meta
|
|
|
|
|
return object()
|