From c99c1a9b1ea2f6492ad9ceee8a4ce6dc7b8b8f18 Mon Sep 17 00:00:00 2001 From: Santiago Martinez Date: Wed, 9 Aug 2023 12:29:23 +0100 Subject: [PATCH] Make sure tests pass in all supported python versions --- sqlmodel/main.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 56ba414..6e56dcd 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import ipaddress +import sys import uuid import weakref from datetime import date, datetime, time, timedelta @@ -22,8 +25,6 @@ from typing import ( TypeVar, Union, cast, - get_args, - get_origin, ) from pydantic import BaseModel @@ -45,6 +46,11 @@ from sqlalchemy.sql.sqltypes import LargeBinary, Time from .sql.sqltypes import GUID, AutoString from .typing import SQLModelConfig +if sys.version_info >= (3, 8): + from typing import get_args, get_origin +else: + from typing_extensions import get_args, get_origin + _T = TypeVar("_T") NoArgAnyCallable = Callable[[], Any] NoneType = type(None) @@ -61,7 +67,6 @@ def __dataclass_transform__( class FieldInfo(PydanticFieldInfo): - nullable: Union[bool, PydanticUndefinedType] def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None: @@ -401,7 +406,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): def get_sqlalchemy_type(field: FieldInfo) -> Any: - type_: type | None = field.annotation + type_: Optional[type] = field.annotation # Resolve Optional fields if type_ is not None and get_origin(type_) is Union: @@ -486,7 +491,7 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore "index": index, "unique": unique, } - sa_default: PydanticUndefinedType | Callable[[], Any] = PydanticUndefined + sa_default: Union[PydanticUndefinedType, Callable[[], Any]] = PydanticUndefined if field.default_factory: sa_default = field.default_factory elif field.default is not PydanticUndefined: @@ -531,7 +536,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry def __init__(__pydantic_self__, **data: Any) -> None: old_dict = __pydantic_self__.__dict__.copy() super().__init__(**data) - __pydantic_self__.__dict__ = old_dict | __pydantic_self__.__dict__ + __pydantic_self__.__dict__ = {**old_dict, **__pydantic_self__.__dict__} non_pydantic_keys = data.keys() - __pydantic_self__.model_fields for key in non_pydantic_keys: if key in __pydantic_self__.__sqlmodel_relationships__: @@ -560,12 +565,12 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry @classmethod def model_validate( - cls: type[_TSQLModel], + cls: Type[_TSQLModel], obj: Any, *, - strict: bool | None = None, - from_attributes: bool | None = None, - context: dict[str, Any] | None = None, + strict: Optional[bool] = None, + from_attributes: Optional[bool] = None, + context: Optional[Dict[str, Any]] = None, ) -> _TSQLModel: # Somehow model validate doesn't call __init__ so it would remove our init logic validated = super().model_validate( @@ -590,7 +595,9 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry def _is_field_noneable(field: FieldInfo) -> bool: - if not isinstance(field.nullable, PydanticUndefinedType): + if hasattr(field, "nullable") and not isinstance( + field.nullable, PydanticUndefinedType + ): return field.nullable if not field.is_required(): default = getattr(field, "original_default", field.default)