From 347e0526562185d7c125e5de36db7797d1e6cdd4 Mon Sep 17 00:00:00 2001 From: honglei Date: Sat, 12 Aug 2023 19:52:59 +0800 Subject: [PATCH 1/8] support str|None , mapped_column, AnyURL --- sqlmodel/main.py | 35 ++++++++++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 6e56dcd..b63de2f 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -11,6 +11,7 @@ from pathlib import Path from typing import ( AbstractSet, Any, + Annotated, Callable, ClassVar, Dict, @@ -26,7 +27,9 @@ from typing import ( Union, cast, ) +import types +import pydantic from pydantic import BaseModel from pydantic._internal._fields import PydanticGeneralMetadata from pydantic._internal._model_construction import ModelMetaclass @@ -40,6 +43,7 @@ from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relati from sqlalchemy.orm.attributes import set_attribute from sqlalchemy.orm.decl_api import DeclarativeMeta from sqlalchemy.orm.instrumentation import is_instrumented +from sqlalchemy.orm.properties import MappedColumn from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time @@ -166,6 +170,18 @@ def Field( schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: current_schema_extra = schema_extra or {} + if default is PydanticUndefined: + if type(sa_column) is types.FunctionType: #lambda + sa_column_ = sa_column() + else: + sa_column_ = sa_column + + #assert type(sa_column) is Column: + if isinstance(sa_column_, Column) and sa_column_.server_default is not None: + default_value = sa_column_.server_default.arg + if issubclass( type(sa_column_.type), Integer): + default = int(default_value) + field_info = FieldInfo( default, default_factory=default_factory, @@ -408,14 +424,27 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): def get_sqlalchemy_type(field: FieldInfo) -> Any: type_: Optional[type] = field.annotation - # Resolve Optional fields - if type_ is not None and get_origin(type_) is Union: + # Resolve Optional/Union fields + if type_ is not None and get_origin(type_) in (types.UnionType, Union): bases = get_args(type_) if len(bases) > 2: raise RuntimeError( "Cannot have a (non-optional) union as a SQL alchemy field" ) type_ = bases[0] + # Resolve Annoted fields, + # like typing.Annotated[pydantic_core._pydantic_core.Url, + # UrlConstraints(max_length=512, + # allowed_schemes=['smb', 'ftp', 'file']) ] + if type_ is pydantic.AnyUrl: + meta = field.metadata[0] + return AutoString(length=meta.max_length) + + if get_origin(type_) is Annotated: + type2 = type_.__args__[0] + if type2 is pydantic.AnyUrl: + meta = type_.__metadata__[0] + return AutoString(length=meta.max_length) # The 3rd is PydanticGeneralMetadata metadata = _get_field_metadata(field) @@ -466,7 +495,7 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any: def get_column_from_field(field: FieldInfo) -> Column: # type: ignore sa_column = getattr(field, "sa_column", PydanticUndefined) - if isinstance(sa_column, Column): + if isinstance(sa_column, Column) or isinstance(sa_column, MappedColumn) : return sa_column sa_type = get_sqlalchemy_type(field) primary_key = getattr(field, "primary_key", False) From 9da04076093ec435b0860a4c23ebf185e145cc0d Mon Sep 17 00:00:00 2001 From: honglei Date: Mon, 14 Aug 2023 21:03:41 +0800 Subject: [PATCH 2/8] support 3.9/use get_args --- sqlmodel/main.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index b63de2f..aa3def4 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -10,8 +10,7 @@ from enum import Enum from pathlib import Path from typing import ( AbstractSet, - Any, - Annotated, + Any, Callable, ClassVar, Dict, @@ -26,6 +25,7 @@ from typing import ( TypeVar, Union, cast, + get_args ) import types @@ -44,16 +44,16 @@ from sqlalchemy.orm.attributes import set_attribute from sqlalchemy.orm.decl_api import DeclarativeMeta from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.orm.properties import MappedColumn -from sqlalchemy.sql.schema import MetaData +from sqlalchemy.sql.schema import MetaData, DefaultClause from sqlalchemy.sql.sqltypes import LargeBinary, Time from .sql.sqltypes import GUID, AutoString from .typing import SQLModelConfig if sys.version_info >= (3, 8): - from typing import get_args, get_origin + from typing import get_args, get_origin, Annotated else: - from typing_extensions import get_args, get_origin + from typing_extensions import get_args, get_origin, Annotated _T = TypeVar("_T") NoArgAnyCallable = Callable[[], Any] @@ -162,7 +162,7 @@ def Field( unique: bool = False, nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined, index: Union[bool, PydanticUndefinedType] = PydanticUndefined, - sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore + sa_column: Union[Column, PydanticUndefinedType, types.FunctionType] = PydanticUndefined, # type: ignore sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined, sa_column_kwargs: Union[ Mapping[str, Any], PydanticUndefinedType @@ -176,8 +176,8 @@ def Field( else: sa_column_ = sa_column - #assert type(sa_column) is Column: - if isinstance(sa_column_, Column) and sa_column_.server_default is not None: + #server_default -> default + if isinstance(sa_column_, Column) and isinstance(sa_column_.server_default, DefaultClause): default_value = sa_column_.server_default.arg if issubclass( type(sa_column_.type), Integer): default = int(default_value) @@ -441,9 +441,9 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any: return AutoString(length=meta.max_length) if get_origin(type_) is Annotated: - type2 = type_.__args__[0] + type2 = get_args(type_)[0] if type2 is pydantic.AnyUrl: - meta = type_.__metadata__[0] + meta = get_args(type_)[1] return AutoString(length=meta.max_length) # The 3rd is PydanticGeneralMetadata @@ -495,8 +495,10 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any: def get_column_from_field(field: FieldInfo) -> Column: # type: ignore sa_column = getattr(field, "sa_column", PydanticUndefined) - if isinstance(sa_column, Column) or isinstance(sa_column, MappedColumn) : + if isinstance(sa_column, Column): return sa_column + if isinstance(sa_column, MappedColumn): + return sa_column.column sa_type = get_sqlalchemy_type(field) primary_key = getattr(field, "primary_key", False) index = getattr(field, "index", PydanticUndefined) From a22dac8113d4086adc215ac5e9571da69f6e8f14 Mon Sep 17 00:00:00 2001 From: honglei Date: Mon, 14 Aug 2023 23:02:32 +0800 Subject: [PATCH 3/8] avoid get_args directly --- sqlmodel/main.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index aa3def4..76f5aa6 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -24,8 +24,7 @@ from typing import ( Type, TypeVar, Union, - cast, - get_args + cast ) import types @@ -46,6 +45,8 @@ from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.orm.properties import MappedColumn from sqlalchemy.sql.schema import MetaData, DefaultClause from sqlalchemy.sql.sqltypes import LargeBinary, Time +from sqlalchemy.sql import false, true + from .sql.sqltypes import GUID, AutoString from .typing import SQLModelConfig @@ -179,8 +180,19 @@ def Field( #server_default -> default if isinstance(sa_column_, Column) and isinstance(sa_column_.server_default, DefaultClause): default_value = sa_column_.server_default.arg - if issubclass( type(sa_column_.type), Integer): + if issubclass( type(sa_column_.type), Integer) and isinstance(default_value, str): default = int(default_value) + elif issubclass( type(sa_column_.type), Boolean): + if default_value is false(): + default = False + elif default_value is true(): + default = True + elif isinstance(default_value, str): + if default_value == '1': + default =True + elif default_value == '0': + default = False + field_info = FieldInfo( default, From ce0064b286f30cead62d5781f3764854a5edae62 Mon Sep 17 00:00:00 2001 From: honglei Date: Tue, 15 Aug 2023 00:28:01 +0800 Subject: [PATCH 4/8] python version for types.UnionType/Annotated --- sqlmodel/main.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 76f5aa6..0a7df79 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -52,9 +52,14 @@ from .sql.sqltypes import GUID, AutoString from .typing import SQLModelConfig if sys.version_info >= (3, 8): - from typing import get_args, get_origin, Annotated + from typing import get_args, get_origin else: - from typing_extensions import get_args, get_origin, Annotated + from typing_extensions import get_args, get_origin + +if sys.version_info >= (3, 9): + from typing import Annotated +else: + from typing_extensions import Annotated _T = TypeVar("_T") NoArgAnyCallable = Callable[[], Any] @@ -437,7 +442,12 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any: type_: Optional[type] = field.annotation # Resolve Optional/Union fields - if type_ is not None and get_origin(type_) in (types.UnionType, Union): + def is_optional_or_union(type_): + if sys.version_info >= (3, 10): + return get_origin(type_) in (types.UnionType, Union) + else: + return get_origin(type_) is Union + if type_ is not None and is_optional_or_union(type_): bases = get_args(type_) if len(bases) > 2: raise RuntimeError( From a0b84c574d46b4574be406fedbf29357fff2a35a Mon Sep 17 00:00:00 2001 From: honglei Date: Tue, 15 Aug 2023 00:59:49 +0800 Subject: [PATCH 5/8] avoid compare types:FunctionType --- sqlmodel/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 0a7df79..0432bde 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -177,7 +177,7 @@ def Field( ) -> Any: current_schema_extra = schema_extra or {} if default is PydanticUndefined: - if type(sa_column) is types.FunctionType: #lambda + if isinstance(sa_column, types.FunctionType): #lambda sa_column_ = sa_column() else: sa_column_ = sa_column From fa8902c77825a7a18754bb27e9fe08bd66f0b27a Mon Sep 17 00:00:00 2001 From: honglei Date: Tue, 15 Aug 2023 01:13:56 +0800 Subject: [PATCH 6/8] add type hints for func is_optional_or_union --- sqlmodel/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 0432bde..a88826a 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -442,7 +442,7 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any: type_: Optional[type] = field.annotation # Resolve Optional/Union fields - def is_optional_or_union(type_): + def is_optional_or_union(type_: Optional[type]) -> bool: if sys.version_info >= (3, 10): return get_origin(type_) in (types.UnionType, Union) else: From 710e92b2859680e6c3acf1af917e90898b610582 Mon Sep 17 00:00:00 2001 From: honglei Date: Tue, 15 Aug 2023 01:22:53 +0800 Subject: [PATCH 7/8] black it --- sqlmodel/main.py | 44 ++++++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index a88826a..1e9c6a9 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -10,7 +10,7 @@ from enum import Enum from pathlib import Path from typing import ( AbstractSet, - Any, + Any, Callable, ClassVar, Dict, @@ -24,7 +24,7 @@ from typing import ( Type, TypeVar, Union, - cast + cast, ) import types @@ -55,7 +55,7 @@ if sys.version_info >= (3, 8): from typing import get_args, get_origin else: from typing_extensions import get_args, get_origin - + if sys.version_info >= (3, 9): from typing import Annotated else: @@ -177,27 +177,30 @@ def Field( ) -> Any: current_schema_extra = schema_extra or {} if default is PydanticUndefined: - if isinstance(sa_column, types.FunctionType): #lambda + if isinstance(sa_column, types.FunctionType): # lambda sa_column_ = sa_column() else: sa_column_ = sa_column - - #server_default -> default - if isinstance(sa_column_, Column) and isinstance(sa_column_.server_default, DefaultClause): - default_value = sa_column_.server_default.arg - if issubclass( type(sa_column_.type), Integer) and isinstance(default_value, str): + + # server_default -> default + if isinstance(sa_column_, Column) and isinstance( + sa_column_.server_default, DefaultClause + ): + default_value = sa_column_.server_default.arg + if issubclass(type(sa_column_.type), Integer) and isinstance( + default_value, str + ): default = int(default_value) - elif issubclass( type(sa_column_.type), Boolean): + elif issubclass(type(sa_column_.type), Boolean): if default_value is false(): default = False elif default_value is true(): default = True elif isinstance(default_value, str): - if default_value == '1': - default =True - elif default_value == '0': + if default_value == "1": + default = True + elif default_value == "0": default = False - field_info = FieldInfo( default, @@ -447,6 +450,7 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any: return get_origin(type_) in (types.UnionType, Union) else: return get_origin(type_) is Union + if type_ is not None and is_optional_or_union(type_): bases = get_args(type_) if len(bases) > 2: @@ -455,17 +459,17 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any: ) type_ = bases[0] # Resolve Annoted fields, - # like typing.Annotated[pydantic_core._pydantic_core.Url, - # UrlConstraints(max_length=512, + # like typing.Annotated[pydantic_core._pydantic_core.Url, + # UrlConstraints(max_length=512, # allowed_schemes=['smb', 'ftp', 'file']) ] if type_ is pydantic.AnyUrl: - meta = field.metadata[0] + meta = field.metadata[0] return AutoString(length=meta.max_length) - - if get_origin(type_) is Annotated: + + if get_origin(type_) is Annotated: type2 = get_args(type_)[0] if type2 is pydantic.AnyUrl: - meta = get_args(type_)[1] + meta = get_args(type_)[1] return AutoString(length=meta.max_length) # The 3rd is PydanticGeneralMetadata From 46b130dfb762c40b0a75deb0eb80905b542f080a Mon Sep 17 00:00:00 2001 From: honglei Date: Tue, 15 Aug 2023 01:30:56 +0800 Subject: [PATCH 8/8] fix isort error for `import types` --- sqlmodel/main.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 1e9c6a9..ab41b7b 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -2,6 +2,7 @@ from __future__ import annotations import ipaddress import sys +import types import uuid import weakref from datetime import date, datetime, time, timedelta @@ -26,7 +27,6 @@ from typing import ( Union, cast, ) -import types import pydantic from pydantic import BaseModel @@ -43,10 +43,9 @@ from sqlalchemy.orm.attributes import set_attribute from sqlalchemy.orm.decl_api import DeclarativeMeta from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.orm.properties import MappedColumn -from sqlalchemy.sql.schema import MetaData, DefaultClause -from sqlalchemy.sql.sqltypes import LargeBinary, Time from sqlalchemy.sql import false, true - +from sqlalchemy.sql.schema import DefaultClause, MetaData +from sqlalchemy.sql.sqltypes import LargeBinary, Time from .sql.sqltypes import GUID, AutoString from .typing import SQLModelConfig