diff --git a/sqlmodel/main.py b/sqlmodel/main.py index ab41b7b..dbc05c4 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -55,10 +55,7 @@ if sys.version_info >= (3, 8): else: from typing_extensions import get_args, get_origin -if sys.version_info >= (3, 9): - from typing import Annotated -else: - from typing_extensions import Annotated +from typing_extensions import Annotated, _AnnotatedAlias _T = TypeVar("_T") NoArgAnyCallable = Callable[[], Any] @@ -167,7 +164,7 @@ def Field( unique: bool = False, nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined, index: Union[bool, PydanticUndefinedType] = PydanticUndefined, - sa_column: Union[Column, PydanticUndefinedType, types.FunctionType] = PydanticUndefined, # type: ignore + sa_column: Union[Column, PydanticUndefinedType, Callable[[], Column]] = PydanticUndefined, # type: ignore sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined, sa_column_kwargs: Union[ Mapping[str, Any], PydanticUndefinedType @@ -440,17 +437,19 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) +def _is_optional_or_union(type_: Optional[type]) -> bool: + if sys.version_info >= (3, 10): + return get_origin(type_) in (types.UnionType, Union) + else: + return get_origin(type_) is Union + + def get_sqlalchemy_type(field: FieldInfo) -> Any: - type_: Optional[type] = field.annotation + type_: Optional[type] | _AnnotatedAlias = field.annotation # Resolve Optional/Union fields - def is_optional_or_union(type_: Optional[type]) -> bool: - if sys.version_info >= (3, 10): - return get_origin(type_) in (types.UnionType, Union) - else: - return get_origin(type_) is Union - if type_ is not None and is_optional_or_union(type_): + if type_ is not None and _is_optional_or_union(type_): bases = get_args(type_) if len(bases) > 2: raise RuntimeError( @@ -462,14 +461,20 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any: # UrlConstraints(max_length=512, # allowed_schemes=['smb', 'ftp', 'file']) ] if type_ is pydantic.AnyUrl: - meta = field.metadata[0] - return AutoString(length=meta.max_length) + if field.metadata: + meta = field.metadata[0] + return AutoString(length=meta.max_length) + else: + return AutoString - if get_origin(type_) is Annotated: + org_type = get_origin(type_) + if org_type is Annotated: type2 = get_args(type_)[0] if type2 is pydantic.AnyUrl: meta = get_args(type_)[1] return AutoString(length=meta.max_length) + elif org_type is pydantic.AnyUrl and type(type_) is _AnnotatedAlias: + return AutoString(type_.__metadata__[0].max_length) # The 3rd is PydanticGeneralMetadata metadata = _get_field_metadata(field) @@ -519,11 +524,18 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any: def get_column_from_field(field: FieldInfo) -> Column: # type: ignore + """ + sa_column > field attributes > annotation info + """ sa_column = getattr(field, "sa_column", PydanticUndefined) if isinstance(sa_column, Column): return sa_column if isinstance(sa_column, MappedColumn): return sa_column.column + if isinstance(sa_column, types.FunctionType): + col = sa_column() + assert isinstance(col, Column) + return col sa_type = get_sqlalchemy_type(field) primary_key = getattr(field, "primary_key", False) index = getattr(field, "index", PydanticUndefined) @@ -587,6 +599,10 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry # in the Pydantic model so that when SQLAlchemy sets attributes that are # added (e.g. when querying from DB) to the __fields_set__, this already exists object.__setattr__(new_object, "__pydantic_fields_set__", set()) + if not hasattr(new_object, "__pydantic_extra__"): + object.__setattr__(new_object, "__pydantic_extra__", None) + if not hasattr(new_object, "__pydantic_private__"): + object.__setattr__(new_object, "__pydantic_private__", None) return new_object def __init__(__pydantic_self__, **data: Any) -> None: @@ -636,7 +652,10 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry # remove defaults so they don't get validated data = {} for key, value in validated: - field = cls.model_fields[key] + field = cls.model_fields.get(key) + + if field is None: + continue if ( hasattr(field, "default") @@ -661,10 +680,11 @@ def _is_field_noneable(field: FieldInfo) -> bool: return False if field.annotation is None or field.annotation is NoneType: return True - if get_origin(field.annotation) is Union: + if _is_optional_or_union(field.annotation): for base in get_args(field.annotation): if base is NoneType: return True + return False return False diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py index da6551b..aa30950 100644 --- a/sqlmodel/sql/sqltypes.py +++ b/sqlmodel/sql/sqltypes.py @@ -8,7 +8,6 @@ from sqlalchemy.sql.type_api import TypeEngine class AutoString(types.TypeDecorator): # type: ignore - impl = types.String cache_ok = True mysql_default_length = 255 diff --git a/tests/test_class_hierarchy.py b/tests/test_class_hierarchy.py new file mode 100644 index 0000000..81314e4 --- /dev/null +++ b/tests/test_class_hierarchy.py @@ -0,0 +1,78 @@ +import datetime +import sys + +import pytest +from pydantic import AnyUrl, UrlConstraints +from sqlmodel import ( + BigInteger, + Column, + DateTime, + Field, + Integer, + SQLModel, + String, + create_engine, +) +from typing_extensions import Annotated + +MoveSharedUrl = Annotated[ + AnyUrl, UrlConstraints(max_length=512, allowed_schemes=["smb", "ftp", "file"]) +] + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") +def test_field_resuse(): + class BasicFileLog(SQLModel): + resourceID: int = Field( + sa_column=lambda: Column(Integer, index=True), description=""" """ + ) + transportID: Annotated[int | None, Field(description=" for ")] = None + fileName: str = Field( + sa_column=lambda: Column(String, index=True), description=""" """ + ) + fileSize: int | None = Field( + sa_column=lambda: Column(BigInteger), ge=0, description=""" """ + ) + beginTime: datetime.datetime | None = Field( + sa_column=lambda: Column( + DateTime(timezone=True), + index=True, + ), + description="", + ) + + class SendFileLog(BasicFileLog, table=True): + id: int | None = Field( + sa_column=Column(Integer, primary_key=True, autoincrement=True), + description=""" """, + ) + sendUser: str + dstUrl: MoveSharedUrl | None + + class RecvFileLog(BasicFileLog, table=True): + id: int | None = Field( + sa_column=Column(Integer, primary_key=True, autoincrement=True), + description=""" """, + ) + recvUser: str + + sqlite_file_name = "database.db" + sqlite_url = f"sqlite:///{sqlite_file_name}" + + engine = create_engine(sqlite_url, echo=True) + SQLModel.metadata.drop_all(engine) + SQLModel.metadata.create_all(engine) + SendFileLog( + sendUser="j", + resourceID=1, + fileName="a.txt", + fileSize=3234, + beginTime=datetime.datetime.now(), + ) + RecvFileLog( + sendUser="j", + resourceID=1, + fileName="a.txt", + fileSize=3234, + beginTime=datetime.datetime.now(), + ) diff --git a/tests/test_model_copy.py b/tests/test_model_copy.py new file mode 100644 index 0000000..d21a6e4 --- /dev/null +++ b/tests/test_model_copy.py @@ -0,0 +1,50 @@ +from typing import Optional + +from sqlmodel import Field, Session, SQLModel, create_engine + + +def test_model_copy(clear_sqlmodel): + """Test validation of implicit and explict None values. + + # For consistency with pydantic, validators are not to be called on + # arguments that are not explicitly provided. + + https://github.com/tiangolo/sqlmodel/issues/230 + https://github.com/samuelcolvin/pydantic/issues/1223 + + """ + + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + secret_name: str + age: Optional[int] = None + + hero = Hero(name="Deadpond", secret_name="Dive Wilson", age=25) + + engine = create_engine("sqlite://") + + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + session.add(hero) + session.commit() + session.refresh(hero) + + model_copy = hero.model_copy(update={"name": "Deadpond Copy"}) + + assert ( + model_copy.name == "Deadpond Copy" + and model_copy.secret_name == "Dive Wilson" + and model_copy.age == 25 + ) + + db_hero = session.get(Hero, hero.id) + + db_copy = db_hero.model_copy(update={"name": "Deadpond Copy"}) + + assert ( + db_copy.name == "Deadpond Copy" + and db_copy.secret_name == "Dive Wilson" + and db_copy.age == 25 + ) diff --git a/tests/test_nullable.py b/tests/test_nullable.py index 1c8b37b..041a889 100644 --- a/tests/test_nullable.py +++ b/tests/test_nullable.py @@ -1,8 +1,14 @@ from typing import Optional import pytest +from pydantic import AnyUrl, UrlConstraints from sqlalchemy.exc import IntegrityError from sqlmodel import Field, Session, SQLModel, create_engine +from typing_extensions import Annotated + +MoveSharedUrl = Annotated[ + AnyUrl, UrlConstraints(max_length=512, allowed_schemes=["smb", "ftp", "file"]) +] def test_nullable_fields(clear_sqlmodel, caplog): @@ -13,6 +19,8 @@ def test_nullable_fields(clear_sqlmodel, caplog): ) required_value: str optional_default_ellipsis: Optional[str] = Field(default=...) + optional_no_field: Optional[str] + optional_no_field_default: Optional[str] = Field(description="no default") optional_default_none: Optional[str] = Field(default=None) optional_non_nullable: Optional[str] = Field( nullable=False, @@ -49,6 +57,13 @@ def test_nullable_fields(clear_sqlmodel, caplog): str_default_str_nullable: str = Field(default="default", nullable=True) str_default_ellipsis_non_nullable: str = Field(default=..., nullable=False) str_default_ellipsis_nullable: str = Field(default=..., nullable=True) + base_url: AnyUrl + optional_url: Optional[MoveSharedUrl] = Field(default=None, description="") + url: MoveSharedUrl + annotated_url: Annotated[MoveSharedUrl, Field(description="")] + annotated_optional_url: Annotated[ + Optional[MoveSharedUrl], Field(description="") + ] = None engine = create_engine("sqlite://", echo=True) SQLModel.metadata.create_all(engine) @@ -59,6 +74,8 @@ def test_nullable_fields(clear_sqlmodel, caplog): assert "primary_key INTEGER NOT NULL," in create_table_log assert "required_value VARCHAR NOT NULL," in create_table_log assert "optional_default_ellipsis VARCHAR NOT NULL," in create_table_log + assert "optional_no_field VARCHAR," in create_table_log + assert "optional_no_field_default VARCHAR NOT NULL," in create_table_log assert "optional_default_none VARCHAR," in create_table_log assert "optional_non_nullable VARCHAR NOT NULL," in create_table_log assert "optional_nullable VARCHAR," in create_table_log @@ -77,6 +94,11 @@ def test_nullable_fields(clear_sqlmodel, caplog): assert "str_default_str_nullable VARCHAR," in create_table_log assert "str_default_ellipsis_non_nullable VARCHAR NOT NULL," in create_table_log assert "str_default_ellipsis_nullable VARCHAR," in create_table_log + assert "base_url VARCHAR NOT NULL," in create_table_log + assert "optional_url VARCHAR(512), " in create_table_log + assert "url VARCHAR(512) NOT NULL," in create_table_log + assert "annotated_url VARCHAR(512) NOT NULL," in create_table_log + assert "annotated_optional_url VARCHAR(512)," in create_table_log # Test for regression in https://github.com/tiangolo/sqlmodel/issues/420