From f3e7811a801580813c8f30ea7b6a47bf653f656e Mon Sep 17 00:00:00 2001 From: honglei Date: Tue, 15 Aug 2023 18:13:24 +0800 Subject: [PATCH 01/25] get_column_from_field support function --- sqlmodel/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index ab41b7b..fc99f00 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -524,6 +524,8 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore return sa_column if isinstance(sa_column, MappedColumn): return sa_column.column + if isinstance(sa_column, types.FunctionType): + return sa_column() sa_type = get_sqlalchemy_type(field) primary_key = getattr(field, "primary_key", False) index = getattr(field, "index", PydanticUndefined) From 9e07c1c772b6526e98d981090f064273743db2e3 Mon Sep 17 00:00:00 2001 From: honglei Date: Tue, 15 Aug 2023 22:38:03 +0800 Subject: [PATCH 02/25] fix type check for sa_column --- sqlmodel/main.py | 6 ++++-- sqlmodel/sql/sqltypes.py | 1 - 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index fc99f00..ab1de9f 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -167,7 +167,7 @@ def Field( unique: bool = False, nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined, index: Union[bool, PydanticUndefinedType] = PydanticUndefined, - sa_column: Union[Column, PydanticUndefinedType, types.FunctionType] = PydanticUndefined, # type: ignore + sa_column: Union[Column, PydanticUndefinedType, Callable[[], Column]] = PydanticUndefined, # type: ignore sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined, sa_column_kwargs: Union[ Mapping[str, Any], PydanticUndefinedType @@ -525,7 +525,9 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore if isinstance(sa_column, MappedColumn): return sa_column.column if isinstance(sa_column, types.FunctionType): - return sa_column() + col = sa_column() + assert isinstance(col, Column) + return col sa_type = get_sqlalchemy_type(field) primary_key = getattr(field, "primary_key", False) index = getattr(field, "index", PydanticUndefined) diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py index da6551b..aa30950 100644 --- a/sqlmodel/sql/sqltypes.py +++ b/sqlmodel/sql/sqltypes.py @@ -8,7 +8,6 @@ from sqlalchemy.sql.type_api import TypeEngine class AutoString(types.TypeDecorator): # type: ignore - impl = types.String cache_ok = True mysql_default_length = 255 From 5b49f778c3a358a5d9169a659ae36f8c6180a99e Mon Sep 17 00:00:00 2001 From: honglei Date: Wed, 16 Aug 2023 21:22:57 +0800 Subject: [PATCH 03/25] get_column_from_field:sa_column>field attribute>field annotation --- sqlmodel/main.py | 43 ++++++++++++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index ab1de9f..95a8278 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -440,17 +440,19 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) +def _is_optional_or_union(type_: Optional[type]) -> bool: + if sys.version_info >= (3, 10): + return get_origin(type_) in (types.UnionType, Union) + else: + return get_origin(type_) is Union + + def get_sqlalchemy_type(field: FieldInfo) -> Any: type_: Optional[type] = field.annotation # Resolve Optional/Union fields - def is_optional_or_union(type_: Optional[type]) -> bool: - if sys.version_info >= (3, 10): - return get_origin(type_) in (types.UnionType, Union) - else: - return get_origin(type_) is Union - if type_ is not None and is_optional_or_union(type_): + if type_ is not None and _is_optional_or_union(type_): bases = get_args(type_) if len(bases) > 2: raise RuntimeError( @@ -519,15 +521,27 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any: def get_column_from_field(field: FieldInfo) -> Column: # type: ignore + """ + sa_column > field attributes > annotation info + """ sa_column = getattr(field, "sa_column", PydanticUndefined) + col: Column | None = None if isinstance(sa_column, Column): - return sa_column - if isinstance(sa_column, MappedColumn): - return sa_column.column - if isinstance(sa_column, types.FunctionType): + col = sa_column + elif isinstance(sa_column, MappedColumn): + col = sa_column.column + elif isinstance(sa_column, types.FunctionType): col = sa_column() - assert isinstance(col, Column) + if isinstance(col, Column): + # field attribute or field annotation -> Column.nullable + if col.nullable is PydanticUndefined: + col.nullable = _is_field_noneable(field) + # field.primary_key -> Column.primary_key + if col.primary_key is PydanticUndefined: + primary_key = getattr(field, "primary_key", False) + col.primary_key = primary_key return col + sa_type = get_sqlalchemy_type(field) primary_key = getattr(field, "primary_key", False) index = getattr(field, "index", PydanticUndefined) @@ -661,14 +675,17 @@ def _is_field_noneable(field: FieldInfo) -> bool: return field.nullable if not field.is_required(): default = getattr(field, "original_default", field.default) - if default is PydanticUndefined: + if default is None: + return True + elif default is not PydanticUndefined: return False if field.annotation is None or field.annotation is NoneType: return True - if get_origin(field.annotation) is Union: + if _is_optional_or_union(field.annotation): for base in get_args(field.annotation): if base is NoneType: return True + return False return False From 6a5f373862f8cf05962c0bd05e1b987c5c3fe1f6 Mon Sep 17 00:00:00 2001 From: honglei Date: Wed, 16 Aug 2023 22:20:43 +0800 Subject: [PATCH 04/25] Revert "get_column_from_field:sa_column>field attribute>field annotation" --- sqlmodel/main.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 95a8278..fcd6433 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -525,23 +525,14 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore sa_column > field attributes > annotation info """ sa_column = getattr(field, "sa_column", PydanticUndefined) - col: Column | None = None if isinstance(sa_column, Column): - col = sa_column - elif isinstance(sa_column, MappedColumn): - col = sa_column.column - elif isinstance(sa_column, types.FunctionType): + return sa_column + if isinstance(sa_column, MappedColumn): + return sa_column.column + if isinstance(sa_column, types.FunctionType): col = sa_column() - if isinstance(col, Column): - # field attribute or field annotation -> Column.nullable - if col.nullable is PydanticUndefined: - col.nullable = _is_field_noneable(field) - # field.primary_key -> Column.primary_key - if col.primary_key is PydanticUndefined: - primary_key = getattr(field, "primary_key", False) - col.primary_key = primary_key + assert isinstance(col, Column) return col - sa_type = get_sqlalchemy_type(field) primary_key = getattr(field, "primary_key", False) index = getattr(field, "index", PydanticUndefined) From 72dc89d92be6eb65dded955d97ccca65f59ae72f Mon Sep 17 00:00:00 2001 From: honglei Date: Wed, 16 Aug 2023 22:32:28 +0800 Subject: [PATCH 05/25] field is required by default, while nullable=True for Column --- sqlmodel/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index fcd6433..922422d 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -668,7 +668,7 @@ def _is_field_noneable(field: FieldInfo) -> bool: default = getattr(field, "original_default", field.default) if default is None: return True - elif default is not PydanticUndefined: + elif default is PydanticUndefined: return False if field.annotation is None or field.annotation is NoneType: return True From fa8955c70043cda6dc7a57e33242d4cbdac031aa Mon Sep 17 00:00:00 2001 From: honglei Date: Wed, 16 Aug 2023 22:58:10 +0800 Subject: [PATCH 06/25] field required --- sqlmodel/main.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 922422d..a593a53 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -666,9 +666,7 @@ def _is_field_noneable(field: FieldInfo) -> bool: return field.nullable if not field.is_required(): default = getattr(field, "original_default", field.default) - if default is None: - return True - elif default is PydanticUndefined: + if default is PydanticUndefined: return False if field.annotation is None or field.annotation is NoneType: return True From 045f9bcc8aa74c2d888b7ee856eae7006f9311fe Mon Sep 17 00:00:00 2001 From: honglei Date: Wed, 16 Aug 2023 23:18:55 +0800 Subject: [PATCH 07/25] add test for pydantic.AnyURL --- sqlmodel/main.py | 5 ++--- tests/test_nullable.py | 5 +++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index a593a53..7e809f3 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -666,15 +666,14 @@ def _is_field_noneable(field: FieldInfo) -> bool: return field.nullable if not field.is_required(): default = getattr(field, "original_default", field.default) - if default is PydanticUndefined: - return False if field.annotation is None or field.annotation is NoneType: return True if _is_optional_or_union(field.annotation): for base in get_args(field.annotation): if base is NoneType: return True - + if default is PydanticUndefined: + return False return False return False diff --git a/tests/test_nullable.py b/tests/test_nullable.py index 1c8b37b..2509da0 100644 --- a/tests/test_nullable.py +++ b/tests/test_nullable.py @@ -3,6 +3,9 @@ from typing import Optional import pytest from sqlalchemy.exc import IntegrityError from sqlmodel import Field, Session, SQLModel, create_engine +from typing_extensions import Annotated +from pydantic import AnyUrl, UrlConstraints +MoveSharedUrl = Annotated[AnyUrl, UrlConstraints(max_length=512, allowed_schemes=['smb', 'ftp','file'])] def test_nullable_fields(clear_sqlmodel, caplog): @@ -49,6 +52,7 @@ def test_nullable_fields(clear_sqlmodel, caplog): str_default_str_nullable: str = Field(default="default", nullable=True) str_default_ellipsis_non_nullable: str = Field(default=..., nullable=False) str_default_ellipsis_nullable: str = Field(default=..., nullable=True) + annotated_any_url: MoveSharedUrl | None = Field(description="") engine = create_engine("sqlite://", echo=True) SQLModel.metadata.create_all(engine) @@ -77,6 +81,7 @@ def test_nullable_fields(clear_sqlmodel, caplog): assert "str_default_str_nullable VARCHAR," in create_table_log assert "str_default_ellipsis_non_nullable VARCHAR NOT NULL," in create_table_log assert "str_default_ellipsis_nullable VARCHAR," in create_table_log + assert "annotated_any_url VARCHAR(512)," in create_table_log # Test for regression in https://github.com/tiangolo/sqlmodel/issues/420 From 6e89ad374b499bc32fa9ba7abc51fe03881d9567 Mon Sep 17 00:00:00 2001 From: honglei Date: Wed, 16 Aug 2023 23:21:40 +0800 Subject: [PATCH 08/25] black/isort for test_nullable.py --- tests/test_nullable.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_nullable.py b/tests/test_nullable.py index 2509da0..c21d311 100644 --- a/tests/test_nullable.py +++ b/tests/test_nullable.py @@ -1,11 +1,14 @@ from typing import Optional import pytest +from pydantic import AnyUrl, UrlConstraints from sqlalchemy.exc import IntegrityError from sqlmodel import Field, Session, SQLModel, create_engine from typing_extensions import Annotated -from pydantic import AnyUrl, UrlConstraints -MoveSharedUrl = Annotated[AnyUrl, UrlConstraints(max_length=512, allowed_schemes=['smb', 'ftp','file'])] + +MoveSharedUrl = Annotated[ + AnyUrl, UrlConstraints(max_length=512, allowed_schemes=["smb", "ftp", "file"]) +] def test_nullable_fields(clear_sqlmodel, caplog): From 6b7925d8ce9e721a6b892b81aaef7154e7ca945a Mon Sep 17 00:00:00 2001 From: honglei Date: Sun, 20 Aug 2023 21:10:32 +0800 Subject: [PATCH 09/25] fix _is_field_noneable --- sqlmodel/main.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 7e809f3..a593a53 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -666,14 +666,15 @@ def _is_field_noneable(field: FieldInfo) -> bool: return field.nullable if not field.is_required(): default = getattr(field, "original_default", field.default) + if default is PydanticUndefined: + return False if field.annotation is None or field.annotation is NoneType: return True if _is_optional_or_union(field.annotation): for base in get_args(field.annotation): if base is NoneType: return True - if default is PydanticUndefined: - return False + return False return False From 4e89361f9300bb50fa4facee7409d9b8fb03e746 Mon Sep 17 00:00:00 2001 From: honglei Date: Sun, 20 Aug 2023 21:50:30 +0800 Subject: [PATCH 10/25] add test for Class hierarchy --- tests/test_class_hierarchy.py | 79 +++++++++++++++++++++++++++++++++++ tests/test_nullable.py | 6 ++- 2 files changed, 84 insertions(+), 1 deletion(-) create mode 100644 tests/test_class_hierarchy.py diff --git a/tests/test_class_hierarchy.py b/tests/test_class_hierarchy.py new file mode 100644 index 0000000..0f0ec1f --- /dev/null +++ b/tests/test_class_hierarchy.py @@ -0,0 +1,79 @@ +import datetime +import sys + +import pytest +from pydantic import AnyUrl, UrlConstraints +from typing_extensions import Annotated + +from sqlmodel import ( + BigInteger, + Column, + DateTime, + Field, + Integer, + SQLModel, + String, + create_engine, +) + +MoveSharedUrl = Annotated[ + AnyUrl, UrlConstraints(max_length=512, allowed_schemes=["smb", "ftp", "file"]) +] + + +@pytest.mark.skipif(sys.version_info < (3, 10)) +def test_field_resuse(): + class BasicFileLog(SQLModel): + resourceID: int = Field( + sa_column=lambda: Column(Integer, index=True), description=""" """ + ) + transportID: Annotated[int | None, Field(description=" for ")] = None + fileName: str = Field( + sa_column=lambda: Column(String, index=True), description=""" """ + ) + fileSize: int | None = Field( + sa_column=lambda: Column(BigInteger), ge=0, description=""" """ + ) + beginTime: datetime.datetime | None = Field( + sa_column=lambda: Column( + DateTime(timezone=True), + index=True, + ), + description="", + ) + + class SendFileLog(BasicFileLog, table=True): + id: int | None = Field( + sa_column=Column(Integer, primary_key=True, autoincrement=True), + description=""" """, + ) + sendUser: str + dstUrl: MoveSharedUrl | None + + class RecvFileLog(BasicFileLog, table=True): + id: int | None = Field( + sa_column=Column(Integer, primary_key=True, autoincrement=True), + description=""" """, + ) + recvUser: str + + sqlite_file_name = "database.db" + sqlite_url = f"sqlite:///{sqlite_file_name}" + + engine = create_engine(sqlite_url, echo=True) + SQLModel.metadata.drop_all(engine) + SQLModel.metadata.create_all(engine) + SendFileLog( + sendUser="j", + resourceID=1, + fileName="a.txt", + fileSize=3234, + beginTime=datetime.datetime.now(), + ) + RecvFileLog( + sendUser="j", + resourceID=1, + fileName="a.txt", + fileSize=3234, + beginTime=datetime.datetime.now(), + ) diff --git a/tests/test_nullable.py b/tests/test_nullable.py index c21d311..b4fa1bb 100644 --- a/tests/test_nullable.py +++ b/tests/test_nullable.py @@ -19,6 +19,8 @@ def test_nullable_fields(clear_sqlmodel, caplog): ) required_value: str optional_default_ellipsis: Optional[str] = Field(default=...) + optional_no_field: Optional[str] + optional_no_field_default: Optional[str] = Field(description="no default") optional_default_none: Optional[str] = Field(default=None) optional_non_nullable: Optional[str] = Field( nullable=False, @@ -55,7 +57,7 @@ def test_nullable_fields(clear_sqlmodel, caplog): str_default_str_nullable: str = Field(default="default", nullable=True) str_default_ellipsis_non_nullable: str = Field(default=..., nullable=False) str_default_ellipsis_nullable: str = Field(default=..., nullable=True) - annotated_any_url: MoveSharedUrl | None = Field(description="") + annotated_any_url: Optional[MoveSharedUrl] = Field(description="") engine = create_engine("sqlite://", echo=True) SQLModel.metadata.create_all(engine) @@ -66,6 +68,8 @@ def test_nullable_fields(clear_sqlmodel, caplog): assert "primary_key INTEGER NOT NULL," in create_table_log assert "required_value VARCHAR NOT NULL," in create_table_log assert "optional_default_ellipsis VARCHAR NOT NULL," in create_table_log + assert "optional_no_field VARCHAR," in create_table_log + assert "optional_no_field_default VARCHAR NOT NULL," in create_table_log assert "optional_default_none VARCHAR," in create_table_log assert "optional_non_nullable VARCHAR NOT NULL," in create_table_log assert "optional_nullable VARCHAR," in create_table_log From 7752780fda975939aab32eca8306e2700f8bb8dc Mon Sep 17 00:00:00 2001 From: honglei Date: Sun, 20 Aug 2023 21:54:01 +0800 Subject: [PATCH 11/25] fix isort --- tests/test_class_hierarchy.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_class_hierarchy.py b/tests/test_class_hierarchy.py index 0f0ec1f..9e3c5f8 100644 --- a/tests/test_class_hierarchy.py +++ b/tests/test_class_hierarchy.py @@ -3,8 +3,6 @@ import sys import pytest from pydantic import AnyUrl, UrlConstraints -from typing_extensions import Annotated - from sqlmodel import ( BigInteger, Column, @@ -15,6 +13,7 @@ from sqlmodel import ( String, create_engine, ) +from typing_extensions import Annotated MoveSharedUrl = Annotated[ AnyUrl, UrlConstraints(max_length=512, allowed_schemes=["smb", "ftp", "file"]) From 499bc188576f3ae1442d9acb92568a5916805057 Mon Sep 17 00:00:00 2001 From: honglei Date: Sun, 20 Aug 2023 21:57:53 +0800 Subject: [PATCH 12/25] add reason for skipif --- tests/test_class_hierarchy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_class_hierarchy.py b/tests/test_class_hierarchy.py index 9e3c5f8..81314e4 100644 --- a/tests/test_class_hierarchy.py +++ b/tests/test_class_hierarchy.py @@ -20,7 +20,7 @@ MoveSharedUrl = Annotated[ ] -@pytest.mark.skipif(sys.version_info < (3, 10)) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_field_resuse(): class BasicFileLog(SQLModel): resourceID: int = Field( From 8e2b363c7731a1cf0543b45f3f7329759aa7edc5 Mon Sep 17 00:00:00 2001 From: honglei Date: Sun, 20 Aug 2023 22:01:45 +0800 Subject: [PATCH 13/25] annotation not null --- tests/test_nullable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_nullable.py b/tests/test_nullable.py index b4fa1bb..3117812 100644 --- a/tests/test_nullable.py +++ b/tests/test_nullable.py @@ -88,7 +88,7 @@ def test_nullable_fields(clear_sqlmodel, caplog): assert "str_default_str_nullable VARCHAR," in create_table_log assert "str_default_ellipsis_non_nullable VARCHAR NOT NULL," in create_table_log assert "str_default_ellipsis_nullable VARCHAR," in create_table_log - assert "annotated_any_url VARCHAR(512)," in create_table_log + assert "annotated_any_url VARCHAR(512) NOT NULL" in create_table_log # Test for regression in https://github.com/tiangolo/sqlmodel/issues/420 From f5fd8504b9f4a939a19d8d9f8428e8dee27b74ad Mon Sep 17 00:00:00 2001 From: "50bytes.dev" <50bytes.dev@gmail.com> Date: Thu, 24 Aug 2023 15:08:02 +0300 Subject: [PATCH 14/25] fix model_copy --- sqlmodel/main.py | 4 ++++ tests/test_model_copy.py | 49 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 tests/test_model_copy.py diff --git a/sqlmodel/main.py b/sqlmodel/main.py index a593a53..3981058 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -596,6 +596,10 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry # in the Pydantic model so that when SQLAlchemy sets attributes that are # added (e.g. when querying from DB) to the __fields_set__, this already exists object.__setattr__(new_object, "__pydantic_fields_set__", set()) + if not hasattr(new_object, "__pydantic_extra__"): + object.__setattr__(new_object, "__pydantic_extra__", None) + if not hasattr(new_object, "__pydantic_private__"): + object.__setattr__(new_object, "__pydantic_private__", None) return new_object def __init__(__pydantic_self__, **data: Any) -> None: diff --git a/tests/test_model_copy.py b/tests/test_model_copy.py new file mode 100644 index 0000000..30f7e0a --- /dev/null +++ b/tests/test_model_copy.py @@ -0,0 +1,49 @@ +from typing import Optional + +import pytest +from pydantic import field_validator +from pydantic.error_wrappers import ValidationError +from sqlmodel import SQLModel, create_engine, Session, Field + + +def test_model_copy(clear_sqlmodel): + """Test validation of implicit and explict None values. + + # For consistency with pydantic, validators are not to be called on + # arguments that are not explicitly provided. + + https://github.com/tiangolo/sqlmodel/issues/230 + https://github.com/samuelcolvin/pydantic/issues/1223 + + """ + + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + secret_name: str + age: Optional[int] = None + + hero = Hero(name="Deadpond", secret_name="Dive Wilson", age=25) + + engine = create_engine("sqlite://") + + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + session.add(hero) + session.commit() + session.refresh(hero) + + model_copy = hero.model_copy(update={"name": "Deadpond Copy"}) + + assert model_copy.name == "Deadpond Copy" and \ + model_copy.secret_name == "Dive Wilson" and \ + model_copy.age == 25 + + db_hero = session.get(Hero, hero.id) + + db_copy = db_hero.model_copy(update={"name": "Deadpond Copy"}) + + assert db_copy.name == "Deadpond Copy" and \ + db_copy.secret_name == "Dive Wilson" and \ + db_copy.age == 25 From 45ab47266c1fd944e73013039efecbb6ca35d3a6 Mon Sep 17 00:00:00 2001 From: "50bytes.dev" <50bytes.dev@gmail.com> Date: Thu, 24 Aug 2023 18:41:34 +0300 Subject: [PATCH 15/25] fix --- sqlmodel/main.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 3981058..dd41cc9 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -649,7 +649,10 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry # remove defaults so they don't get validated data = {} for key, value in validated: - field = cls.model_fields[key] + field = cls.model_fields.get(key) + + if field is None: + continue if ( hasattr(field, "default") From f266da7b46d0919d7946156acfcdd03370384727 Mon Sep 17 00:00:00 2001 From: honglei Date: Fri, 25 Aug 2023 18:11:07 +0800 Subject: [PATCH 16/25] black/isort test_model_copy.py --- tests/test_model_copy.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/test_model_copy.py b/tests/test_model_copy.py index 30f7e0a..74cd534 100644 --- a/tests/test_model_copy.py +++ b/tests/test_model_copy.py @@ -3,7 +3,7 @@ from typing import Optional import pytest from pydantic import field_validator from pydantic.error_wrappers import ValidationError -from sqlmodel import SQLModel, create_engine, Session, Field +from sqlmodel import Field, Session, SQLModel, create_engine def test_model_copy(clear_sqlmodel): @@ -36,14 +36,18 @@ def test_model_copy(clear_sqlmodel): model_copy = hero.model_copy(update={"name": "Deadpond Copy"}) - assert model_copy.name == "Deadpond Copy" and \ - model_copy.secret_name == "Dive Wilson" and \ - model_copy.age == 25 + assert ( + model_copy.name == "Deadpond Copy" + and model_copy.secret_name == "Dive Wilson" + and model_copy.age == 25 + ) db_hero = session.get(Hero, hero.id) db_copy = db_hero.model_copy(update={"name": "Deadpond Copy"}) - assert db_copy.name == "Deadpond Copy" and \ - db_copy.secret_name == "Dive Wilson" and \ - db_copy.age == 25 + assert ( + db_copy.name == "Deadpond Copy" + and db_copy.secret_name == "Dive Wilson" + and db_copy.age == 25 + ) From ef56d08f2ac9118d4b558cbd97e8996b9d4fda1e Mon Sep 17 00:00:00 2001 From: honglei Date: Fri, 25 Aug 2023 18:15:17 +0800 Subject: [PATCH 17/25] remove unused import in test_model_copy.py --- tests/test_model_copy.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_model_copy.py b/tests/test_model_copy.py index 74cd534..d21a6e4 100644 --- a/tests/test_model_copy.py +++ b/tests/test_model_copy.py @@ -1,8 +1,5 @@ from typing import Optional -import pytest -from pydantic import field_validator -from pydantic.error_wrappers import ValidationError from sqlmodel import Field, Session, SQLModel, create_engine From e0d32fb27ae04d60e206cedef28d77439d4f29af Mon Sep 17 00:00:00 2001 From: honglei Date: Fri, 25 Aug 2023 21:02:34 +0800 Subject: [PATCH 18/25] try fix py3.8/test_nullable.py --- sqlmodel/main.py | 5 ++++- tests/test_nullable.py | 12 ++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index dd41cc9..7d94c40 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -467,11 +467,14 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any: meta = field.metadata[0] return AutoString(length=meta.max_length) - if get_origin(type_) is Annotated: + org_type = get_origin(type_) + if org_type is Annotated: type2 = get_args(type_)[0] if type2 is pydantic.AnyUrl: meta = get_args(type_)[1] return AutoString(length=meta.max_length) + elif org_type is pydantic.AnyUrl: + return AutoString(type_.__metadata__[0].max_length) # The 3rd is PydanticGeneralMetadata metadata = _get_field_metadata(field) diff --git a/tests/test_nullable.py b/tests/test_nullable.py index 3117812..22e1e1b 100644 --- a/tests/test_nullable.py +++ b/tests/test_nullable.py @@ -57,7 +57,12 @@ def test_nullable_fields(clear_sqlmodel, caplog): str_default_str_nullable: str = Field(default="default", nullable=True) str_default_ellipsis_non_nullable: str = Field(default=..., nullable=False) str_default_ellipsis_nullable: str = Field(default=..., nullable=True) - annotated_any_url: Optional[MoveSharedUrl] = Field(description="") + optional_url: Optional[MoveSharedUrl] = Field(default=None, description="") + url: MoveSharedUrl + annotated_url: Annotated[MoveSharedUrl, Field(description="")] + annotated_optional_url: Annotated[ + Optional[MoveSharedUrl], Field(description="") + ] = None engine = create_engine("sqlite://", echo=True) SQLModel.metadata.create_all(engine) @@ -88,7 +93,10 @@ def test_nullable_fields(clear_sqlmodel, caplog): assert "str_default_str_nullable VARCHAR," in create_table_log assert "str_default_ellipsis_non_nullable VARCHAR NOT NULL," in create_table_log assert "str_default_ellipsis_nullable VARCHAR," in create_table_log - assert "annotated_any_url VARCHAR(512) NOT NULL" in create_table_log + assert "optional_url VARCHAR(512), " in create_table_log + assert "url VARCHAR(512) NOT NULL," in create_table_log + assert "annotated_url VARCHAR(512) NOT NULL," in create_table_log + assert "annotated_optional_url VARCHAR(512)," in create_table_log # Test for regression in https://github.com/tiangolo/sqlmodel/issues/420 From aa3325bb54df9a653f083de335c33669021d43b7 Mon Sep 17 00:00:00 2001 From: honglei Date: Fri, 25 Aug 2023 21:30:10 +0800 Subject: [PATCH 19/25] ugly way to fix py3.8/Annotation --- sqlmodel/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 7d94c40..cc5e8d9 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -448,7 +448,7 @@ def _is_optional_or_union(type_: Optional[type]) -> bool: def get_sqlalchemy_type(field: FieldInfo) -> Any: - type_: Optional[type] = field.annotation + type_: Optional[type] | _AnnotatedAlias = field.annotation # Resolve Optional/Union fields @@ -473,7 +473,7 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any: if type2 is pydantic.AnyUrl: meta = get_args(type_)[1] return AutoString(length=meta.max_length) - elif org_type is pydantic.AnyUrl: + elif org_type is pydantic.AnyUrl and type(type_) is _AnnotatedAlias: return AutoString(type_.__metadata__[0].max_length) # The 3rd is PydanticGeneralMetadata From e8247308fc94ae82faeeb469186d52c062f9655d Mon Sep 17 00:00:00 2001 From: honglei Date: Fri, 25 Aug 2023 21:32:11 +0800 Subject: [PATCH 20/25] miss import _AnnotatedAlias --- sqlmodel/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index cc5e8d9..38e668e 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -58,7 +58,7 @@ else: if sys.version_info >= (3, 9): from typing import Annotated else: - from typing_extensions import Annotated + from typing_extensions import Annotated, _AnnotatedAlias _T = TypeVar("_T") NoArgAnyCallable = Callable[[], Any] From dcb406f0eda938cf4fa3c3b5b73c10b17e7945c9 Mon Sep 17 00:00:00 2001 From: honglei Date: Fri, 25 Aug 2023 21:36:46 +0800 Subject: [PATCH 21/25] fix py3.9+ _AnnotatedAlias --- sqlmodel/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 38e668e..4ffeb59 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -56,7 +56,7 @@ else: from typing_extensions import get_args, get_origin if sys.version_info >= (3, 9): - from typing import Annotated + from typing import Annotated, _AnnotatedAlias else: from typing_extensions import Annotated, _AnnotatedAlias From d13fb7435bf9bdcbca4671dcc5935f8c0e4adf45 Mon Sep 17 00:00:00 2001 From: honglei Date: Fri, 25 Aug 2023 21:43:02 +0800 Subject: [PATCH 22/25] only use typing_extensions to import _AnnotatedAlias --- sqlmodel/main.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 4ffeb59..14dbffd 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -55,10 +55,7 @@ if sys.version_info >= (3, 8): else: from typing_extensions import get_args, get_origin -if sys.version_info >= (3, 9): - from typing import Annotated, _AnnotatedAlias -else: - from typing_extensions import Annotated, _AnnotatedAlias +from typing_extensions import Annotated, _AnnotatedAlias _T = TypeVar("_T") NoArgAnyCallable = Callable[[], Any] From c02b579f1b14f98fa0aaca0a62c6ab0c5ecfd543 Mon Sep 17 00:00:00 2001 From: honglei Date: Fri, 25 Aug 2023 22:08:03 +0800 Subject: [PATCH 23/25] support AnyURL --- tests/test_nullable.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_nullable.py b/tests/test_nullable.py index 22e1e1b..6d3df54 100644 --- a/tests/test_nullable.py +++ b/tests/test_nullable.py @@ -57,6 +57,7 @@ def test_nullable_fields(clear_sqlmodel, caplog): str_default_str_nullable: str = Field(default="default", nullable=True) str_default_ellipsis_non_nullable: str = Field(default=..., nullable=False) str_default_ellipsis_nullable: str = Field(default=..., nullable=True) + base_url : AnyUrl optional_url: Optional[MoveSharedUrl] = Field(default=None, description="") url: MoveSharedUrl annotated_url: Annotated[MoveSharedUrl, Field(description="")] @@ -93,6 +94,7 @@ def test_nullable_fields(clear_sqlmodel, caplog): assert "str_default_str_nullable VARCHAR," in create_table_log assert "str_default_ellipsis_non_nullable VARCHAR NOT NULL," in create_table_log assert "str_default_ellipsis_nullable VARCHAR," in create_table_log + assert "base_url VARCHAR NOT NULL," in create_table_log assert "optional_url VARCHAR(512), " in create_table_log assert "url VARCHAR(512) NOT NULL," in create_table_log assert "annotated_url VARCHAR(512) NOT NULL," in create_table_log From cb6ccf4c079ddf4b4011acd71d60ebd557fbc099 Mon Sep 17 00:00:00 2001 From: honglei Date: Fri, 25 Aug 2023 22:11:11 +0800 Subject: [PATCH 24/25] forgot black it --- tests/test_nullable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_nullable.py b/tests/test_nullable.py index 6d3df54..041a889 100644 --- a/tests/test_nullable.py +++ b/tests/test_nullable.py @@ -57,7 +57,7 @@ def test_nullable_fields(clear_sqlmodel, caplog): str_default_str_nullable: str = Field(default="default", nullable=True) str_default_ellipsis_non_nullable: str = Field(default=..., nullable=False) str_default_ellipsis_nullable: str = Field(default=..., nullable=True) - base_url : AnyUrl + base_url: AnyUrl optional_url: Optional[MoveSharedUrl] = Field(default=None, description="") url: MoveSharedUrl annotated_url: Annotated[MoveSharedUrl, Field(description="")] From 4213c978fc17b6f57337c43295adffd75eb554e4 Mon Sep 17 00:00:00 2001 From: honglei Date: Fri, 25 Aug 2023 22:12:34 +0800 Subject: [PATCH 25/25] support AnyURL --- sqlmodel/main.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 14dbffd..dbc05c4 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -461,8 +461,11 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any: # UrlConstraints(max_length=512, # allowed_schemes=['smb', 'ftp', 'file']) ] if type_ is pydantic.AnyUrl: - meta = field.metadata[0] - return AutoString(length=meta.max_length) + if field.metadata: + meta = field.metadata[0] + return AutoString(length=meta.max_length) + else: + return AutoString org_type = get_origin(type_) if org_type is Annotated: