Make all tests but fastapi work

This commit is contained in:
Anton De Meester 2023-08-01 08:13:30 +00:00
parent 43d5d41a29
commit 82888160ac

View File

@ -14,7 +14,6 @@ from typing import (
ForwardRef, ForwardRef,
List, List,
Mapping, Mapping,
NoneType,
Optional, Optional,
Sequence, Sequence,
Set, Set,
@ -48,6 +47,7 @@ from .typing import SQLModelConfig
_T = TypeVar("_T") _T = TypeVar("_T")
NoArgAnyCallable = Callable[[], Any] NoArgAnyCallable = Callable[[], Any]
NoneType = type(None)
def __dataclass_transform__( def __dataclass_transform__(
@ -273,13 +273,17 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
key: pydantic_kwargs.pop(key) key: pydantic_kwargs.pop(key)
for key in pydantic_kwargs.keys() & allowed_config_kwargs for key in pydantic_kwargs.keys() & allowed_config_kwargs
} }
config_table = getattr(class_dict.get("Config", object()), "table", False) config_table = getattr(class_dict.get("Config", object()), "table", False) or kwargs.get("table", False)
# If we have a table, we need to have defaults for all fields # If we have a table, we need to have defaults for all fields
# Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything # Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything
if config_table is True: if config_table is True:
for key in original_annotations.keys(): for key in pydantic_annotations.keys():
if dict_used.get(key, PydanticUndefined) is PydanticUndefined: value = dict_used.get(key, PydanticUndefined)
if value is PydanticUndefined:
dict_used[key] = None dict_used[key] = None
elif isinstance(value, FieldInfo):
if value.default is PydanticUndefined and value.default_factory is None:
value.default = None
new_cls: Type["SQLModelMetaclass"] = super().__new__( new_cls: Type["SQLModelMetaclass"] = super().__new__(
cls, name, bases, dict_used, **config_kwargs cls, name, bases, dict_used, **config_kwargs
@ -349,8 +353,11 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
continue continue
ann = cls.__annotations__[rel_name] ann = cls.__annotations__[rel_name]
relationship_to = get_origin(ann) relationship_to = get_origin(ann)
# If Union (Optional), get the real field # Direct relationships (e.g. 'Team' or Team) have None as an origin
if relationship_to is Union: if relationship_to is None:
relationship_to = ann
# If Union (e.g. Optional), get the real field
elif relationship_to is Union:
relationship_to = get_args(ann)[0] relationship_to = get_args(ann)[0]
# If a list, then also get the real field # If a list, then also get the real field
elif relationship_to is list: elif relationship_to is list:
@ -501,6 +508,16 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
__allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six
model_config = SQLModelConfig(from_attributes=True) model_config = SQLModelConfig(from_attributes=True)
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
new_object = super().__new__(cls)
# SQLAlchemy doesn't call __init__ on the base class
# Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html
# Set __fields_set__ here, that would have been set when calling __init__
# in the Pydantic model so that when SQLAlchemy sets attributes that are
# added (e.g. when querying from DB) to the __fields_set__, this already exists
object.__setattr__(new_object, "__pydantic_fields_set__", set())
return new_object
def __init__(__pydantic_self__, **data: Any) -> None: def __init__(__pydantic_self__, **data: Any) -> None:
old_dict = __pydantic_self__.__dict__.copy() old_dict = __pydantic_self__.__dict__.copy()
super().__init__(**data) super().__init__(**data)
@ -531,6 +548,10 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
def __tablename__(cls) -> str: def __tablename__(cls) -> str:
return cls.__name__.lower() return cls.__name__.lower()
@classmethod
def model_validate(cls, *args, **kwargs):
return super().model_validate(*args, **kwargs)
def _is_field_noneable(field: FieldInfo) -> bool: def _is_field_noneable(field: FieldInfo) -> bool:
if not field.is_required(): if not field.is_required():