mirror of
https://github.com/PaiGramTeam/sqlmodel.git
synced 2024-11-22 07:08:06 +00:00
Merge pull request #1 from honglei/main
support str|None , mapped_column, AnyURL
This commit is contained in:
commit
f67b414438
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import sys
|
||||
import types
|
||||
import uuid
|
||||
import weakref
|
||||
from datetime import date, datetime, time, timedelta
|
||||
@ -27,6 +28,7 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
import pydantic
|
||||
from pydantic import BaseModel
|
||||
from pydantic._internal._fields import PydanticGeneralMetadata
|
||||
from pydantic._internal._model_construction import ModelMetaclass
|
||||
@ -40,7 +42,9 @@ from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relati
|
||||
from sqlalchemy.orm.attributes import set_attribute
|
||||
from sqlalchemy.orm.decl_api import DeclarativeMeta
|
||||
from sqlalchemy.orm.instrumentation import is_instrumented
|
||||
from sqlalchemy.sql.schema import MetaData
|
||||
from sqlalchemy.orm.properties import MappedColumn
|
||||
from sqlalchemy.sql import false, true
|
||||
from sqlalchemy.sql.schema import DefaultClause, MetaData
|
||||
from sqlalchemy.sql.sqltypes import LargeBinary, Time
|
||||
|
||||
from .sql.sqltypes import GUID, AutoString
|
||||
@ -51,6 +55,11 @@ if sys.version_info >= (3, 8):
|
||||
else:
|
||||
from typing_extensions import get_args, get_origin
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
from typing import Annotated
|
||||
else:
|
||||
from typing_extensions import Annotated
|
||||
|
||||
_T = TypeVar("_T")
|
||||
NoArgAnyCallable = Callable[[], Any]
|
||||
NoneType = type(None)
|
||||
@ -158,7 +167,7 @@ def Field(
|
||||
unique: bool = False,
|
||||
nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined,
|
||||
index: Union[bool, PydanticUndefinedType] = PydanticUndefined,
|
||||
sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore
|
||||
sa_column: Union[Column, PydanticUndefinedType, types.FunctionType] = PydanticUndefined, # type: ignore
|
||||
sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined,
|
||||
sa_column_kwargs: Union[
|
||||
Mapping[str, Any], PydanticUndefinedType
|
||||
@ -166,6 +175,32 @@ def Field(
|
||||
schema_extra: Optional[Dict[str, Any]] = None,
|
||||
) -> Any:
|
||||
current_schema_extra = schema_extra or {}
|
||||
if default is PydanticUndefined:
|
||||
if isinstance(sa_column, types.FunctionType): # lambda
|
||||
sa_column_ = sa_column()
|
||||
else:
|
||||
sa_column_ = sa_column
|
||||
|
||||
# server_default -> default
|
||||
if isinstance(sa_column_, Column) and isinstance(
|
||||
sa_column_.server_default, DefaultClause
|
||||
):
|
||||
default_value = sa_column_.server_default.arg
|
||||
if issubclass(type(sa_column_.type), Integer) and isinstance(
|
||||
default_value, str
|
||||
):
|
||||
default = int(default_value)
|
||||
elif issubclass(type(sa_column_.type), Boolean):
|
||||
if default_value is false():
|
||||
default = False
|
||||
elif default_value is true():
|
||||
default = True
|
||||
elif isinstance(default_value, str):
|
||||
if default_value == "1":
|
||||
default = True
|
||||
elif default_value == "0":
|
||||
default = False
|
||||
|
||||
field_info = FieldInfo(
|
||||
default,
|
||||
default_factory=default_factory,
|
||||
@ -408,14 +443,33 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
||||
def get_sqlalchemy_type(field: FieldInfo) -> Any:
|
||||
type_: Optional[type] = field.annotation
|
||||
|
||||
# Resolve Optional fields
|
||||
if type_ is not None and get_origin(type_) is Union:
|
||||
# Resolve Optional/Union fields
|
||||
def is_optional_or_union(type_: Optional[type]) -> bool:
|
||||
if sys.version_info >= (3, 10):
|
||||
return get_origin(type_) in (types.UnionType, Union)
|
||||
else:
|
||||
return get_origin(type_) is Union
|
||||
|
||||
if type_ is not None and is_optional_or_union(type_):
|
||||
bases = get_args(type_)
|
||||
if len(bases) > 2:
|
||||
raise RuntimeError(
|
||||
"Cannot have a (non-optional) union as a SQL alchemy field"
|
||||
)
|
||||
type_ = bases[0]
|
||||
# Resolve Annoted fields,
|
||||
# like typing.Annotated[pydantic_core._pydantic_core.Url,
|
||||
# UrlConstraints(max_length=512,
|
||||
# allowed_schemes=['smb', 'ftp', 'file']) ]
|
||||
if type_ is pydantic.AnyUrl:
|
||||
meta = field.metadata[0]
|
||||
return AutoString(length=meta.max_length)
|
||||
|
||||
if get_origin(type_) is Annotated:
|
||||
type2 = get_args(type_)[0]
|
||||
if type2 is pydantic.AnyUrl:
|
||||
meta = get_args(type_)[1]
|
||||
return AutoString(length=meta.max_length)
|
||||
|
||||
# The 3rd is PydanticGeneralMetadata
|
||||
metadata = _get_field_metadata(field)
|
||||
@ -468,6 +522,8 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
|
||||
sa_column = getattr(field, "sa_column", PydanticUndefined)
|
||||
if isinstance(sa_column, Column):
|
||||
return sa_column
|
||||
if isinstance(sa_column, MappedColumn):
|
||||
return sa_column.column
|
||||
sa_type = get_sqlalchemy_type(field)
|
||||
primary_key = getattr(field, "primary_key", False)
|
||||
index = getattr(field, "index", PydanticUndefined)
|
||||
|
Loading…
Reference in New Issue
Block a user