Merge pull request #1 from honglei/main

support str|None , mapped_column, AnyURL
This commit is contained in:
Santiago Martinez Balvanera 2023-08-14 22:48:50 +01:00 committed by GitHub
commit f67b414438
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import ipaddress
import sys
import types
import uuid
import weakref
from datetime import date, datetime, time, timedelta
@ -27,6 +28,7 @@ from typing import (
cast,
)
import pydantic
from pydantic import BaseModel
from pydantic._internal._fields import PydanticGeneralMetadata
from pydantic._internal._model_construction import ModelMetaclass
@ -40,7 +42,9 @@ from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relati
from sqlalchemy.orm.attributes import set_attribute
from sqlalchemy.orm.decl_api import DeclarativeMeta
from sqlalchemy.orm.instrumentation import is_instrumented
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.orm.properties import MappedColumn
from sqlalchemy.sql import false, true
from sqlalchemy.sql.schema import DefaultClause, MetaData
from sqlalchemy.sql.sqltypes import LargeBinary, Time
from .sql.sqltypes import GUID, AutoString
@ -51,6 +55,11 @@ if sys.version_info >= (3, 8):
else:
from typing_extensions import get_args, get_origin
if sys.version_info >= (3, 9):
from typing import Annotated
else:
from typing_extensions import Annotated
_T = TypeVar("_T")
NoArgAnyCallable = Callable[[], Any]
NoneType = type(None)
@ -158,7 +167,7 @@ def Field(
unique: bool = False,
nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined,
index: Union[bool, PydanticUndefinedType] = PydanticUndefined,
sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore
sa_column: Union[Column, PydanticUndefinedType, types.FunctionType] = PydanticUndefined, # type: ignore
sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined,
sa_column_kwargs: Union[
Mapping[str, Any], PydanticUndefinedType
@ -166,6 +175,32 @@ def Field(
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
current_schema_extra = schema_extra or {}
if default is PydanticUndefined:
if isinstance(sa_column, types.FunctionType): # lambda
sa_column_ = sa_column()
else:
sa_column_ = sa_column
# server_default -> default
if isinstance(sa_column_, Column) and isinstance(
sa_column_.server_default, DefaultClause
):
default_value = sa_column_.server_default.arg
if issubclass(type(sa_column_.type), Integer) and isinstance(
default_value, str
):
default = int(default_value)
elif issubclass(type(sa_column_.type), Boolean):
if default_value is false():
default = False
elif default_value is true():
default = True
elif isinstance(default_value, str):
if default_value == "1":
default = True
elif default_value == "0":
default = False
field_info = FieldInfo(
default,
default_factory=default_factory,
@ -408,14 +443,33 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
def get_sqlalchemy_type(field: FieldInfo) -> Any:
type_: Optional[type] = field.annotation
# Resolve Optional fields
if type_ is not None and get_origin(type_) is Union:
# Resolve Optional/Union fields
def is_optional_or_union(type_: Optional[type]) -> bool:
if sys.version_info >= (3, 10):
return get_origin(type_) in (types.UnionType, Union)
else:
return get_origin(type_) is Union
if type_ is not None and is_optional_or_union(type_):
bases = get_args(type_)
if len(bases) > 2:
raise RuntimeError(
"Cannot have a (non-optional) union as a SQL alchemy field"
)
type_ = bases[0]
# Resolve Annoted fields,
# like typing.Annotated[pydantic_core._pydantic_core.Url,
# UrlConstraints(max_length=512,
# allowed_schemes=['smb', 'ftp', 'file']) ]
if type_ is pydantic.AnyUrl:
meta = field.metadata[0]
return AutoString(length=meta.max_length)
if get_origin(type_) is Annotated:
type2 = get_args(type_)[0]
if type2 is pydantic.AnyUrl:
meta = get_args(type_)[1]
return AutoString(length=meta.max_length)
# The 3rd is PydanticGeneralMetadata
metadata = _get_field_metadata(field)
@ -468,6 +522,8 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
sa_column = getattr(field, "sa_column", PydanticUndefined)
if isinstance(sa_column, Column):
return sa_column
if isinstance(sa_column, MappedColumn):
return sa_column.column
sa_type = get_sqlalchemy_type(field)
primary_key = getattr(field, "primary_key", False)
index = getattr(field, "index", PydanticUndefined)