mirror of
https://github.com/PaiGramTeam/sqlmodel.git
synced 2024-11-21 14:48:30 +00:00
Merge pull request #2 from honglei/main
get_column_from_field support functional sa_column
This commit is contained in:
commit
bcb6f32128
@ -55,10 +55,7 @@ if sys.version_info >= (3, 8):
|
||||
else:
|
||||
from typing_extensions import get_args, get_origin
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
from typing import Annotated
|
||||
else:
|
||||
from typing_extensions import Annotated
|
||||
from typing_extensions import Annotated, _AnnotatedAlias
|
||||
|
||||
_T = TypeVar("_T")
|
||||
NoArgAnyCallable = Callable[[], Any]
|
||||
@ -167,7 +164,7 @@ def Field(
|
||||
unique: bool = False,
|
||||
nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined,
|
||||
index: Union[bool, PydanticUndefinedType] = PydanticUndefined,
|
||||
sa_column: Union[Column, PydanticUndefinedType, types.FunctionType] = PydanticUndefined, # type: ignore
|
||||
sa_column: Union[Column, PydanticUndefinedType, Callable[[], Column]] = PydanticUndefined, # type: ignore
|
||||
sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined,
|
||||
sa_column_kwargs: Union[
|
||||
Mapping[str, Any], PydanticUndefinedType
|
||||
@ -440,17 +437,19 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
||||
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
|
||||
|
||||
|
||||
def _is_optional_or_union(type_: Optional[type]) -> bool:
|
||||
if sys.version_info >= (3, 10):
|
||||
return get_origin(type_) in (types.UnionType, Union)
|
||||
else:
|
||||
return get_origin(type_) is Union
|
||||
|
||||
|
||||
def get_sqlalchemy_type(field: FieldInfo) -> Any:
|
||||
type_: Optional[type] = field.annotation
|
||||
type_: Optional[type] | _AnnotatedAlias = field.annotation
|
||||
|
||||
# Resolve Optional/Union fields
|
||||
def is_optional_or_union(type_: Optional[type]) -> bool:
|
||||
if sys.version_info >= (3, 10):
|
||||
return get_origin(type_) in (types.UnionType, Union)
|
||||
else:
|
||||
return get_origin(type_) is Union
|
||||
|
||||
if type_ is not None and is_optional_or_union(type_):
|
||||
if type_ is not None and _is_optional_or_union(type_):
|
||||
bases = get_args(type_)
|
||||
if len(bases) > 2:
|
||||
raise RuntimeError(
|
||||
@ -462,14 +461,20 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
|
||||
# UrlConstraints(max_length=512,
|
||||
# allowed_schemes=['smb', 'ftp', 'file']) ]
|
||||
if type_ is pydantic.AnyUrl:
|
||||
meta = field.metadata[0]
|
||||
return AutoString(length=meta.max_length)
|
||||
if field.metadata:
|
||||
meta = field.metadata[0]
|
||||
return AutoString(length=meta.max_length)
|
||||
else:
|
||||
return AutoString
|
||||
|
||||
if get_origin(type_) is Annotated:
|
||||
org_type = get_origin(type_)
|
||||
if org_type is Annotated:
|
||||
type2 = get_args(type_)[0]
|
||||
if type2 is pydantic.AnyUrl:
|
||||
meta = get_args(type_)[1]
|
||||
return AutoString(length=meta.max_length)
|
||||
elif org_type is pydantic.AnyUrl and type(type_) is _AnnotatedAlias:
|
||||
return AutoString(type_.__metadata__[0].max_length)
|
||||
|
||||
# The 3rd is PydanticGeneralMetadata
|
||||
metadata = _get_field_metadata(field)
|
||||
@ -519,11 +524,18 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
|
||||
|
||||
|
||||
def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
|
||||
"""
|
||||
sa_column > field attributes > annotation info
|
||||
"""
|
||||
sa_column = getattr(field, "sa_column", PydanticUndefined)
|
||||
if isinstance(sa_column, Column):
|
||||
return sa_column
|
||||
if isinstance(sa_column, MappedColumn):
|
||||
return sa_column.column
|
||||
if isinstance(sa_column, types.FunctionType):
|
||||
col = sa_column()
|
||||
assert isinstance(col, Column)
|
||||
return col
|
||||
sa_type = get_sqlalchemy_type(field)
|
||||
primary_key = getattr(field, "primary_key", False)
|
||||
index = getattr(field, "index", PydanticUndefined)
|
||||
@ -587,6 +599,10 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
|
||||
# in the Pydantic model so that when SQLAlchemy sets attributes that are
|
||||
# added (e.g. when querying from DB) to the __fields_set__, this already exists
|
||||
object.__setattr__(new_object, "__pydantic_fields_set__", set())
|
||||
if not hasattr(new_object, "__pydantic_extra__"):
|
||||
object.__setattr__(new_object, "__pydantic_extra__", None)
|
||||
if not hasattr(new_object, "__pydantic_private__"):
|
||||
object.__setattr__(new_object, "__pydantic_private__", None)
|
||||
return new_object
|
||||
|
||||
def __init__(__pydantic_self__, **data: Any) -> None:
|
||||
@ -636,7 +652,10 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
|
||||
# remove defaults so they don't get validated
|
||||
data = {}
|
||||
for key, value in validated:
|
||||
field = cls.model_fields[key]
|
||||
field = cls.model_fields.get(key)
|
||||
|
||||
if field is None:
|
||||
continue
|
||||
|
||||
if (
|
||||
hasattr(field, "default")
|
||||
@ -661,10 +680,11 @@ def _is_field_noneable(field: FieldInfo) -> bool:
|
||||
return False
|
||||
if field.annotation is None or field.annotation is NoneType:
|
||||
return True
|
||||
if get_origin(field.annotation) is Union:
|
||||
if _is_optional_or_union(field.annotation):
|
||||
for base in get_args(field.annotation):
|
||||
if base is NoneType:
|
||||
return True
|
||||
|
||||
return False
|
||||
return False
|
||||
|
||||
|
@ -8,7 +8,6 @@ from sqlalchemy.sql.type_api import TypeEngine
|
||||
|
||||
|
||||
class AutoString(types.TypeDecorator): # type: ignore
|
||||
|
||||
impl = types.String
|
||||
cache_ok = True
|
||||
mysql_default_length = 255
|
||||
|
78
tests/test_class_hierarchy.py
Normal file
78
tests/test_class_hierarchy.py
Normal file
@ -0,0 +1,78 @@
|
||||
import datetime
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from pydantic import AnyUrl, UrlConstraints
|
||||
from sqlmodel import (
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Field,
|
||||
Integer,
|
||||
SQLModel,
|
||||
String,
|
||||
create_engine,
|
||||
)
|
||||
from typing_extensions import Annotated
|
||||
|
||||
MoveSharedUrl = Annotated[
|
||||
AnyUrl, UrlConstraints(max_length=512, allowed_schemes=["smb", "ftp", "file"])
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher")
|
||||
def test_field_resuse():
|
||||
class BasicFileLog(SQLModel):
|
||||
resourceID: int = Field(
|
||||
sa_column=lambda: Column(Integer, index=True), description=""" """
|
||||
)
|
||||
transportID: Annotated[int | None, Field(description=" for ")] = None
|
||||
fileName: str = Field(
|
||||
sa_column=lambda: Column(String, index=True), description=""" """
|
||||
)
|
||||
fileSize: int | None = Field(
|
||||
sa_column=lambda: Column(BigInteger), ge=0, description=""" """
|
||||
)
|
||||
beginTime: datetime.datetime | None = Field(
|
||||
sa_column=lambda: Column(
|
||||
DateTime(timezone=True),
|
||||
index=True,
|
||||
),
|
||||
description="",
|
||||
)
|
||||
|
||||
class SendFileLog(BasicFileLog, table=True):
|
||||
id: int | None = Field(
|
||||
sa_column=Column(Integer, primary_key=True, autoincrement=True),
|
||||
description=""" """,
|
||||
)
|
||||
sendUser: str
|
||||
dstUrl: MoveSharedUrl | None
|
||||
|
||||
class RecvFileLog(BasicFileLog, table=True):
|
||||
id: int | None = Field(
|
||||
sa_column=Column(Integer, primary_key=True, autoincrement=True),
|
||||
description=""" """,
|
||||
)
|
||||
recvUser: str
|
||||
|
||||
sqlite_file_name = "database.db"
|
||||
sqlite_url = f"sqlite:///{sqlite_file_name}"
|
||||
|
||||
engine = create_engine(sqlite_url, echo=True)
|
||||
SQLModel.metadata.drop_all(engine)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
SendFileLog(
|
||||
sendUser="j",
|
||||
resourceID=1,
|
||||
fileName="a.txt",
|
||||
fileSize=3234,
|
||||
beginTime=datetime.datetime.now(),
|
||||
)
|
||||
RecvFileLog(
|
||||
sendUser="j",
|
||||
resourceID=1,
|
||||
fileName="a.txt",
|
||||
fileSize=3234,
|
||||
beginTime=datetime.datetime.now(),
|
||||
)
|
50
tests/test_model_copy.py
Normal file
50
tests/test_model_copy.py
Normal file
@ -0,0 +1,50 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import Field, Session, SQLModel, create_engine
|
||||
|
||||
|
||||
def test_model_copy(clear_sqlmodel):
|
||||
"""Test validation of implicit and explict None values.
|
||||
|
||||
# For consistency with pydantic, validators are not to be called on
|
||||
# arguments that are not explicitly provided.
|
||||
|
||||
https://github.com/tiangolo/sqlmodel/issues/230
|
||||
https://github.com/samuelcolvin/pydantic/issues/1223
|
||||
|
||||
"""
|
||||
|
||||
class Hero(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
name: str
|
||||
secret_name: str
|
||||
age: Optional[int] = None
|
||||
|
||||
hero = Hero(name="Deadpond", secret_name="Dive Wilson", age=25)
|
||||
|
||||
engine = create_engine("sqlite://")
|
||||
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
with Session(engine) as session:
|
||||
session.add(hero)
|
||||
session.commit()
|
||||
session.refresh(hero)
|
||||
|
||||
model_copy = hero.model_copy(update={"name": "Deadpond Copy"})
|
||||
|
||||
assert (
|
||||
model_copy.name == "Deadpond Copy"
|
||||
and model_copy.secret_name == "Dive Wilson"
|
||||
and model_copy.age == 25
|
||||
)
|
||||
|
||||
db_hero = session.get(Hero, hero.id)
|
||||
|
||||
db_copy = db_hero.model_copy(update={"name": "Deadpond Copy"})
|
||||
|
||||
assert (
|
||||
db_copy.name == "Deadpond Copy"
|
||||
and db_copy.secret_name == "Dive Wilson"
|
||||
and db_copy.age == 25
|
||||
)
|
@ -1,8 +1,14 @@
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from pydantic import AnyUrl, UrlConstraints
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlmodel import Field, Session, SQLModel, create_engine
|
||||
from typing_extensions import Annotated
|
||||
|
||||
MoveSharedUrl = Annotated[
|
||||
AnyUrl, UrlConstraints(max_length=512, allowed_schemes=["smb", "ftp", "file"])
|
||||
]
|
||||
|
||||
|
||||
def test_nullable_fields(clear_sqlmodel, caplog):
|
||||
@ -13,6 +19,8 @@ def test_nullable_fields(clear_sqlmodel, caplog):
|
||||
)
|
||||
required_value: str
|
||||
optional_default_ellipsis: Optional[str] = Field(default=...)
|
||||
optional_no_field: Optional[str]
|
||||
optional_no_field_default: Optional[str] = Field(description="no default")
|
||||
optional_default_none: Optional[str] = Field(default=None)
|
||||
optional_non_nullable: Optional[str] = Field(
|
||||
nullable=False,
|
||||
@ -49,6 +57,13 @@ def test_nullable_fields(clear_sqlmodel, caplog):
|
||||
str_default_str_nullable: str = Field(default="default", nullable=True)
|
||||
str_default_ellipsis_non_nullable: str = Field(default=..., nullable=False)
|
||||
str_default_ellipsis_nullable: str = Field(default=..., nullable=True)
|
||||
base_url: AnyUrl
|
||||
optional_url: Optional[MoveSharedUrl] = Field(default=None, description="")
|
||||
url: MoveSharedUrl
|
||||
annotated_url: Annotated[MoveSharedUrl, Field(description="")]
|
||||
annotated_optional_url: Annotated[
|
||||
Optional[MoveSharedUrl], Field(description="")
|
||||
] = None
|
||||
|
||||
engine = create_engine("sqlite://", echo=True)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
@ -59,6 +74,8 @@ def test_nullable_fields(clear_sqlmodel, caplog):
|
||||
assert "primary_key INTEGER NOT NULL," in create_table_log
|
||||
assert "required_value VARCHAR NOT NULL," in create_table_log
|
||||
assert "optional_default_ellipsis VARCHAR NOT NULL," in create_table_log
|
||||
assert "optional_no_field VARCHAR," in create_table_log
|
||||
assert "optional_no_field_default VARCHAR NOT NULL," in create_table_log
|
||||
assert "optional_default_none VARCHAR," in create_table_log
|
||||
assert "optional_non_nullable VARCHAR NOT NULL," in create_table_log
|
||||
assert "optional_nullable VARCHAR," in create_table_log
|
||||
@ -77,6 +94,11 @@ def test_nullable_fields(clear_sqlmodel, caplog):
|
||||
assert "str_default_str_nullable VARCHAR," in create_table_log
|
||||
assert "str_default_ellipsis_non_nullable VARCHAR NOT NULL," in create_table_log
|
||||
assert "str_default_ellipsis_nullable VARCHAR," in create_table_log
|
||||
assert "base_url VARCHAR NOT NULL," in create_table_log
|
||||
assert "optional_url VARCHAR(512), " in create_table_log
|
||||
assert "url VARCHAR(512) NOT NULL," in create_table_log
|
||||
assert "annotated_url VARCHAR(512) NOT NULL," in create_table_log
|
||||
assert "annotated_optional_url VARCHAR(512)," in create_table_log
|
||||
|
||||
|
||||
# Test for regression in https://github.com/tiangolo/sqlmodel/issues/420
|
||||
|
Loading…
Reference in New Issue
Block a user