support str|None , mapped_column, AnyURL

This commit is contained in:
honglei 2023-08-12 19:52:59 +08:00
parent c99c1a9b1e
commit 347e052656

View File

@ -11,6 +11,7 @@ from pathlib import Path
from typing import (
AbstractSet,
Any,
Annotated,
Callable,
ClassVar,
Dict,
@ -26,7 +27,9 @@ from typing import (
Union,
cast,
)
import types
import pydantic
from pydantic import BaseModel
from pydantic._internal._fields import PydanticGeneralMetadata
from pydantic._internal._model_construction import ModelMetaclass
@ -40,6 +43,7 @@ from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relati
from sqlalchemy.orm.attributes import set_attribute
from sqlalchemy.orm.decl_api import DeclarativeMeta
from sqlalchemy.orm.instrumentation import is_instrumented
from sqlalchemy.orm.properties import MappedColumn
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.sqltypes import LargeBinary, Time
@ -166,6 +170,18 @@ def Field(
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
current_schema_extra = schema_extra or {}
if default is PydanticUndefined:
if type(sa_column) is types.FunctionType: #lambda
sa_column_ = sa_column()
else:
sa_column_ = sa_column
#assert type(sa_column) is Column:
if isinstance(sa_column_, Column) and sa_column_.server_default is not None:
default_value = sa_column_.server_default.arg
if issubclass( type(sa_column_.type), Integer):
default = int(default_value)
field_info = FieldInfo(
default,
default_factory=default_factory,
@ -408,14 +424,27 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
def get_sqlalchemy_type(field: FieldInfo) -> Any:
type_: Optional[type] = field.annotation
# Resolve Optional fields
if type_ is not None and get_origin(type_) is Union:
# Resolve Optional/Union fields
if type_ is not None and get_origin(type_) in (types.UnionType, Union):
bases = get_args(type_)
if len(bases) > 2:
raise RuntimeError(
"Cannot have a (non-optional) union as a SQL alchemy field"
)
type_ = bases[0]
# Resolve Annoted fields,
# like typing.Annotated[pydantic_core._pydantic_core.Url,
# UrlConstraints(max_length=512,
# allowed_schemes=['smb', 'ftp', 'file']) ]
if type_ is pydantic.AnyUrl:
meta = field.metadata[0]
return AutoString(length=meta.max_length)
if get_origin(type_) is Annotated:
type2 = type_.__args__[0]
if type2 is pydantic.AnyUrl:
meta = type_.__metadata__[0]
return AutoString(length=meta.max_length)
# The 3rd is PydanticGeneralMetadata
metadata = _get_field_metadata(field)
@ -466,7 +495,7 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
sa_column = getattr(field, "sa_column", PydanticUndefined)
if isinstance(sa_column, Column):
if isinstance(sa_column, Column) or isinstance(sa_column, MappedColumn) :
return sa_column
sa_type = get_sqlalchemy_type(field)
primary_key = getattr(field, "primary_key", False)