Make sure tests pass in all supported python versions

This commit is contained in:
Santiago Martinez 2023-08-09 12:29:23 +01:00
parent 179183c018
commit c99c1a9b1e

View File

@ -1,4 +1,7 @@
from __future__ import annotations
import ipaddress import ipaddress
import sys
import uuid import uuid
import weakref import weakref
from datetime import date, datetime, time, timedelta from datetime import date, datetime, time, timedelta
@ -22,8 +25,6 @@ from typing import (
TypeVar, TypeVar,
Union, Union,
cast, cast,
get_args,
get_origin,
) )
from pydantic import BaseModel from pydantic import BaseModel
@ -45,6 +46,11 @@ from sqlalchemy.sql.sqltypes import LargeBinary, Time
from .sql.sqltypes import GUID, AutoString from .sql.sqltypes import GUID, AutoString
from .typing import SQLModelConfig from .typing import SQLModelConfig
if sys.version_info >= (3, 8):
from typing import get_args, get_origin
else:
from typing_extensions import get_args, get_origin
_T = TypeVar("_T") _T = TypeVar("_T")
NoArgAnyCallable = Callable[[], Any] NoArgAnyCallable = Callable[[], Any]
NoneType = type(None) NoneType = type(None)
@ -61,7 +67,6 @@ def __dataclass_transform__(
class FieldInfo(PydanticFieldInfo): class FieldInfo(PydanticFieldInfo):
nullable: Union[bool, PydanticUndefinedType] nullable: Union[bool, PydanticUndefinedType]
def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None: def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None:
@ -401,7 +406,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
def get_sqlalchemy_type(field: FieldInfo) -> Any: def get_sqlalchemy_type(field: FieldInfo) -> Any:
type_: type | None = field.annotation type_: Optional[type] = field.annotation
# Resolve Optional fields # Resolve Optional fields
if type_ is not None and get_origin(type_) is Union: if type_ is not None and get_origin(type_) is Union:
@ -486,7 +491,7 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
"index": index, "index": index,
"unique": unique, "unique": unique,
} }
sa_default: PydanticUndefinedType | Callable[[], Any] = PydanticUndefined sa_default: Union[PydanticUndefinedType, Callable[[], Any]] = PydanticUndefined
if field.default_factory: if field.default_factory:
sa_default = field.default_factory sa_default = field.default_factory
elif field.default is not PydanticUndefined: elif field.default is not PydanticUndefined:
@ -531,7 +536,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
def __init__(__pydantic_self__, **data: Any) -> None: def __init__(__pydantic_self__, **data: Any) -> None:
old_dict = __pydantic_self__.__dict__.copy() old_dict = __pydantic_self__.__dict__.copy()
super().__init__(**data) super().__init__(**data)
__pydantic_self__.__dict__ = old_dict | __pydantic_self__.__dict__ __pydantic_self__.__dict__ = {**old_dict, **__pydantic_self__.__dict__}
non_pydantic_keys = data.keys() - __pydantic_self__.model_fields non_pydantic_keys = data.keys() - __pydantic_self__.model_fields
for key in non_pydantic_keys: for key in non_pydantic_keys:
if key in __pydantic_self__.__sqlmodel_relationships__: if key in __pydantic_self__.__sqlmodel_relationships__:
@ -560,12 +565,12 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
@classmethod @classmethod
def model_validate( def model_validate(
cls: type[_TSQLModel], cls: Type[_TSQLModel],
obj: Any, obj: Any,
*, *,
strict: bool | None = None, strict: Optional[bool] = None,
from_attributes: bool | None = None, from_attributes: Optional[bool] = None,
context: dict[str, Any] | None = None, context: Optional[Dict[str, Any]] = None,
) -> _TSQLModel: ) -> _TSQLModel:
# Somehow model validate doesn't call __init__ so it would remove our init logic # Somehow model validate doesn't call __init__ so it would remove our init logic
validated = super().model_validate( validated = super().model_validate(
@ -590,7 +595,9 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
def _is_field_noneable(field: FieldInfo) -> bool: def _is_field_noneable(field: FieldInfo) -> bool:
if not isinstance(field.nullable, PydanticUndefinedType): if hasattr(field, "nullable") and not isinstance(
field.nullable, PydanticUndefinedType
):
return field.nullable return field.nullable
if not field.is_required(): if not field.is_required():
default = getattr(field, "original_default", field.default) default = getattr(field, "original_default", field.default)