mirror of
https://github.com/PaiGramTeam/sqlmodel.git
synced 2024-11-21 22:58:22 +00:00
Merge pull request #2 from honglei/main
get_column_from_field support functional sa_column
This commit is contained in:
commit
bcb6f32128
@ -55,10 +55,7 @@ if sys.version_info >= (3, 8):
|
|||||||
else:
|
else:
|
||||||
from typing_extensions import get_args, get_origin
|
from typing_extensions import get_args, get_origin
|
||||||
|
|
||||||
if sys.version_info >= (3, 9):
|
from typing_extensions import Annotated, _AnnotatedAlias
|
||||||
from typing import Annotated
|
|
||||||
else:
|
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
NoArgAnyCallable = Callable[[], Any]
|
NoArgAnyCallable = Callable[[], Any]
|
||||||
@ -167,7 +164,7 @@ def Field(
|
|||||||
unique: bool = False,
|
unique: bool = False,
|
||||||
nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined,
|
nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined,
|
||||||
index: Union[bool, PydanticUndefinedType] = PydanticUndefined,
|
index: Union[bool, PydanticUndefinedType] = PydanticUndefined,
|
||||||
sa_column: Union[Column, PydanticUndefinedType, types.FunctionType] = PydanticUndefined, # type: ignore
|
sa_column: Union[Column, PydanticUndefinedType, Callable[[], Column]] = PydanticUndefined, # type: ignore
|
||||||
sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined,
|
sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined,
|
||||||
sa_column_kwargs: Union[
|
sa_column_kwargs: Union[
|
||||||
Mapping[str, Any], PydanticUndefinedType
|
Mapping[str, Any], PydanticUndefinedType
|
||||||
@ -440,17 +437,19 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
|||||||
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
|
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_optional_or_union(type_: Optional[type]) -> bool:
|
||||||
|
if sys.version_info >= (3, 10):
|
||||||
|
return get_origin(type_) in (types.UnionType, Union)
|
||||||
|
else:
|
||||||
|
return get_origin(type_) is Union
|
||||||
|
|
||||||
|
|
||||||
def get_sqlalchemy_type(field: FieldInfo) -> Any:
|
def get_sqlalchemy_type(field: FieldInfo) -> Any:
|
||||||
type_: Optional[type] = field.annotation
|
type_: Optional[type] | _AnnotatedAlias = field.annotation
|
||||||
|
|
||||||
# Resolve Optional/Union fields
|
# Resolve Optional/Union fields
|
||||||
def is_optional_or_union(type_: Optional[type]) -> bool:
|
|
||||||
if sys.version_info >= (3, 10):
|
|
||||||
return get_origin(type_) in (types.UnionType, Union)
|
|
||||||
else:
|
|
||||||
return get_origin(type_) is Union
|
|
||||||
|
|
||||||
if type_ is not None and is_optional_or_union(type_):
|
if type_ is not None and _is_optional_or_union(type_):
|
||||||
bases = get_args(type_)
|
bases = get_args(type_)
|
||||||
if len(bases) > 2:
|
if len(bases) > 2:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -462,14 +461,20 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
|
|||||||
# UrlConstraints(max_length=512,
|
# UrlConstraints(max_length=512,
|
||||||
# allowed_schemes=['smb', 'ftp', 'file']) ]
|
# allowed_schemes=['smb', 'ftp', 'file']) ]
|
||||||
if type_ is pydantic.AnyUrl:
|
if type_ is pydantic.AnyUrl:
|
||||||
meta = field.metadata[0]
|
if field.metadata:
|
||||||
return AutoString(length=meta.max_length)
|
meta = field.metadata[0]
|
||||||
|
return AutoString(length=meta.max_length)
|
||||||
|
else:
|
||||||
|
return AutoString
|
||||||
|
|
||||||
if get_origin(type_) is Annotated:
|
org_type = get_origin(type_)
|
||||||
|
if org_type is Annotated:
|
||||||
type2 = get_args(type_)[0]
|
type2 = get_args(type_)[0]
|
||||||
if type2 is pydantic.AnyUrl:
|
if type2 is pydantic.AnyUrl:
|
||||||
meta = get_args(type_)[1]
|
meta = get_args(type_)[1]
|
||||||
return AutoString(length=meta.max_length)
|
return AutoString(length=meta.max_length)
|
||||||
|
elif org_type is pydantic.AnyUrl and type(type_) is _AnnotatedAlias:
|
||||||
|
return AutoString(type_.__metadata__[0].max_length)
|
||||||
|
|
||||||
# The 3rd is PydanticGeneralMetadata
|
# The 3rd is PydanticGeneralMetadata
|
||||||
metadata = _get_field_metadata(field)
|
metadata = _get_field_metadata(field)
|
||||||
@ -519,11 +524,18 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
|
|||||||
|
|
||||||
|
|
||||||
def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
|
def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
|
||||||
|
"""
|
||||||
|
sa_column > field attributes > annotation info
|
||||||
|
"""
|
||||||
sa_column = getattr(field, "sa_column", PydanticUndefined)
|
sa_column = getattr(field, "sa_column", PydanticUndefined)
|
||||||
if isinstance(sa_column, Column):
|
if isinstance(sa_column, Column):
|
||||||
return sa_column
|
return sa_column
|
||||||
if isinstance(sa_column, MappedColumn):
|
if isinstance(sa_column, MappedColumn):
|
||||||
return sa_column.column
|
return sa_column.column
|
||||||
|
if isinstance(sa_column, types.FunctionType):
|
||||||
|
col = sa_column()
|
||||||
|
assert isinstance(col, Column)
|
||||||
|
return col
|
||||||
sa_type = get_sqlalchemy_type(field)
|
sa_type = get_sqlalchemy_type(field)
|
||||||
primary_key = getattr(field, "primary_key", False)
|
primary_key = getattr(field, "primary_key", False)
|
||||||
index = getattr(field, "index", PydanticUndefined)
|
index = getattr(field, "index", PydanticUndefined)
|
||||||
@ -587,6 +599,10 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
|
|||||||
# in the Pydantic model so that when SQLAlchemy sets attributes that are
|
# in the Pydantic model so that when SQLAlchemy sets attributes that are
|
||||||
# added (e.g. when querying from DB) to the __fields_set__, this already exists
|
# added (e.g. when querying from DB) to the __fields_set__, this already exists
|
||||||
object.__setattr__(new_object, "__pydantic_fields_set__", set())
|
object.__setattr__(new_object, "__pydantic_fields_set__", set())
|
||||||
|
if not hasattr(new_object, "__pydantic_extra__"):
|
||||||
|
object.__setattr__(new_object, "__pydantic_extra__", None)
|
||||||
|
if not hasattr(new_object, "__pydantic_private__"):
|
||||||
|
object.__setattr__(new_object, "__pydantic_private__", None)
|
||||||
return new_object
|
return new_object
|
||||||
|
|
||||||
def __init__(__pydantic_self__, **data: Any) -> None:
|
def __init__(__pydantic_self__, **data: Any) -> None:
|
||||||
@ -636,7 +652,10 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
|
|||||||
# remove defaults so they don't get validated
|
# remove defaults so they don't get validated
|
||||||
data = {}
|
data = {}
|
||||||
for key, value in validated:
|
for key, value in validated:
|
||||||
field = cls.model_fields[key]
|
field = cls.model_fields.get(key)
|
||||||
|
|
||||||
|
if field is None:
|
||||||
|
continue
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(field, "default")
|
hasattr(field, "default")
|
||||||
@ -661,10 +680,11 @@ def _is_field_noneable(field: FieldInfo) -> bool:
|
|||||||
return False
|
return False
|
||||||
if field.annotation is None or field.annotation is NoneType:
|
if field.annotation is None or field.annotation is NoneType:
|
||||||
return True
|
return True
|
||||||
if get_origin(field.annotation) is Union:
|
if _is_optional_or_union(field.annotation):
|
||||||
for base in get_args(field.annotation):
|
for base in get_args(field.annotation):
|
||||||
if base is NoneType:
|
if base is NoneType:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -8,7 +8,6 @@ from sqlalchemy.sql.type_api import TypeEngine
|
|||||||
|
|
||||||
|
|
||||||
class AutoString(types.TypeDecorator): # type: ignore
|
class AutoString(types.TypeDecorator): # type: ignore
|
||||||
|
|
||||||
impl = types.String
|
impl = types.String
|
||||||
cache_ok = True
|
cache_ok = True
|
||||||
mysql_default_length = 255
|
mysql_default_length = 255
|
||||||
|
78
tests/test_class_hierarchy.py
Normal file
78
tests/test_class_hierarchy.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import datetime
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import AnyUrl, UrlConstraints
|
||||||
|
from sqlmodel import (
|
||||||
|
BigInteger,
|
||||||
|
Column,
|
||||||
|
DateTime,
|
||||||
|
Field,
|
||||||
|
Integer,
|
||||||
|
SQLModel,
|
||||||
|
String,
|
||||||
|
create_engine,
|
||||||
|
)
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
MoveSharedUrl = Annotated[
|
||||||
|
AnyUrl, UrlConstraints(max_length=512, allowed_schemes=["smb", "ftp", "file"])
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher")
|
||||||
|
def test_field_resuse():
|
||||||
|
class BasicFileLog(SQLModel):
|
||||||
|
resourceID: int = Field(
|
||||||
|
sa_column=lambda: Column(Integer, index=True), description=""" """
|
||||||
|
)
|
||||||
|
transportID: Annotated[int | None, Field(description=" for ")] = None
|
||||||
|
fileName: str = Field(
|
||||||
|
sa_column=lambda: Column(String, index=True), description=""" """
|
||||||
|
)
|
||||||
|
fileSize: int | None = Field(
|
||||||
|
sa_column=lambda: Column(BigInteger), ge=0, description=""" """
|
||||||
|
)
|
||||||
|
beginTime: datetime.datetime | None = Field(
|
||||||
|
sa_column=lambda: Column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
index=True,
|
||||||
|
),
|
||||||
|
description="",
|
||||||
|
)
|
||||||
|
|
||||||
|
class SendFileLog(BasicFileLog, table=True):
|
||||||
|
id: int | None = Field(
|
||||||
|
sa_column=Column(Integer, primary_key=True, autoincrement=True),
|
||||||
|
description=""" """,
|
||||||
|
)
|
||||||
|
sendUser: str
|
||||||
|
dstUrl: MoveSharedUrl | None
|
||||||
|
|
||||||
|
class RecvFileLog(BasicFileLog, table=True):
|
||||||
|
id: int | None = Field(
|
||||||
|
sa_column=Column(Integer, primary_key=True, autoincrement=True),
|
||||||
|
description=""" """,
|
||||||
|
)
|
||||||
|
recvUser: str
|
||||||
|
|
||||||
|
sqlite_file_name = "database.db"
|
||||||
|
sqlite_url = f"sqlite:///{sqlite_file_name}"
|
||||||
|
|
||||||
|
engine = create_engine(sqlite_url, echo=True)
|
||||||
|
SQLModel.metadata.drop_all(engine)
|
||||||
|
SQLModel.metadata.create_all(engine)
|
||||||
|
SendFileLog(
|
||||||
|
sendUser="j",
|
||||||
|
resourceID=1,
|
||||||
|
fileName="a.txt",
|
||||||
|
fileSize=3234,
|
||||||
|
beginTime=datetime.datetime.now(),
|
||||||
|
)
|
||||||
|
RecvFileLog(
|
||||||
|
sendUser="j",
|
||||||
|
resourceID=1,
|
||||||
|
fileName="a.txt",
|
||||||
|
fileSize=3234,
|
||||||
|
beginTime=datetime.datetime.now(),
|
||||||
|
)
|
50
tests/test_model_copy.py
Normal file
50
tests/test_model_copy.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlmodel import Field, Session, SQLModel, create_engine
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_copy(clear_sqlmodel):
|
||||||
|
"""Test validation of implicit and explict None values.
|
||||||
|
|
||||||
|
# For consistency with pydantic, validators are not to be called on
|
||||||
|
# arguments that are not explicitly provided.
|
||||||
|
|
||||||
|
https://github.com/tiangolo/sqlmodel/issues/230
|
||||||
|
https://github.com/samuelcolvin/pydantic/issues/1223
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Hero(SQLModel, table=True):
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
name: str
|
||||||
|
secret_name: str
|
||||||
|
age: Optional[int] = None
|
||||||
|
|
||||||
|
hero = Hero(name="Deadpond", secret_name="Dive Wilson", age=25)
|
||||||
|
|
||||||
|
engine = create_engine("sqlite://")
|
||||||
|
|
||||||
|
SQLModel.metadata.create_all(engine)
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
session.add(hero)
|
||||||
|
session.commit()
|
||||||
|
session.refresh(hero)
|
||||||
|
|
||||||
|
model_copy = hero.model_copy(update={"name": "Deadpond Copy"})
|
||||||
|
|
||||||
|
assert (
|
||||||
|
model_copy.name == "Deadpond Copy"
|
||||||
|
and model_copy.secret_name == "Dive Wilson"
|
||||||
|
and model_copy.age == 25
|
||||||
|
)
|
||||||
|
|
||||||
|
db_hero = session.get(Hero, hero.id)
|
||||||
|
|
||||||
|
db_copy = db_hero.model_copy(update={"name": "Deadpond Copy"})
|
||||||
|
|
||||||
|
assert (
|
||||||
|
db_copy.name == "Deadpond Copy"
|
||||||
|
and db_copy.secret_name == "Dive Wilson"
|
||||||
|
and db_copy.age == 25
|
||||||
|
)
|
@ -1,8 +1,14 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from pydantic import AnyUrl, UrlConstraints
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlmodel import Field, Session, SQLModel, create_engine
|
from sqlmodel import Field, Session, SQLModel, create_engine
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
MoveSharedUrl = Annotated[
|
||||||
|
AnyUrl, UrlConstraints(max_length=512, allowed_schemes=["smb", "ftp", "file"])
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_nullable_fields(clear_sqlmodel, caplog):
|
def test_nullable_fields(clear_sqlmodel, caplog):
|
||||||
@ -13,6 +19,8 @@ def test_nullable_fields(clear_sqlmodel, caplog):
|
|||||||
)
|
)
|
||||||
required_value: str
|
required_value: str
|
||||||
optional_default_ellipsis: Optional[str] = Field(default=...)
|
optional_default_ellipsis: Optional[str] = Field(default=...)
|
||||||
|
optional_no_field: Optional[str]
|
||||||
|
optional_no_field_default: Optional[str] = Field(description="no default")
|
||||||
optional_default_none: Optional[str] = Field(default=None)
|
optional_default_none: Optional[str] = Field(default=None)
|
||||||
optional_non_nullable: Optional[str] = Field(
|
optional_non_nullable: Optional[str] = Field(
|
||||||
nullable=False,
|
nullable=False,
|
||||||
@ -49,6 +57,13 @@ def test_nullable_fields(clear_sqlmodel, caplog):
|
|||||||
str_default_str_nullable: str = Field(default="default", nullable=True)
|
str_default_str_nullable: str = Field(default="default", nullable=True)
|
||||||
str_default_ellipsis_non_nullable: str = Field(default=..., nullable=False)
|
str_default_ellipsis_non_nullable: str = Field(default=..., nullable=False)
|
||||||
str_default_ellipsis_nullable: str = Field(default=..., nullable=True)
|
str_default_ellipsis_nullable: str = Field(default=..., nullable=True)
|
||||||
|
base_url: AnyUrl
|
||||||
|
optional_url: Optional[MoveSharedUrl] = Field(default=None, description="")
|
||||||
|
url: MoveSharedUrl
|
||||||
|
annotated_url: Annotated[MoveSharedUrl, Field(description="")]
|
||||||
|
annotated_optional_url: Annotated[
|
||||||
|
Optional[MoveSharedUrl], Field(description="")
|
||||||
|
] = None
|
||||||
|
|
||||||
engine = create_engine("sqlite://", echo=True)
|
engine = create_engine("sqlite://", echo=True)
|
||||||
SQLModel.metadata.create_all(engine)
|
SQLModel.metadata.create_all(engine)
|
||||||
@ -59,6 +74,8 @@ def test_nullable_fields(clear_sqlmodel, caplog):
|
|||||||
assert "primary_key INTEGER NOT NULL," in create_table_log
|
assert "primary_key INTEGER NOT NULL," in create_table_log
|
||||||
assert "required_value VARCHAR NOT NULL," in create_table_log
|
assert "required_value VARCHAR NOT NULL," in create_table_log
|
||||||
assert "optional_default_ellipsis VARCHAR NOT NULL," in create_table_log
|
assert "optional_default_ellipsis VARCHAR NOT NULL," in create_table_log
|
||||||
|
assert "optional_no_field VARCHAR," in create_table_log
|
||||||
|
assert "optional_no_field_default VARCHAR NOT NULL," in create_table_log
|
||||||
assert "optional_default_none VARCHAR," in create_table_log
|
assert "optional_default_none VARCHAR," in create_table_log
|
||||||
assert "optional_non_nullable VARCHAR NOT NULL," in create_table_log
|
assert "optional_non_nullable VARCHAR NOT NULL," in create_table_log
|
||||||
assert "optional_nullable VARCHAR," in create_table_log
|
assert "optional_nullable VARCHAR," in create_table_log
|
||||||
@ -77,6 +94,11 @@ def test_nullable_fields(clear_sqlmodel, caplog):
|
|||||||
assert "str_default_str_nullable VARCHAR," in create_table_log
|
assert "str_default_str_nullable VARCHAR," in create_table_log
|
||||||
assert "str_default_ellipsis_non_nullable VARCHAR NOT NULL," in create_table_log
|
assert "str_default_ellipsis_non_nullable VARCHAR NOT NULL," in create_table_log
|
||||||
assert "str_default_ellipsis_nullable VARCHAR," in create_table_log
|
assert "str_default_ellipsis_nullable VARCHAR," in create_table_log
|
||||||
|
assert "base_url VARCHAR NOT NULL," in create_table_log
|
||||||
|
assert "optional_url VARCHAR(512), " in create_table_log
|
||||||
|
assert "url VARCHAR(512) NOT NULL," in create_table_log
|
||||||
|
assert "annotated_url VARCHAR(512) NOT NULL," in create_table_log
|
||||||
|
assert "annotated_optional_url VARCHAR(512)," in create_table_log
|
||||||
|
|
||||||
|
|
||||||
# Test for regression in https://github.com/tiangolo/sqlmodel/issues/420
|
# Test for regression in https://github.com/tiangolo/sqlmodel/issues/420
|
||||||
|
Loading…
Reference in New Issue
Block a user