This commit is contained in:
honglei 2023-08-15 01:22:53 +08:00
parent fa8902c778
commit 710e92b285

View File

@ -24,7 +24,7 @@ from typing import (
Type, Type,
TypeVar, TypeVar,
Union, Union,
cast cast,
) )
import types import types
@ -177,28 +177,31 @@ def Field(
) -> Any: ) -> Any:
current_schema_extra = schema_extra or {} current_schema_extra = schema_extra or {}
if default is PydanticUndefined: if default is PydanticUndefined:
if isinstance(sa_column, types.FunctionType): #lambda if isinstance(sa_column, types.FunctionType): # lambda
sa_column_ = sa_column() sa_column_ = sa_column()
else: else:
sa_column_ = sa_column sa_column_ = sa_column
#server_default -> default # server_default -> default
if isinstance(sa_column_, Column) and isinstance(sa_column_.server_default, DefaultClause): if isinstance(sa_column_, Column) and isinstance(
default_value = sa_column_.server_default.arg sa_column_.server_default, DefaultClause
if issubclass( type(sa_column_.type), Integer) and isinstance(default_value, str): ):
default_value = sa_column_.server_default.arg
if issubclass(type(sa_column_.type), Integer) and isinstance(
default_value, str
):
default = int(default_value) default = int(default_value)
elif issubclass( type(sa_column_.type), Boolean): elif issubclass(type(sa_column_.type), Boolean):
if default_value is false(): if default_value is false():
default = False default = False
elif default_value is true(): elif default_value is true():
default = True default = True
elif isinstance(default_value, str): elif isinstance(default_value, str):
if default_value == '1': if default_value == "1":
default =True default = True
elif default_value == '0': elif default_value == "0":
default = False default = False
field_info = FieldInfo( field_info = FieldInfo(
default, default,
default_factory=default_factory, default_factory=default_factory,
@ -447,6 +450,7 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
return get_origin(type_) in (types.UnionType, Union) return get_origin(type_) in (types.UnionType, Union)
else: else:
return get_origin(type_) is Union return get_origin(type_) is Union
if type_ is not None and is_optional_or_union(type_): if type_ is not None and is_optional_or_union(type_):
bases = get_args(type_) bases = get_args(type_)
if len(bases) > 2: if len(bases) > 2:
@ -459,13 +463,13 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
# UrlConstraints(max_length=512, # UrlConstraints(max_length=512,
# allowed_schemes=['smb', 'ftp', 'file']) ] # allowed_schemes=['smb', 'ftp', 'file']) ]
if type_ is pydantic.AnyUrl: if type_ is pydantic.AnyUrl:
meta = field.metadata[0] meta = field.metadata[0]
return AutoString(length=meta.max_length) return AutoString(length=meta.max_length)
if get_origin(type_) is Annotated: if get_origin(type_) is Annotated:
type2 = get_args(type_)[0] type2 = get_args(type_)[0]
if type2 is pydantic.AnyUrl: if type2 is pydantic.AnyUrl:
meta = get_args(type_)[1] meta = get_args(type_)[1]
return AutoString(length=meta.max_length) return AutoString(length=meta.max_length)
# The 3rd is PydanticGeneralMetadata # The 3rd is PydanticGeneralMetadata