diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 6e56dcd..b63de2f 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -11,6 +11,7 @@ from pathlib import Path from typing import ( AbstractSet, Any, + Annotated, Callable, ClassVar, Dict, @@ -26,7 +27,9 @@ from typing import ( Union, cast, ) +import types +import pydantic from pydantic import BaseModel from pydantic._internal._fields import PydanticGeneralMetadata from pydantic._internal._model_construction import ModelMetaclass @@ -40,6 +43,7 @@ from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relati from sqlalchemy.orm.attributes import set_attribute from sqlalchemy.orm.decl_api import DeclarativeMeta from sqlalchemy.orm.instrumentation import is_instrumented +from sqlalchemy.orm.properties import MappedColumn from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time @@ -166,6 +170,18 @@ def Field( schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: current_schema_extra = schema_extra or {} + if default is PydanticUndefined: + if type(sa_column) is types.FunctionType: #lambda + sa_column_ = sa_column() + else: + sa_column_ = sa_column + + #assert type(sa_column) is Column: + if isinstance(sa_column_, Column) and sa_column_.server_default is not None: + default_value = sa_column_.server_default.arg + if issubclass( type(sa_column_.type), Integer): + default = int(default_value) + field_info = FieldInfo( default, default_factory=default_factory, @@ -408,14 +424,27 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): def get_sqlalchemy_type(field: FieldInfo) -> Any: type_: Optional[type] = field.annotation - # Resolve Optional fields - if type_ is not None and get_origin(type_) is Union: + # Resolve Optional/Union fields + if type_ is not None and get_origin(type_) in (types.UnionType, Union): bases = get_args(type_) if len(bases) > 2: raise RuntimeError( "Cannot have a (non-optional) union as a SQL alchemy field" ) type_ = bases[0] + # Resolve Annoted fields, + # like typing.Annotated[pydantic_core._pydantic_core.Url, + # UrlConstraints(max_length=512, + # allowed_schemes=['smb', 'ftp', 'file']) ] + if type_ is pydantic.AnyUrl: + meta = field.metadata[0] + return AutoString(length=meta.max_length) + + if get_origin(type_) is Annotated: + type2 = type_.__args__[0] + if type2 is pydantic.AnyUrl: + meta = type_.__metadata__[0] + return AutoString(length=meta.max_length) # The 3rd is PydanticGeneralMetadata metadata = _get_field_metadata(field) @@ -466,7 +495,7 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any: def get_column_from_field(field: FieldInfo) -> Column: # type: ignore sa_column = getattr(field, "sa_column", PydanticUndefined) - if isinstance(sa_column, Column): + if isinstance(sa_column, Column) or isinstance(sa_column, MappedColumn) : return sa_column sa_type = get_sqlalchemy_type(field) primary_key = getattr(field, "primary_key", False)