This commit is contained in:
honglei 2023-08-15 01:22:53 +08:00
parent fa8902c778
commit 710e92b285

View File

@ -10,7 +10,7 @@ from enum import Enum
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
AbstractSet, AbstractSet,
Any, Any,
Callable, Callable,
ClassVar, ClassVar,
Dict, Dict,
@ -24,7 +24,7 @@ from typing import (
Type, Type,
TypeVar, TypeVar,
Union, Union,
cast cast,
) )
import types import types
@ -55,7 +55,7 @@ if sys.version_info >= (3, 8):
from typing import get_args, get_origin from typing import get_args, get_origin
else: else:
from typing_extensions import get_args, get_origin from typing_extensions import get_args, get_origin
if sys.version_info >= (3, 9): if sys.version_info >= (3, 9):
from typing import Annotated from typing import Annotated
else: else:
@ -177,27 +177,30 @@ def Field(
) -> Any: ) -> Any:
current_schema_extra = schema_extra or {} current_schema_extra = schema_extra or {}
if default is PydanticUndefined: if default is PydanticUndefined:
if isinstance(sa_column, types.FunctionType): #lambda if isinstance(sa_column, types.FunctionType): # lambda
sa_column_ = sa_column() sa_column_ = sa_column()
else: else:
sa_column_ = sa_column sa_column_ = sa_column
#server_default -> default # server_default -> default
if isinstance(sa_column_, Column) and isinstance(sa_column_.server_default, DefaultClause): if isinstance(sa_column_, Column) and isinstance(
default_value = sa_column_.server_default.arg sa_column_.server_default, DefaultClause
if issubclass( type(sa_column_.type), Integer) and isinstance(default_value, str): ):
default_value = sa_column_.server_default.arg
if issubclass(type(sa_column_.type), Integer) and isinstance(
default_value, str
):
default = int(default_value) default = int(default_value)
elif issubclass( type(sa_column_.type), Boolean): elif issubclass(type(sa_column_.type), Boolean):
if default_value is false(): if default_value is false():
default = False default = False
elif default_value is true(): elif default_value is true():
default = True default = True
elif isinstance(default_value, str): elif isinstance(default_value, str):
if default_value == '1': if default_value == "1":
default =True default = True
elif default_value == '0': elif default_value == "0":
default = False default = False
field_info = FieldInfo( field_info = FieldInfo(
default, default,
@ -447,6 +450,7 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
return get_origin(type_) in (types.UnionType, Union) return get_origin(type_) in (types.UnionType, Union)
else: else:
return get_origin(type_) is Union return get_origin(type_) is Union
if type_ is not None and is_optional_or_union(type_): if type_ is not None and is_optional_or_union(type_):
bases = get_args(type_) bases = get_args(type_)
if len(bases) > 2: if len(bases) > 2:
@ -455,17 +459,17 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
) )
type_ = bases[0] type_ = bases[0]
# Resolve Annoted fields, # Resolve Annoted fields,
# like typing.Annotated[pydantic_core._pydantic_core.Url, # like typing.Annotated[pydantic_core._pydantic_core.Url,
# UrlConstraints(max_length=512, # UrlConstraints(max_length=512,
# allowed_schemes=['smb', 'ftp', 'file']) ] # allowed_schemes=['smb', 'ftp', 'file']) ]
if type_ is pydantic.AnyUrl: if type_ is pydantic.AnyUrl:
meta = field.metadata[0] meta = field.metadata[0]
return AutoString(length=meta.max_length) return AutoString(length=meta.max_length)
if get_origin(type_) is Annotated: if get_origin(type_) is Annotated:
type2 = get_args(type_)[0] type2 = get_args(type_)[0]
if type2 is pydantic.AnyUrl: if type2 is pydantic.AnyUrl:
meta = get_args(type_)[1] meta = get_args(type_)[1]
return AutoString(length=meta.max_length) return AutoString(length=meta.max_length)
# The 3rd is PydanticGeneralMetadata # The 3rd is PydanticGeneralMetadata