mirror of
https://github.com/PaiGramTeam/sqlmodel.git
synced 2024-11-25 09:27:40 +00:00
support 3.9/use get_args
This commit is contained in:
parent
347e052656
commit
9da0407609
@ -10,8 +10,7 @@ from enum import Enum
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
AbstractSet,
|
AbstractSet,
|
||||||
Any,
|
Any,
|
||||||
Annotated,
|
|
||||||
Callable,
|
Callable,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
Dict,
|
Dict,
|
||||||
@ -26,6 +25,7 @@ from typing import (
|
|||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
|
get_args
|
||||||
)
|
)
|
||||||
import types
|
import types
|
||||||
|
|
||||||
@ -44,16 +44,16 @@ from sqlalchemy.orm.attributes import set_attribute
|
|||||||
from sqlalchemy.orm.decl_api import DeclarativeMeta
|
from sqlalchemy.orm.decl_api import DeclarativeMeta
|
||||||
from sqlalchemy.orm.instrumentation import is_instrumented
|
from sqlalchemy.orm.instrumentation import is_instrumented
|
||||||
from sqlalchemy.orm.properties import MappedColumn
|
from sqlalchemy.orm.properties import MappedColumn
|
||||||
from sqlalchemy.sql.schema import MetaData
|
from sqlalchemy.sql.schema import MetaData, DefaultClause
|
||||||
from sqlalchemy.sql.sqltypes import LargeBinary, Time
|
from sqlalchemy.sql.sqltypes import LargeBinary, Time
|
||||||
|
|
||||||
from .sql.sqltypes import GUID, AutoString
|
from .sql.sqltypes import GUID, AutoString
|
||||||
from .typing import SQLModelConfig
|
from .typing import SQLModelConfig
|
||||||
|
|
||||||
if sys.version_info >= (3, 8):
|
if sys.version_info >= (3, 8):
|
||||||
from typing import get_args, get_origin
|
from typing import get_args, get_origin, Annotated
|
||||||
else:
|
else:
|
||||||
from typing_extensions import get_args, get_origin
|
from typing_extensions import get_args, get_origin, Annotated
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
NoArgAnyCallable = Callable[[], Any]
|
NoArgAnyCallable = Callable[[], Any]
|
||||||
@ -162,7 +162,7 @@ def Field(
|
|||||||
unique: bool = False,
|
unique: bool = False,
|
||||||
nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined,
|
nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined,
|
||||||
index: Union[bool, PydanticUndefinedType] = PydanticUndefined,
|
index: Union[bool, PydanticUndefinedType] = PydanticUndefined,
|
||||||
sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore
|
sa_column: Union[Column, PydanticUndefinedType, types.FunctionType] = PydanticUndefined, # type: ignore
|
||||||
sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined,
|
sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined,
|
||||||
sa_column_kwargs: Union[
|
sa_column_kwargs: Union[
|
||||||
Mapping[str, Any], PydanticUndefinedType
|
Mapping[str, Any], PydanticUndefinedType
|
||||||
@ -176,8 +176,8 @@ def Field(
|
|||||||
else:
|
else:
|
||||||
sa_column_ = sa_column
|
sa_column_ = sa_column
|
||||||
|
|
||||||
#assert type(sa_column) is Column:
|
#server_default -> default
|
||||||
if isinstance(sa_column_, Column) and sa_column_.server_default is not None:
|
if isinstance(sa_column_, Column) and isinstance(sa_column_.server_default, DefaultClause):
|
||||||
default_value = sa_column_.server_default.arg
|
default_value = sa_column_.server_default.arg
|
||||||
if issubclass( type(sa_column_.type), Integer):
|
if issubclass( type(sa_column_.type), Integer):
|
||||||
default = int(default_value)
|
default = int(default_value)
|
||||||
@ -441,9 +441,9 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
|
|||||||
return AutoString(length=meta.max_length)
|
return AutoString(length=meta.max_length)
|
||||||
|
|
||||||
if get_origin(type_) is Annotated:
|
if get_origin(type_) is Annotated:
|
||||||
type2 = type_.__args__[0]
|
type2 = get_args(type_)[0]
|
||||||
if type2 is pydantic.AnyUrl:
|
if type2 is pydantic.AnyUrl:
|
||||||
meta = type_.__metadata__[0]
|
meta = get_args(type_)[1]
|
||||||
return AutoString(length=meta.max_length)
|
return AutoString(length=meta.max_length)
|
||||||
|
|
||||||
# The 3rd is PydanticGeneralMetadata
|
# The 3rd is PydanticGeneralMetadata
|
||||||
@ -495,8 +495,10 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
|
|||||||
|
|
||||||
def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
|
def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
|
||||||
sa_column = getattr(field, "sa_column", PydanticUndefined)
|
sa_column = getattr(field, "sa_column", PydanticUndefined)
|
||||||
if isinstance(sa_column, Column) or isinstance(sa_column, MappedColumn) :
|
if isinstance(sa_column, Column):
|
||||||
return sa_column
|
return sa_column
|
||||||
|
if isinstance(sa_column, MappedColumn):
|
||||||
|
return sa_column.column
|
||||||
sa_type = get_sqlalchemy_type(field)
|
sa_type = get_sqlalchemy_type(field)
|
||||||
primary_key = getattr(field, "primary_key", False)
|
primary_key = getattr(field, "primary_key", False)
|
||||||
index = getattr(field, "index", PydanticUndefined)
|
index = getattr(field, "index", PydanticUndefined)
|
||||||
|
Loading…
Reference in New Issue
Block a user