Merge pull request #2 from honglei/main

get_column_from_field support functional sa_column
This commit is contained in:
Santiago Martinez Balvanera 2023-08-25 17:05:25 +01:00 committed by GitHub
commit bcb6f32128
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 187 additions and 18 deletions

View File

@ -55,10 +55,7 @@ if sys.version_info >= (3, 8):
else: else:
from typing_extensions import get_args, get_origin from typing_extensions import get_args, get_origin
if sys.version_info >= (3, 9): from typing_extensions import Annotated, _AnnotatedAlias
from typing import Annotated
else:
from typing_extensions import Annotated
_T = TypeVar("_T") _T = TypeVar("_T")
NoArgAnyCallable = Callable[[], Any] NoArgAnyCallable = Callable[[], Any]
@ -167,7 +164,7 @@ def Field(
unique: bool = False, unique: bool = False,
nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined, nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined,
index: Union[bool, PydanticUndefinedType] = PydanticUndefined, index: Union[bool, PydanticUndefinedType] = PydanticUndefined,
sa_column: Union[Column, PydanticUndefinedType, types.FunctionType] = PydanticUndefined, # type: ignore sa_column: Union[Column, PydanticUndefinedType, Callable[[], Column]] = PydanticUndefined, # type: ignore
sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined, sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined,
sa_column_kwargs: Union[ sa_column_kwargs: Union[
Mapping[str, Any], PydanticUndefinedType Mapping[str, Any], PydanticUndefinedType
@ -440,17 +437,19 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
def get_sqlalchemy_type(field: FieldInfo) -> Any: def _is_optional_or_union(type_: Optional[type]) -> bool:
type_: Optional[type] = field.annotation
# Resolve Optional/Union fields
def is_optional_or_union(type_: Optional[type]) -> bool:
if sys.version_info >= (3, 10): if sys.version_info >= (3, 10):
return get_origin(type_) in (types.UnionType, Union) return get_origin(type_) in (types.UnionType, Union)
else: else:
return get_origin(type_) is Union return get_origin(type_) is Union
if type_ is not None and is_optional_or_union(type_):
def get_sqlalchemy_type(field: FieldInfo) -> Any:
type_: Optional[type] | _AnnotatedAlias = field.annotation
# Resolve Optional/Union fields
if type_ is not None and _is_optional_or_union(type_):
bases = get_args(type_) bases = get_args(type_)
if len(bases) > 2: if len(bases) > 2:
raise RuntimeError( raise RuntimeError(
@ -462,14 +461,20 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
# UrlConstraints(max_length=512, # UrlConstraints(max_length=512,
# allowed_schemes=['smb', 'ftp', 'file']) ] # allowed_schemes=['smb', 'ftp', 'file']) ]
if type_ is pydantic.AnyUrl: if type_ is pydantic.AnyUrl:
if field.metadata:
meta = field.metadata[0] meta = field.metadata[0]
return AutoString(length=meta.max_length) return AutoString(length=meta.max_length)
else:
return AutoString
if get_origin(type_) is Annotated: org_type = get_origin(type_)
if org_type is Annotated:
type2 = get_args(type_)[0] type2 = get_args(type_)[0]
if type2 is pydantic.AnyUrl: if type2 is pydantic.AnyUrl:
meta = get_args(type_)[1] meta = get_args(type_)[1]
return AutoString(length=meta.max_length) return AutoString(length=meta.max_length)
elif org_type is pydantic.AnyUrl and type(type_) is _AnnotatedAlias:
return AutoString(type_.__metadata__[0].max_length)
# The 3rd is PydanticGeneralMetadata # The 3rd is PydanticGeneralMetadata
metadata = _get_field_metadata(field) metadata = _get_field_metadata(field)
@ -519,11 +524,18 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
def get_column_from_field(field: FieldInfo) -> Column: # type: ignore def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
"""
sa_column > field attributes > annotation info
"""
sa_column = getattr(field, "sa_column", PydanticUndefined) sa_column = getattr(field, "sa_column", PydanticUndefined)
if isinstance(sa_column, Column): if isinstance(sa_column, Column):
return sa_column return sa_column
if isinstance(sa_column, MappedColumn): if isinstance(sa_column, MappedColumn):
return sa_column.column return sa_column.column
if isinstance(sa_column, types.FunctionType):
col = sa_column()
assert isinstance(col, Column)
return col
sa_type = get_sqlalchemy_type(field) sa_type = get_sqlalchemy_type(field)
primary_key = getattr(field, "primary_key", False) primary_key = getattr(field, "primary_key", False)
index = getattr(field, "index", PydanticUndefined) index = getattr(field, "index", PydanticUndefined)
@ -587,6 +599,10 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
# in the Pydantic model so that when SQLAlchemy sets attributes that are # in the Pydantic model so that when SQLAlchemy sets attributes that are
# added (e.g. when querying from DB) to the __fields_set__, this already exists # added (e.g. when querying from DB) to the __fields_set__, this already exists
object.__setattr__(new_object, "__pydantic_fields_set__", set()) object.__setattr__(new_object, "__pydantic_fields_set__", set())
if not hasattr(new_object, "__pydantic_extra__"):
object.__setattr__(new_object, "__pydantic_extra__", None)
if not hasattr(new_object, "__pydantic_private__"):
object.__setattr__(new_object, "__pydantic_private__", None)
return new_object return new_object
def __init__(__pydantic_self__, **data: Any) -> None: def __init__(__pydantic_self__, **data: Any) -> None:
@ -636,7 +652,10 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
# remove defaults so they don't get validated # remove defaults so they don't get validated
data = {} data = {}
for key, value in validated: for key, value in validated:
field = cls.model_fields[key] field = cls.model_fields.get(key)
if field is None:
continue
if ( if (
hasattr(field, "default") hasattr(field, "default")
@ -661,10 +680,11 @@ def _is_field_noneable(field: FieldInfo) -> bool:
return False return False
if field.annotation is None or field.annotation is NoneType: if field.annotation is None or field.annotation is NoneType:
return True return True
if get_origin(field.annotation) is Union: if _is_optional_or_union(field.annotation):
for base in get_args(field.annotation): for base in get_args(field.annotation):
if base is NoneType: if base is NoneType:
return True return True
return False return False
return False return False

View File

@ -8,7 +8,6 @@ from sqlalchemy.sql.type_api import TypeEngine
class AutoString(types.TypeDecorator): # type: ignore class AutoString(types.TypeDecorator): # type: ignore
impl = types.String impl = types.String
cache_ok = True cache_ok = True
mysql_default_length = 255 mysql_default_length = 255

View File

@ -0,0 +1,78 @@
import datetime
import sys
import pytest
from pydantic import AnyUrl, UrlConstraints
from sqlmodel import (
BigInteger,
Column,
DateTime,
Field,
Integer,
SQLModel,
String,
create_engine,
)
from typing_extensions import Annotated
MoveSharedUrl = Annotated[
AnyUrl, UrlConstraints(max_length=512, allowed_schemes=["smb", "ftp", "file"])
]
@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher")
def test_field_resuse():
class BasicFileLog(SQLModel):
resourceID: int = Field(
sa_column=lambda: Column(Integer, index=True), description=""" """
)
transportID: Annotated[int | None, Field(description=" for ")] = None
fileName: str = Field(
sa_column=lambda: Column(String, index=True), description=""" """
)
fileSize: int | None = Field(
sa_column=lambda: Column(BigInteger), ge=0, description=""" """
)
beginTime: datetime.datetime | None = Field(
sa_column=lambda: Column(
DateTime(timezone=True),
index=True,
),
description="",
)
class SendFileLog(BasicFileLog, table=True):
id: int | None = Field(
sa_column=Column(Integer, primary_key=True, autoincrement=True),
description=""" """,
)
sendUser: str
dstUrl: MoveSharedUrl | None
class RecvFileLog(BasicFileLog, table=True):
id: int | None = Field(
sa_column=Column(Integer, primary_key=True, autoincrement=True),
description=""" """,
)
recvUser: str
sqlite_file_name = "database.db"
sqlite_url = f"sqlite:///{sqlite_file_name}"
engine = create_engine(sqlite_url, echo=True)
SQLModel.metadata.drop_all(engine)
SQLModel.metadata.create_all(engine)
SendFileLog(
sendUser="j",
resourceID=1,
fileName="a.txt",
fileSize=3234,
beginTime=datetime.datetime.now(),
)
RecvFileLog(
sendUser="j",
resourceID=1,
fileName="a.txt",
fileSize=3234,
beginTime=datetime.datetime.now(),
)

50
tests/test_model_copy.py Normal file
View File

@ -0,0 +1,50 @@
from typing import Optional
from sqlmodel import Field, Session, SQLModel, create_engine
def test_model_copy(clear_sqlmodel):
"""Test validation of implicit and explict None values.
# For consistency with pydantic, validators are not to be called on
# arguments that are not explicitly provided.
https://github.com/tiangolo/sqlmodel/issues/230
https://github.com/samuelcolvin/pydantic/issues/1223
"""
class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
secret_name: str
age: Optional[int] = None
hero = Hero(name="Deadpond", secret_name="Dive Wilson", age=25)
engine = create_engine("sqlite://")
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
session.add(hero)
session.commit()
session.refresh(hero)
model_copy = hero.model_copy(update={"name": "Deadpond Copy"})
assert (
model_copy.name == "Deadpond Copy"
and model_copy.secret_name == "Dive Wilson"
and model_copy.age == 25
)
db_hero = session.get(Hero, hero.id)
db_copy = db_hero.model_copy(update={"name": "Deadpond Copy"})
assert (
db_copy.name == "Deadpond Copy"
and db_copy.secret_name == "Dive Wilson"
and db_copy.age == 25
)

View File

@ -1,8 +1,14 @@
from typing import Optional from typing import Optional
import pytest import pytest
from pydantic import AnyUrl, UrlConstraints
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlmodel import Field, Session, SQLModel, create_engine from sqlmodel import Field, Session, SQLModel, create_engine
from typing_extensions import Annotated
MoveSharedUrl = Annotated[
AnyUrl, UrlConstraints(max_length=512, allowed_schemes=["smb", "ftp", "file"])
]
def test_nullable_fields(clear_sqlmodel, caplog): def test_nullable_fields(clear_sqlmodel, caplog):
@ -13,6 +19,8 @@ def test_nullable_fields(clear_sqlmodel, caplog):
) )
required_value: str required_value: str
optional_default_ellipsis: Optional[str] = Field(default=...) optional_default_ellipsis: Optional[str] = Field(default=...)
optional_no_field: Optional[str]
optional_no_field_default: Optional[str] = Field(description="no default")
optional_default_none: Optional[str] = Field(default=None) optional_default_none: Optional[str] = Field(default=None)
optional_non_nullable: Optional[str] = Field( optional_non_nullable: Optional[str] = Field(
nullable=False, nullable=False,
@ -49,6 +57,13 @@ def test_nullable_fields(clear_sqlmodel, caplog):
str_default_str_nullable: str = Field(default="default", nullable=True) str_default_str_nullable: str = Field(default="default", nullable=True)
str_default_ellipsis_non_nullable: str = Field(default=..., nullable=False) str_default_ellipsis_non_nullable: str = Field(default=..., nullable=False)
str_default_ellipsis_nullable: str = Field(default=..., nullable=True) str_default_ellipsis_nullable: str = Field(default=..., nullable=True)
base_url: AnyUrl
optional_url: Optional[MoveSharedUrl] = Field(default=None, description="")
url: MoveSharedUrl
annotated_url: Annotated[MoveSharedUrl, Field(description="")]
annotated_optional_url: Annotated[
Optional[MoveSharedUrl], Field(description="")
] = None
engine = create_engine("sqlite://", echo=True) engine = create_engine("sqlite://", echo=True)
SQLModel.metadata.create_all(engine) SQLModel.metadata.create_all(engine)
@ -59,6 +74,8 @@ def test_nullable_fields(clear_sqlmodel, caplog):
assert "primary_key INTEGER NOT NULL," in create_table_log assert "primary_key INTEGER NOT NULL," in create_table_log
assert "required_value VARCHAR NOT NULL," in create_table_log assert "required_value VARCHAR NOT NULL," in create_table_log
assert "optional_default_ellipsis VARCHAR NOT NULL," in create_table_log assert "optional_default_ellipsis VARCHAR NOT NULL," in create_table_log
assert "optional_no_field VARCHAR," in create_table_log
assert "optional_no_field_default VARCHAR NOT NULL," in create_table_log
assert "optional_default_none VARCHAR," in create_table_log assert "optional_default_none VARCHAR," in create_table_log
assert "optional_non_nullable VARCHAR NOT NULL," in create_table_log assert "optional_non_nullable VARCHAR NOT NULL," in create_table_log
assert "optional_nullable VARCHAR," in create_table_log assert "optional_nullable VARCHAR," in create_table_log
@ -77,6 +94,11 @@ def test_nullable_fields(clear_sqlmodel, caplog):
assert "str_default_str_nullable VARCHAR," in create_table_log assert "str_default_str_nullable VARCHAR," in create_table_log
assert "str_default_ellipsis_non_nullable VARCHAR NOT NULL," in create_table_log assert "str_default_ellipsis_non_nullable VARCHAR NOT NULL," in create_table_log
assert "str_default_ellipsis_nullable VARCHAR," in create_table_log assert "str_default_ellipsis_nullable VARCHAR," in create_table_log
assert "base_url VARCHAR NOT NULL," in create_table_log
assert "optional_url VARCHAR(512), " in create_table_log
assert "url VARCHAR(512) NOT NULL," in create_table_log
assert "annotated_url VARCHAR(512) NOT NULL," in create_table_log
assert "annotated_optional_url VARCHAR(512)," in create_table_log
# Test for regression in https://github.com/tiangolo/sqlmodel/issues/420 # Test for regression in https://github.com/tiangolo/sqlmodel/issues/420