diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 6e56dcd..ab41b7b 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -2,6 +2,7 @@ from __future__ import annotations import ipaddress import sys +import types import uuid import weakref from datetime import date, datetime, time, timedelta @@ -27,6 +28,7 @@ from typing import ( cast, ) +import pydantic from pydantic import BaseModel from pydantic._internal._fields import PydanticGeneralMetadata from pydantic._internal._model_construction import ModelMetaclass @@ -40,7 +42,9 @@ from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relati from sqlalchemy.orm.attributes import set_attribute from sqlalchemy.orm.decl_api import DeclarativeMeta from sqlalchemy.orm.instrumentation import is_instrumented -from sqlalchemy.sql.schema import MetaData +from sqlalchemy.orm.properties import MappedColumn +from sqlalchemy.sql import false, true +from sqlalchemy.sql.schema import DefaultClause, MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time from .sql.sqltypes import GUID, AutoString @@ -51,6 +55,11 @@ if sys.version_info >= (3, 8): else: from typing_extensions import get_args, get_origin +if sys.version_info >= (3, 9): + from typing import Annotated +else: + from typing_extensions import Annotated + _T = TypeVar("_T") NoArgAnyCallable = Callable[[], Any] NoneType = type(None) @@ -158,7 +167,7 @@ def Field( unique: bool = False, nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined, index: Union[bool, PydanticUndefinedType] = PydanticUndefined, - sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore + sa_column: Union[Column, PydanticUndefinedType, types.FunctionType] = PydanticUndefined, # type: ignore sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined, sa_column_kwargs: Union[ Mapping[str, Any], PydanticUndefinedType @@ -166,6 +175,32 @@ def Field( schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: current_schema_extra = schema_extra or {} + if default is PydanticUndefined: + if isinstance(sa_column, types.FunctionType): # lambda + sa_column_ = sa_column() + else: + sa_column_ = sa_column + + # server_default -> default + if isinstance(sa_column_, Column) and isinstance( + sa_column_.server_default, DefaultClause + ): + default_value = sa_column_.server_default.arg + if issubclass(type(sa_column_.type), Integer) and isinstance( + default_value, str + ): + default = int(default_value) + elif issubclass(type(sa_column_.type), Boolean): + if default_value is false(): + default = False + elif default_value is true(): + default = True + elif isinstance(default_value, str): + if default_value == "1": + default = True + elif default_value == "0": + default = False + field_info = FieldInfo( default, default_factory=default_factory, @@ -408,14 +443,33 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): def get_sqlalchemy_type(field: FieldInfo) -> Any: type_: Optional[type] = field.annotation - # Resolve Optional fields - if type_ is not None and get_origin(type_) is Union: + # Resolve Optional/Union fields + def is_optional_or_union(type_: Optional[type]) -> bool: + if sys.version_info >= (3, 10): + return get_origin(type_) in (types.UnionType, Union) + else: + return get_origin(type_) is Union + + if type_ is not None and is_optional_or_union(type_): bases = get_args(type_) if len(bases) > 2: raise RuntimeError( "Cannot have a (non-optional) union as a SQL alchemy field" ) type_ = bases[0] + # Resolve Annoted fields, + # like typing.Annotated[pydantic_core._pydantic_core.Url, + # UrlConstraints(max_length=512, + # allowed_schemes=['smb', 'ftp', 'file']) ] + if type_ is pydantic.AnyUrl: + meta = field.metadata[0] + return AutoString(length=meta.max_length) + + if get_origin(type_) is Annotated: + type2 = get_args(type_)[0] + if type2 is pydantic.AnyUrl: + meta = get_args(type_)[1] + return AutoString(length=meta.max_length) # The 3rd is PydanticGeneralMetadata metadata = _get_field_metadata(field) @@ -468,6 +522,8 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore sa_column = getattr(field, "sa_column", PydanticUndefined) if isinstance(sa_column, Column): return sa_column + if isinstance(sa_column, MappedColumn): + return sa_column.column sa_type = get_sqlalchemy_type(field) primary_key = getattr(field, "primary_key", False) index = getattr(field, "index", PydanticUndefined)