From 82888160acfd596f5f75d89df309b4f952d893cf Mon Sep 17 00:00:00 2001 From: Anton De Meester Date: Tue, 1 Aug 2023 08:13:30 +0000 Subject: [PATCH] Make all tests but fastapi work --- sqlmodel/main.py | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 79a233c..86d9b19 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -14,7 +14,6 @@ from typing import ( ForwardRef, List, Mapping, - NoneType, Optional, Sequence, Set, @@ -48,6 +47,7 @@ from .typing import SQLModelConfig _T = TypeVar("_T") NoArgAnyCallable = Callable[[], Any] +NoneType = type(None) def __dataclass_transform__( @@ -273,13 +273,17 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): key: pydantic_kwargs.pop(key) for key in pydantic_kwargs.keys() & allowed_config_kwargs } - config_table = getattr(class_dict.get("Config", object()), "table", False) + config_table = getattr(class_dict.get("Config", object()), "table", False) or kwargs.get("table", False) # If we have a table, we need to have defaults for all fields # Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything if config_table is True: - for key in original_annotations.keys(): - if dict_used.get(key, PydanticUndefined) is PydanticUndefined: + for key in pydantic_annotations.keys(): + value = dict_used.get(key, PydanticUndefined) + if value is PydanticUndefined: dict_used[key] = None + elif isinstance(value, FieldInfo): + if value.default is PydanticUndefined and value.default_factory is None: + value.default = None new_cls: Type["SQLModelMetaclass"] = super().__new__( cls, name, bases, dict_used, **config_kwargs @@ -349,8 +353,11 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): continue ann = cls.__annotations__[rel_name] relationship_to = get_origin(ann) - # If Union (Optional), get the real field - if relationship_to is Union: + # Direct relationships (e.g. 'Team' or Team) have None as an origin + if relationship_to is None: + relationship_to = ann + # If Union (e.g. Optional), get the real field + elif relationship_to is Union: relationship_to = get_args(ann)[0] # If a list, then also get the real field elif relationship_to is list: @@ -501,6 +508,16 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six model_config = SQLModelConfig(from_attributes=True) + def __new__(cls, *args: Any, **kwargs: Any) -> Any: + new_object = super().__new__(cls) + # SQLAlchemy doesn't call __init__ on the base class + # Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html + # Set __fields_set__ here, that would have been set when calling __init__ + # in the Pydantic model so that when SQLAlchemy sets attributes that are + # added (e.g. when querying from DB) to the __fields_set__, this already exists + object.__setattr__(new_object, "__pydantic_fields_set__", set()) + return new_object + def __init__(__pydantic_self__, **data: Any) -> None: old_dict = __pydantic_self__.__dict__.copy() super().__init__(**data) @@ -531,6 +548,10 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry def __tablename__(cls) -> str: return cls.__name__.lower() + @classmethod + def model_validate(cls, *args, **kwargs): + return super().model_validate(*args, **kwargs) + def _is_field_noneable(field: FieldInfo) -> bool: if not field.is_required():