Make sure tests pass in all supported python versions

This commit is contained in:
Santiago Martinez 2023-08-09 12:29:23 +01:00
parent 179183c018
commit c99c1a9b1e

View File

@ -1,4 +1,7 @@
from __future__ import annotations
import ipaddress
import sys
import uuid
import weakref
from datetime import date, datetime, time, timedelta
@ -22,8 +25,6 @@ from typing import (
TypeVar,
Union,
cast,
get_args,
get_origin,
)
from pydantic import BaseModel
@ -45,6 +46,11 @@ from sqlalchemy.sql.sqltypes import LargeBinary, Time
from .sql.sqltypes import GUID, AutoString
from .typing import SQLModelConfig
if sys.version_info >= (3, 8):
from typing import get_args, get_origin
else:
from typing_extensions import get_args, get_origin
_T = TypeVar("_T")
NoArgAnyCallable = Callable[[], Any]
NoneType = type(None)
@ -61,7 +67,6 @@ def __dataclass_transform__(
class FieldInfo(PydanticFieldInfo):
nullable: Union[bool, PydanticUndefinedType]
def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None:
@ -401,7 +406,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
def get_sqlalchemy_type(field: FieldInfo) -> Any:
type_: type | None = field.annotation
type_: Optional[type] = field.annotation
# Resolve Optional fields
if type_ is not None and get_origin(type_) is Union:
@ -486,7 +491,7 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
"index": index,
"unique": unique,
}
sa_default: PydanticUndefinedType | Callable[[], Any] = PydanticUndefined
sa_default: Union[PydanticUndefinedType, Callable[[], Any]] = PydanticUndefined
if field.default_factory:
sa_default = field.default_factory
elif field.default is not PydanticUndefined:
@ -531,7 +536,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
def __init__(__pydantic_self__, **data: Any) -> None:
old_dict = __pydantic_self__.__dict__.copy()
super().__init__(**data)
__pydantic_self__.__dict__ = old_dict | __pydantic_self__.__dict__
__pydantic_self__.__dict__ = {**old_dict, **__pydantic_self__.__dict__}
non_pydantic_keys = data.keys() - __pydantic_self__.model_fields
for key in non_pydantic_keys:
if key in __pydantic_self__.__sqlmodel_relationships__:
@ -560,12 +565,12 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
@classmethod
def model_validate(
cls: type[_TSQLModel],
cls: Type[_TSQLModel],
obj: Any,
*,
strict: bool | None = None,
from_attributes: bool | None = None,
context: dict[str, Any] | None = None,
strict: Optional[bool] = None,
from_attributes: Optional[bool] = None,
context: Optional[Dict[str, Any]] = None,
) -> _TSQLModel:
# Somehow model validate doesn't call __init__ so it would remove our init logic
validated = super().model_validate(
@ -590,7 +595,9 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
def _is_field_noneable(field: FieldInfo) -> bool:
if not isinstance(field.nullable, PydanticUndefinedType):
if hasattr(field, "nullable") and not isinstance(
field.nullable, PydanticUndefinedType
):
return field.nullable
if not field.is_required():
default = getattr(field, "original_default", field.default)