mirror of
https://github.com/PaiGramTeam/sqlmodel.git
synced 2024-11-21 22:58:22 +00:00
get_column_from_field:sa_column>field attribute>field annotation
This commit is contained in:
parent
9e07c1c772
commit
5b49f778c3
@ -440,17 +440,19 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
||||
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
|
||||
|
||||
|
||||
def _is_optional_or_union(type_: Optional[type]) -> bool:
|
||||
if sys.version_info >= (3, 10):
|
||||
return get_origin(type_) in (types.UnionType, Union)
|
||||
else:
|
||||
return get_origin(type_) is Union
|
||||
|
||||
|
||||
def get_sqlalchemy_type(field: FieldInfo) -> Any:
|
||||
type_: Optional[type] = field.annotation
|
||||
|
||||
# Resolve Optional/Union fields
|
||||
def is_optional_or_union(type_: Optional[type]) -> bool:
|
||||
if sys.version_info >= (3, 10):
|
||||
return get_origin(type_) in (types.UnionType, Union)
|
||||
else:
|
||||
return get_origin(type_) is Union
|
||||
|
||||
if type_ is not None and is_optional_or_union(type_):
|
||||
if type_ is not None and _is_optional_or_union(type_):
|
||||
bases = get_args(type_)
|
||||
if len(bases) > 2:
|
||||
raise RuntimeError(
|
||||
@ -519,15 +521,27 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
|
||||
|
||||
|
||||
def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
|
||||
"""
|
||||
sa_column > field attributes > annotation info
|
||||
"""
|
||||
sa_column = getattr(field, "sa_column", PydanticUndefined)
|
||||
col: Column | None = None
|
||||
if isinstance(sa_column, Column):
|
||||
return sa_column
|
||||
if isinstance(sa_column, MappedColumn):
|
||||
return sa_column.column
|
||||
if isinstance(sa_column, types.FunctionType):
|
||||
col = sa_column
|
||||
elif isinstance(sa_column, MappedColumn):
|
||||
col = sa_column.column
|
||||
elif isinstance(sa_column, types.FunctionType):
|
||||
col = sa_column()
|
||||
assert isinstance(col, Column)
|
||||
if isinstance(col, Column):
|
||||
# field attribute or field annotation -> Column.nullable
|
||||
if col.nullable is PydanticUndefined:
|
||||
col.nullable = _is_field_noneable(field)
|
||||
# field.primary_key -> Column.primary_key
|
||||
if col.primary_key is PydanticUndefined:
|
||||
primary_key = getattr(field, "primary_key", False)
|
||||
col.primary_key = primary_key
|
||||
return col
|
||||
|
||||
sa_type = get_sqlalchemy_type(field)
|
||||
primary_key = getattr(field, "primary_key", False)
|
||||
index = getattr(field, "index", PydanticUndefined)
|
||||
@ -661,14 +675,17 @@ def _is_field_noneable(field: FieldInfo) -> bool:
|
||||
return field.nullable
|
||||
if not field.is_required():
|
||||
default = getattr(field, "original_default", field.default)
|
||||
if default is PydanticUndefined:
|
||||
if default is None:
|
||||
return True
|
||||
elif default is not PydanticUndefined:
|
||||
return False
|
||||
if field.annotation is None or field.annotation is NoneType:
|
||||
return True
|
||||
if get_origin(field.annotation) is Union:
|
||||
if _is_optional_or_union(field.annotation):
|
||||
for base in get_args(field.annotation):
|
||||
if base is NoneType:
|
||||
return True
|
||||
|
||||
return False
|
||||
return False
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user