support 3.9/use get_args

This commit is contained in:
honglei 2023-08-14 21:03:41 +08:00
parent 347e052656
commit 9da0407609

View File

@ -10,8 +10,7 @@ from enum import Enum
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
AbstractSet, AbstractSet,
Any, Any,
Annotated,
Callable, Callable,
ClassVar, ClassVar,
Dict, Dict,
@ -26,6 +25,7 @@ from typing import (
TypeVar, TypeVar,
Union, Union,
cast, cast,
get_args
) )
import types import types
@ -44,16 +44,16 @@ from sqlalchemy.orm.attributes import set_attribute
from sqlalchemy.orm.decl_api import DeclarativeMeta from sqlalchemy.orm.decl_api import DeclarativeMeta
from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.orm.instrumentation import is_instrumented
from sqlalchemy.orm.properties import MappedColumn from sqlalchemy.orm.properties import MappedColumn
from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.schema import MetaData, DefaultClause
from sqlalchemy.sql.sqltypes import LargeBinary, Time from sqlalchemy.sql.sqltypes import LargeBinary, Time
from .sql.sqltypes import GUID, AutoString from .sql.sqltypes import GUID, AutoString
from .typing import SQLModelConfig from .typing import SQLModelConfig
if sys.version_info >= (3, 8): if sys.version_info >= (3, 8):
from typing import get_args, get_origin from typing import get_args, get_origin, Annotated
else: else:
from typing_extensions import get_args, get_origin from typing_extensions import get_args, get_origin, Annotated
_T = TypeVar("_T") _T = TypeVar("_T")
NoArgAnyCallable = Callable[[], Any] NoArgAnyCallable = Callable[[], Any]
@ -162,7 +162,7 @@ def Field(
unique: bool = False, unique: bool = False,
nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined, nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined,
index: Union[bool, PydanticUndefinedType] = PydanticUndefined, index: Union[bool, PydanticUndefinedType] = PydanticUndefined,
sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore sa_column: Union[Column, PydanticUndefinedType, types.FunctionType] = PydanticUndefined, # type: ignore
sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined, sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined,
sa_column_kwargs: Union[ sa_column_kwargs: Union[
Mapping[str, Any], PydanticUndefinedType Mapping[str, Any], PydanticUndefinedType
@ -176,8 +176,8 @@ def Field(
else: else:
sa_column_ = sa_column sa_column_ = sa_column
#assert type(sa_column) is Column: #server_default -> default
if isinstance(sa_column_, Column) and sa_column_.server_default is not None: if isinstance(sa_column_, Column) and isinstance(sa_column_.server_default, DefaultClause):
default_value = sa_column_.server_default.arg default_value = sa_column_.server_default.arg
if issubclass( type(sa_column_.type), Integer): if issubclass( type(sa_column_.type), Integer):
default = int(default_value) default = int(default_value)
@ -441,9 +441,9 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
return AutoString(length=meta.max_length) return AutoString(length=meta.max_length)
if get_origin(type_) is Annotated: if get_origin(type_) is Annotated:
type2 = type_.__args__[0] type2 = get_args(type_)[0]
if type2 is pydantic.AnyUrl: if type2 is pydantic.AnyUrl:
meta = type_.__metadata__[0] meta = get_args(type_)[1]
return AutoString(length=meta.max_length) return AutoString(length=meta.max_length)
# The 3rd is PydanticGeneralMetadata # The 3rd is PydanticGeneralMetadata
@ -495,8 +495,10 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
def get_column_from_field(field: FieldInfo) -> Column: # type: ignore def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
sa_column = getattr(field, "sa_column", PydanticUndefined) sa_column = getattr(field, "sa_column", PydanticUndefined)
if isinstance(sa_column, Column) or isinstance(sa_column, MappedColumn) : if isinstance(sa_column, Column):
return sa_column return sa_column
if isinstance(sa_column, MappedColumn):
return sa_column.column
sa_type = get_sqlalchemy_type(field) sa_type = get_sqlalchemy_type(field)
primary_key = getattr(field, "primary_key", False) primary_key = getattr(field, "primary_key", False)
index = getattr(field, "index", PydanticUndefined) index = getattr(field, "index", PydanticUndefined)