From 710e92b2859680e6c3acf1af917e90898b610582 Mon Sep 17 00:00:00 2001 From: honglei Date: Tue, 15 Aug 2023 01:22:53 +0800 Subject: [PATCH] black it --- sqlmodel/main.py | 44 ++++++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index a88826a..1e9c6a9 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -10,7 +10,7 @@ from enum import Enum from pathlib import Path from typing import ( AbstractSet, - Any, + Any, Callable, ClassVar, Dict, @@ -24,7 +24,7 @@ from typing import ( Type, TypeVar, Union, - cast + cast, ) import types @@ -55,7 +55,7 @@ if sys.version_info >= (3, 8): from typing import get_args, get_origin else: from typing_extensions import get_args, get_origin - + if sys.version_info >= (3, 9): from typing import Annotated else: @@ -177,27 +177,30 @@ def Field( ) -> Any: current_schema_extra = schema_extra or {} if default is PydanticUndefined: - if isinstance(sa_column, types.FunctionType): #lambda + if isinstance(sa_column, types.FunctionType): # lambda sa_column_ = sa_column() else: sa_column_ = sa_column - - #server_default -> default - if isinstance(sa_column_, Column) and isinstance(sa_column_.server_default, DefaultClause): - default_value = sa_column_.server_default.arg - if issubclass( type(sa_column_.type), Integer) and isinstance(default_value, str): + + # server_default -> default + if isinstance(sa_column_, Column) and isinstance( + sa_column_.server_default, DefaultClause + ): + default_value = sa_column_.server_default.arg + if issubclass(type(sa_column_.type), Integer) and isinstance( + default_value, str + ): default = int(default_value) - elif issubclass( type(sa_column_.type), Boolean): + elif issubclass(type(sa_column_.type), Boolean): if default_value is false(): default = False elif default_value is true(): default = True elif isinstance(default_value, str): - if default_value == '1': - default =True - elif default_value == '0': + if default_value == "1": + default = True + elif default_value == "0": default = False - field_info = FieldInfo( default, @@ -447,6 +450,7 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any: return get_origin(type_) in (types.UnionType, Union) else: return get_origin(type_) is Union + if type_ is not None and is_optional_or_union(type_): bases = get_args(type_) if len(bases) > 2: @@ -455,17 +459,17 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any: ) type_ = bases[0] # Resolve Annoted fields, - # like typing.Annotated[pydantic_core._pydantic_core.Url, - # UrlConstraints(max_length=512, + # like typing.Annotated[pydantic_core._pydantic_core.Url, + # UrlConstraints(max_length=512, # allowed_schemes=['smb', 'ftp', 'file']) ] if type_ is pydantic.AnyUrl: - meta = field.metadata[0] + meta = field.metadata[0] return AutoString(length=meta.max_length) - - if get_origin(type_) is Annotated: + + if get_origin(type_) is Annotated: type2 = get_args(type_)[0] if type2 is pydantic.AnyUrl: - meta = get_args(type_)[1] + meta = get_args(type_)[1] return AutoString(length=meta.max_length) # The 3rd is PydanticGeneralMetadata