Merge pull request #2 from honglei/main

get_column_from_field support functional sa_column
This commit is contained in:
Santiago Martinez Balvanera 2023-08-25 17:05:25 +01:00 committed by GitHub
commit bcb6f32128
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 187 additions and 18 deletions

View File

@ -55,10 +55,7 @@ if sys.version_info >= (3, 8):
else:
from typing_extensions import get_args, get_origin
if sys.version_info >= (3, 9):
from typing import Annotated
else:
from typing_extensions import Annotated
from typing_extensions import Annotated, _AnnotatedAlias
_T = TypeVar("_T")
NoArgAnyCallable = Callable[[], Any]
@ -167,7 +164,7 @@ def Field(
unique: bool = False,
nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined,
index: Union[bool, PydanticUndefinedType] = PydanticUndefined,
sa_column: Union[Column, PydanticUndefinedType, types.FunctionType] = PydanticUndefined, # type: ignore
sa_column: Union[Column, PydanticUndefinedType, Callable[[], Column]] = PydanticUndefined, # type: ignore
sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined,
sa_column_kwargs: Union[
Mapping[str, Any], PydanticUndefinedType
@ -440,17 +437,19 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
def get_sqlalchemy_type(field: FieldInfo) -> Any:
type_: Optional[type] = field.annotation
# Resolve Optional/Union fields
def is_optional_or_union(type_: Optional[type]) -> bool:
def _is_optional_or_union(type_: Optional[type]) -> bool:
if sys.version_info >= (3, 10):
return get_origin(type_) in (types.UnionType, Union)
else:
return get_origin(type_) is Union
if type_ is not None and is_optional_or_union(type_):
def get_sqlalchemy_type(field: FieldInfo) -> Any:
type_: Optional[type] | _AnnotatedAlias = field.annotation
# Resolve Optional/Union fields
if type_ is not None and _is_optional_or_union(type_):
bases = get_args(type_)
if len(bases) > 2:
raise RuntimeError(
@ -462,14 +461,20 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
# UrlConstraints(max_length=512,
# allowed_schemes=['smb', 'ftp', 'file']) ]
if type_ is pydantic.AnyUrl:
if field.metadata:
meta = field.metadata[0]
return AutoString(length=meta.max_length)
else:
return AutoString
if get_origin(type_) is Annotated:
org_type = get_origin(type_)
if org_type is Annotated:
type2 = get_args(type_)[0]
if type2 is pydantic.AnyUrl:
meta = get_args(type_)[1]
return AutoString(length=meta.max_length)
elif org_type is pydantic.AnyUrl and type(type_) is _AnnotatedAlias:
return AutoString(type_.__metadata__[0].max_length)
# The 3rd is PydanticGeneralMetadata
metadata = _get_field_metadata(field)
@ -519,11 +524,18 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
"""
sa_column > field attributes > annotation info
"""
sa_column = getattr(field, "sa_column", PydanticUndefined)
if isinstance(sa_column, Column):
return sa_column
if isinstance(sa_column, MappedColumn):
return sa_column.column
if isinstance(sa_column, types.FunctionType):
col = sa_column()
assert isinstance(col, Column)
return col
sa_type = get_sqlalchemy_type(field)
primary_key = getattr(field, "primary_key", False)
index = getattr(field, "index", PydanticUndefined)
@ -587,6 +599,10 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
# in the Pydantic model so that when SQLAlchemy sets attributes that are
# added (e.g. when querying from DB) to the __fields_set__, this already exists
object.__setattr__(new_object, "__pydantic_fields_set__", set())
if not hasattr(new_object, "__pydantic_extra__"):
object.__setattr__(new_object, "__pydantic_extra__", None)
if not hasattr(new_object, "__pydantic_private__"):
object.__setattr__(new_object, "__pydantic_private__", None)
return new_object
def __init__(__pydantic_self__, **data: Any) -> None:
@ -636,7 +652,10 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
# remove defaults so they don't get validated
data = {}
for key, value in validated:
field = cls.model_fields[key]
field = cls.model_fields.get(key)
if field is None:
continue
if (
hasattr(field, "default")
@ -661,10 +680,11 @@ def _is_field_noneable(field: FieldInfo) -> bool:
return False
if field.annotation is None or field.annotation is NoneType:
return True
if get_origin(field.annotation) is Union:
if _is_optional_or_union(field.annotation):
for base in get_args(field.annotation):
if base is NoneType:
return True
return False
return False

View File

@ -8,7 +8,6 @@ from sqlalchemy.sql.type_api import TypeEngine
class AutoString(types.TypeDecorator): # type: ignore
impl = types.String
cache_ok = True
mysql_default_length = 255

View File

@ -0,0 +1,78 @@
import datetime
import sys
import pytest
from pydantic import AnyUrl, UrlConstraints
from sqlmodel import (
BigInteger,
Column,
DateTime,
Field,
Integer,
SQLModel,
String,
create_engine,
)
from typing_extensions import Annotated
MoveSharedUrl = Annotated[
AnyUrl, UrlConstraints(max_length=512, allowed_schemes=["smb", "ftp", "file"])
]
@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher")
def test_field_resuse():
class BasicFileLog(SQLModel):
resourceID: int = Field(
sa_column=lambda: Column(Integer, index=True), description=""" """
)
transportID: Annotated[int | None, Field(description=" for ")] = None
fileName: str = Field(
sa_column=lambda: Column(String, index=True), description=""" """
)
fileSize: int | None = Field(
sa_column=lambda: Column(BigInteger), ge=0, description=""" """
)
beginTime: datetime.datetime | None = Field(
sa_column=lambda: Column(
DateTime(timezone=True),
index=True,
),
description="",
)
class SendFileLog(BasicFileLog, table=True):
id: int | None = Field(
sa_column=Column(Integer, primary_key=True, autoincrement=True),
description=""" """,
)
sendUser: str
dstUrl: MoveSharedUrl | None
class RecvFileLog(BasicFileLog, table=True):
id: int | None = Field(
sa_column=Column(Integer, primary_key=True, autoincrement=True),
description=""" """,
)
recvUser: str
sqlite_file_name = "database.db"
sqlite_url = f"sqlite:///{sqlite_file_name}"
engine = create_engine(sqlite_url, echo=True)
SQLModel.metadata.drop_all(engine)
SQLModel.metadata.create_all(engine)
SendFileLog(
sendUser="j",
resourceID=1,
fileName="a.txt",
fileSize=3234,
beginTime=datetime.datetime.now(),
)
RecvFileLog(
sendUser="j",
resourceID=1,
fileName="a.txt",
fileSize=3234,
beginTime=datetime.datetime.now(),
)

50
tests/test_model_copy.py Normal file
View File

@ -0,0 +1,50 @@
from typing import Optional
from sqlmodel import Field, Session, SQLModel, create_engine
def test_model_copy(clear_sqlmodel):
"""Test validation of implicit and explict None values.
# For consistency with pydantic, validators are not to be called on
# arguments that are not explicitly provided.
https://github.com/tiangolo/sqlmodel/issues/230
https://github.com/samuelcolvin/pydantic/issues/1223
"""
class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
secret_name: str
age: Optional[int] = None
hero = Hero(name="Deadpond", secret_name="Dive Wilson", age=25)
engine = create_engine("sqlite://")
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
session.add(hero)
session.commit()
session.refresh(hero)
model_copy = hero.model_copy(update={"name": "Deadpond Copy"})
assert (
model_copy.name == "Deadpond Copy"
and model_copy.secret_name == "Dive Wilson"
and model_copy.age == 25
)
db_hero = session.get(Hero, hero.id)
db_copy = db_hero.model_copy(update={"name": "Deadpond Copy"})
assert (
db_copy.name == "Deadpond Copy"
and db_copy.secret_name == "Dive Wilson"
and db_copy.age == 25
)

View File

@ -1,8 +1,14 @@
from typing import Optional
import pytest
from pydantic import AnyUrl, UrlConstraints
from sqlalchemy.exc import IntegrityError
from sqlmodel import Field, Session, SQLModel, create_engine
from typing_extensions import Annotated
MoveSharedUrl = Annotated[
AnyUrl, UrlConstraints(max_length=512, allowed_schemes=["smb", "ftp", "file"])
]
def test_nullable_fields(clear_sqlmodel, caplog):
@ -13,6 +19,8 @@ def test_nullable_fields(clear_sqlmodel, caplog):
)
required_value: str
optional_default_ellipsis: Optional[str] = Field(default=...)
optional_no_field: Optional[str]
optional_no_field_default: Optional[str] = Field(description="no default")
optional_default_none: Optional[str] = Field(default=None)
optional_non_nullable: Optional[str] = Field(
nullable=False,
@ -49,6 +57,13 @@ def test_nullable_fields(clear_sqlmodel, caplog):
str_default_str_nullable: str = Field(default="default", nullable=True)
str_default_ellipsis_non_nullable: str = Field(default=..., nullable=False)
str_default_ellipsis_nullable: str = Field(default=..., nullable=True)
base_url: AnyUrl
optional_url: Optional[MoveSharedUrl] = Field(default=None, description="")
url: MoveSharedUrl
annotated_url: Annotated[MoveSharedUrl, Field(description="")]
annotated_optional_url: Annotated[
Optional[MoveSharedUrl], Field(description="")
] = None
engine = create_engine("sqlite://", echo=True)
SQLModel.metadata.create_all(engine)
@ -59,6 +74,8 @@ def test_nullable_fields(clear_sqlmodel, caplog):
assert "primary_key INTEGER NOT NULL," in create_table_log
assert "required_value VARCHAR NOT NULL," in create_table_log
assert "optional_default_ellipsis VARCHAR NOT NULL," in create_table_log
assert "optional_no_field VARCHAR," in create_table_log
assert "optional_no_field_default VARCHAR NOT NULL," in create_table_log
assert "optional_default_none VARCHAR," in create_table_log
assert "optional_non_nullable VARCHAR NOT NULL," in create_table_log
assert "optional_nullable VARCHAR," in create_table_log
@ -77,6 +94,11 @@ def test_nullable_fields(clear_sqlmodel, caplog):
assert "str_default_str_nullable VARCHAR," in create_table_log
assert "str_default_ellipsis_non_nullable VARCHAR NOT NULL," in create_table_log
assert "str_default_ellipsis_nullable VARCHAR," in create_table_log
assert "base_url VARCHAR NOT NULL," in create_table_log
assert "optional_url VARCHAR(512), " in create_table_log
assert "url VARCHAR(512) NOT NULL," in create_table_log
assert "annotated_url VARCHAR(512) NOT NULL," in create_table_log
assert "annotated_optional_url VARCHAR(512)," in create_table_log
# Test for regression in https://github.com/tiangolo/sqlmodel/issues/420