mirror of
https://github.com/PaiGramTeam/sqlmodel.git
synced 2024-11-22 07:08:06 +00:00
Make sure tests pass in all supported python versions
This commit is contained in:
parent
179183c018
commit
c99c1a9b1e
@ -1,4 +1,7 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import ipaddress
|
import ipaddress
|
||||||
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
import weakref
|
import weakref
|
||||||
from datetime import date, datetime, time, timedelta
|
from datetime import date, datetime, time, timedelta
|
||||||
@ -22,8 +25,6 @@ from typing import (
|
|||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
get_args,
|
|
||||||
get_origin,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -45,6 +46,11 @@ from sqlalchemy.sql.sqltypes import LargeBinary, Time
|
|||||||
from .sql.sqltypes import GUID, AutoString
|
from .sql.sqltypes import GUID, AutoString
|
||||||
from .typing import SQLModelConfig
|
from .typing import SQLModelConfig
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 8):
|
||||||
|
from typing import get_args, get_origin
|
||||||
|
else:
|
||||||
|
from typing_extensions import get_args, get_origin
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
NoArgAnyCallable = Callable[[], Any]
|
NoArgAnyCallable = Callable[[], Any]
|
||||||
NoneType = type(None)
|
NoneType = type(None)
|
||||||
@ -61,7 +67,6 @@ def __dataclass_transform__(
|
|||||||
|
|
||||||
|
|
||||||
class FieldInfo(PydanticFieldInfo):
|
class FieldInfo(PydanticFieldInfo):
|
||||||
|
|
||||||
nullable: Union[bool, PydanticUndefinedType]
|
nullable: Union[bool, PydanticUndefinedType]
|
||||||
|
|
||||||
def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None:
|
def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None:
|
||||||
@ -401,7 +406,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
|||||||
|
|
||||||
|
|
||||||
def get_sqlalchemy_type(field: FieldInfo) -> Any:
|
def get_sqlalchemy_type(field: FieldInfo) -> Any:
|
||||||
type_: type | None = field.annotation
|
type_: Optional[type] = field.annotation
|
||||||
|
|
||||||
# Resolve Optional fields
|
# Resolve Optional fields
|
||||||
if type_ is not None and get_origin(type_) is Union:
|
if type_ is not None and get_origin(type_) is Union:
|
||||||
@ -486,7 +491,7 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
|
|||||||
"index": index,
|
"index": index,
|
||||||
"unique": unique,
|
"unique": unique,
|
||||||
}
|
}
|
||||||
sa_default: PydanticUndefinedType | Callable[[], Any] = PydanticUndefined
|
sa_default: Union[PydanticUndefinedType, Callable[[], Any]] = PydanticUndefined
|
||||||
if field.default_factory:
|
if field.default_factory:
|
||||||
sa_default = field.default_factory
|
sa_default = field.default_factory
|
||||||
elif field.default is not PydanticUndefined:
|
elif field.default is not PydanticUndefined:
|
||||||
@ -531,7 +536,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
|
|||||||
def __init__(__pydantic_self__, **data: Any) -> None:
|
def __init__(__pydantic_self__, **data: Any) -> None:
|
||||||
old_dict = __pydantic_self__.__dict__.copy()
|
old_dict = __pydantic_self__.__dict__.copy()
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
__pydantic_self__.__dict__ = old_dict | __pydantic_self__.__dict__
|
__pydantic_self__.__dict__ = {**old_dict, **__pydantic_self__.__dict__}
|
||||||
non_pydantic_keys = data.keys() - __pydantic_self__.model_fields
|
non_pydantic_keys = data.keys() - __pydantic_self__.model_fields
|
||||||
for key in non_pydantic_keys:
|
for key in non_pydantic_keys:
|
||||||
if key in __pydantic_self__.__sqlmodel_relationships__:
|
if key in __pydantic_self__.__sqlmodel_relationships__:
|
||||||
@ -560,12 +565,12 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def model_validate(
|
def model_validate(
|
||||||
cls: type[_TSQLModel],
|
cls: Type[_TSQLModel],
|
||||||
obj: Any,
|
obj: Any,
|
||||||
*,
|
*,
|
||||||
strict: bool | None = None,
|
strict: Optional[bool] = None,
|
||||||
from_attributes: bool | None = None,
|
from_attributes: Optional[bool] = None,
|
||||||
context: dict[str, Any] | None = None,
|
context: Optional[Dict[str, Any]] = None,
|
||||||
) -> _TSQLModel:
|
) -> _TSQLModel:
|
||||||
# Somehow model validate doesn't call __init__ so it would remove our init logic
|
# Somehow model validate doesn't call __init__ so it would remove our init logic
|
||||||
validated = super().model_validate(
|
validated = super().model_validate(
|
||||||
@ -590,7 +595,9 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
|
|||||||
|
|
||||||
|
|
||||||
def _is_field_noneable(field: FieldInfo) -> bool:
|
def _is_field_noneable(field: FieldInfo) -> bool:
|
||||||
if not isinstance(field.nullable, PydanticUndefinedType):
|
if hasattr(field, "nullable") and not isinstance(
|
||||||
|
field.nullable, PydanticUndefinedType
|
||||||
|
):
|
||||||
return field.nullable
|
return field.nullable
|
||||||
if not field.is_required():
|
if not field.is_required():
|
||||||
default = getattr(field, "original_default", field.default)
|
default = getattr(field, "original_default", field.default)
|
||||||
|
Loading…
Reference in New Issue
Block a user