support 3.9/use get_args

This commit is contained in:
honglei 2023-08-14 21:03:41 +08:00
parent 347e052656
commit 9da0407609

View File

@ -10,8 +10,7 @@ from enum import Enum
from pathlib import Path
from typing import (
AbstractSet,
Any,
Annotated,
Any,
Callable,
ClassVar,
Dict,
@ -26,6 +25,7 @@ from typing import (
TypeVar,
Union,
cast,
get_args
)
import types
@ -44,16 +44,16 @@ from sqlalchemy.orm.attributes import set_attribute
from sqlalchemy.orm.decl_api import DeclarativeMeta
from sqlalchemy.orm.instrumentation import is_instrumented
from sqlalchemy.orm.properties import MappedColumn
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.schema import MetaData, DefaultClause
from sqlalchemy.sql.sqltypes import LargeBinary, Time
from .sql.sqltypes import GUID, AutoString
from .typing import SQLModelConfig
if sys.version_info >= (3, 8):
from typing import get_args, get_origin
from typing import get_args, get_origin, Annotated
else:
from typing_extensions import get_args, get_origin
from typing_extensions import get_args, get_origin, Annotated
_T = TypeVar("_T")
NoArgAnyCallable = Callable[[], Any]
@ -162,7 +162,7 @@ def Field(
unique: bool = False,
nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined,
index: Union[bool, PydanticUndefinedType] = PydanticUndefined,
sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore
sa_column: Union[Column, PydanticUndefinedType, types.FunctionType] = PydanticUndefined, # type: ignore
sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined,
sa_column_kwargs: Union[
Mapping[str, Any], PydanticUndefinedType
@ -176,8 +176,8 @@ def Field(
else:
sa_column_ = sa_column
#assert type(sa_column) is Column:
if isinstance(sa_column_, Column) and sa_column_.server_default is not None:
#server_default -> default
if isinstance(sa_column_, Column) and isinstance(sa_column_.server_default, DefaultClause):
default_value = sa_column_.server_default.arg
if issubclass( type(sa_column_.type), Integer):
default = int(default_value)
@ -441,9 +441,9 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
return AutoString(length=meta.max_length)
if get_origin(type_) is Annotated:
type2 = type_.__args__[0]
type2 = get_args(type_)[0]
if type2 is pydantic.AnyUrl:
meta = type_.__metadata__[0]
meta = get_args(type_)[1]
return AutoString(length=meta.max_length)
# The 3rd is PydanticGeneralMetadata
@ -495,8 +495,10 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
sa_column = getattr(field, "sa_column", PydanticUndefined)
if isinstance(sa_column, Column) or isinstance(sa_column, MappedColumn) :
if isinstance(sa_column, Column):
return sa_column
if isinstance(sa_column, MappedColumn):
return sa_column.column
sa_type = get_sqlalchemy_type(field)
primary_key = getattr(field, "primary_key", False)
index = getattr(field, "index", PydanticUndefined)