get_column_from_field:sa_column>field attribute>field annotation

This commit is contained in:
honglei 2023-08-16 21:22:57 +08:00
parent 9e07c1c772
commit 5b49f778c3

View File

@ -440,17 +440,19 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
def _is_optional_or_union(type_: Optional[type]) -> bool:
if sys.version_info >= (3, 10):
return get_origin(type_) in (types.UnionType, Union)
else:
return get_origin(type_) is Union
def get_sqlalchemy_type(field: FieldInfo) -> Any:
type_: Optional[type] = field.annotation
# Resolve Optional/Union fields
def is_optional_or_union(type_: Optional[type]) -> bool:
if sys.version_info >= (3, 10):
return get_origin(type_) in (types.UnionType, Union)
else:
return get_origin(type_) is Union
if type_ is not None and is_optional_or_union(type_):
if type_ is not None and _is_optional_or_union(type_):
bases = get_args(type_)
if len(bases) > 2:
raise RuntimeError(
@ -519,15 +521,27 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
"""
sa_column > field attributes > annotation info
"""
sa_column = getattr(field, "sa_column", PydanticUndefined)
col: Column | None = None
if isinstance(sa_column, Column):
return sa_column
if isinstance(sa_column, MappedColumn):
return sa_column.column
if isinstance(sa_column, types.FunctionType):
col = sa_column
elif isinstance(sa_column, MappedColumn):
col = sa_column.column
elif isinstance(sa_column, types.FunctionType):
col = sa_column()
assert isinstance(col, Column)
if isinstance(col, Column):
# field attribute or field annotation -> Column.nullable
if col.nullable is PydanticUndefined:
col.nullable = _is_field_noneable(field)
# field.primary_key -> Column.primary_key
if col.primary_key is PydanticUndefined:
primary_key = getattr(field, "primary_key", False)
col.primary_key = primary_key
return col
sa_type = get_sqlalchemy_type(field)
primary_key = getattr(field, "primary_key", False)
index = getattr(field, "index", PydanticUndefined)
@ -661,14 +675,17 @@ def _is_field_noneable(field: FieldInfo) -> bool:
return field.nullable
if not field.is_required():
default = getattr(field, "original_default", field.default)
if default is PydanticUndefined:
if default is None:
return True
elif default is not PydanticUndefined:
return False
if field.annotation is None or field.annotation is NoneType:
return True
if get_origin(field.annotation) is Union:
if _is_optional_or_union(field.annotation):
for base in get_args(field.annotation):
if base is NoneType:
return True
return False
return False