This commit is contained in:
Anton De Meester 2023-07-31 13:57:57 +00:00
parent 6955600120
commit 43d5d41a29
2 changed files with 20 additions and 15 deletions

View File

@ -14,6 +14,7 @@ from typing import (
ForwardRef,
List,
Mapping,
NoneType,
Optional,
Sequence,
Set,
@ -194,7 +195,7 @@ def Relationship(
*,
back_populates: Optional[str] = None,
link_model: Optional[Any] = None,
sa_relationship: Optional[RelationshipProperty] = None,
sa_relationship: Optional[RelationshipProperty[Any]] = None,
sa_relationship_args: Optional[Sequence[Any]] = None,
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
) -> Any:
@ -211,7 +212,7 @@ def Relationship(
@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
__sqlmodel_relationships__: Dict[str, RelationshipInfo]
model_config: Type[SQLModelConfig]
model_config: SQLModelConfig
model_fields: Dict[str, FieldInfo]
# Replicate SQLAlchemy
@ -280,7 +281,9 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
if dict_used.get(key, PydanticUndefined) is PydanticUndefined:
dict_used[key] = None
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
new_cls: Type["SQLModelMetaclass"] = super().__new__(
cls, name, bases, dict_used, **config_kwargs
)
new_cls.__annotations__ = {
**relationship_annotations,
**pydantic_annotations,
@ -371,7 +374,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
rel_args.extend(rel_info.sa_relationship_args)
if rel_info.sa_relationship_kwargs:
rel_kwargs.update(rel_info.sa_relationship_kwargs)
rel_value: RelationshipProperty = relationship(
rel_value: RelationshipProperty[Any] = relationship(
relationship_to, *rel_args, **rel_kwargs
)
dict_used[rel_name] = rel_value
@ -382,7 +385,8 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
def get_sqlalchemy_type(field: FieldInfo) -> Any:
type_ = field.annotation
type_: type | None = field.annotation
# Resolve Optional fields
if type_ is not None and get_origin(type_) is Union:
bases = get_args(type_)
@ -394,9 +398,12 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
# The 3rd is PydanticGeneralMetadata
metadata = _get_field_metadata(field)
if type_ is None:
raise ValueError("Missing field type")
if issubclass(type_, str):
if getattr(metadata, "max_length", None):
return AutoString(length=metadata.max_length)
max_length = getattr(metadata, "max_length", None)
if max_length:
return AutoString(length=max_length)
return AutoString
if issubclass(type_, float):
return Float
@ -463,7 +470,7 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
"index": index,
"unique": unique,
}
sa_default = PydanticUndefined
sa_default: PydanticUndefinedType | Callable[[], Any] = PydanticUndefined
if field.default_factory:
sa_default = field.default_factory
elif field.default is not PydanticUndefined:
@ -483,14 +490,12 @@ class_registry = weakref.WeakValueDictionary() # type: ignore
default_registry = registry()
_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")
class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
__slots__ = ("__weakref__",)
__tablename__: ClassVar[Union[str, Callable[..., str]]]
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]]
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]]
__name__: ClassVar[str]
metadata: ClassVar[MetaData]
__allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six
@ -511,7 +516,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
return
else:
# Set in SQLAlchemy, before Pydantic to trigger events and updates
if self.model_config.get("table", False) and is_instrumented(self, name):
if self.model_config.get("table", False) and is_instrumented(self, name): # type: ignore
set_attribute(self, name, value)
# Set in Pydantic model to trigger possible validation changes, only for
# non relationship values
@ -529,11 +534,11 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
def _is_field_noneable(field: FieldInfo) -> bool:
if not field.is_required():
if field.annotation is None or field.annotation is type(None):
if field.annotation is None or field.annotation is NoneType:
return True
if get_origin(field.annotation) is Union:
for base in get_args(field.annotation):
if base is type(None):
if base is NoneType:
return True
return False
return False

View File

@ -3,7 +3,7 @@ from typing import Any, Optional
from pydantic import ConfigDict
class SQLModelConfig(ConfigDict):
class SQLModelConfig(ConfigDict, total=False):
table: Optional[bool]
read_from_attributes: Optional[bool]
registry: Optional[Any]