Formatting

This commit is contained in:
Anton De Meester 2023-07-31 13:40:23 +00:00
parent a7f8a27c20
commit 6955600120
4 changed files with 31 additions and 23 deletions

View File

@ -19,17 +19,19 @@ from typing import (
Set,
Tuple,
Type,
TypeVar,ForwardRef,
TypeVar,
Union,
cast,get_origin,get_args
cast,
get_args,
get_origin,
)
from pydantic import BaseModel, ValidationError
from pydantic_core import PydanticUndefined, PydanticUndefinedType
from pydantic.fields import FieldInfo as PydanticFieldInfo
from pydantic import BaseModel
from pydantic._internal._fields import PydanticGeneralMetadata
from pydantic._internal._model_construction import ModelMetaclass
from pydantic._internal._repr import Representation
from pydantic._internal._fields import PydanticGeneralMetadata
from pydantic.fields import FieldInfo as PydanticFieldInfo
from pydantic_core import PydanticUndefined, PydanticUndefinedType
from sqlalchemy import Boolean, Column, Date, DateTime
from sqlalchemy import Enum as sa_Enum
from sqlalchemy import Float, ForeignKey, Integer, Interval, Numeric, inspect
@ -149,7 +151,9 @@ def Field(
index: Union[bool, PydanticUndefinedType] = PydanticUndefined,
sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore
sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined,
sa_column_kwargs: Union[Mapping[str, Any], PydanticUndefinedType] = PydanticUndefined,
sa_column_kwargs: Union[
Mapping[str, Any], PydanticUndefinedType
] = PydanticUndefined,
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
current_schema_extra = schema_extra or {}
@ -268,7 +272,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
key: pydantic_kwargs.pop(key)
for key in pydantic_kwargs.keys() & allowed_config_kwargs
}
config_table = getattr(class_dict.get('Config', object()), 'table', False)
config_table = getattr(class_dict.get("Config", object()), "table", False)
# If we have a table, we need to have defaults for all fields
# Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything
if config_table is True:
@ -295,7 +299,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
config_table = get_config("table")
if config_table is True:
# If it was passed by kwargs, ensure it's also set in config
new_cls.model_config['table'] = config_table
new_cls.model_config["table"] = config_table
for k, v in new_cls.model_fields.items():
col = get_column_from_field(v)
setattr(new_cls, k, col)
@ -304,13 +308,13 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
# This could be done by reading new_cls.model_config['table'] in FastAPI, but
# that's very specific about SQLModel, so let's have another config that
# other future tools based on Pydantic can use.
new_cls.model_config['read_from_attributes'] = True
new_cls.model_config["read_from_attributes"] = True
config_registry = get_config("registry")
if config_registry is not PydanticUndefined:
config_registry = cast(registry, config_registry)
# If it was passed by kwargs, ensure it's also set in config
new_cls.model_config['registry'] = config_table
new_cls.model_config["registry"] = config_table
setattr(new_cls, "_sa_registry", config_registry)
setattr(new_cls, "metadata", config_registry.metadata)
setattr(new_cls, "__abstract__", True)
@ -383,7 +387,9 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
if type_ is not None and get_origin(type_) is Union:
bases = get_args(type_)
if len(bases) > 2:
raise RuntimeError("Cannot have a (non-optional) union as a SQL alchemy field")
raise RuntimeError(
"Cannot have a (non-optional) union as a SQL alchemy field"
)
type_ = bases[0]
# The 3rd is PydanticGeneralMetadata
@ -516,7 +522,6 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
# Don't show SQLAlchemy private attributes
return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_sa_")]
@declared_attr # type: ignore
def __tablename__(cls) -> str:
return cls.__name__.lower()
@ -533,6 +538,7 @@ def _is_field_noneable(field: FieldInfo) -> bool:
return False
return False
def _get_field_metadata(field: FieldInfo) -> object:
for meta in field.metadata:
if isinstance(meta, PydanticGeneralMetadata):

View File

@ -1,5 +1,7 @@
from typing import Any, Optional
from pydantic import ConfigDict
from typing import Optional, Any
class SQLModelConfig(ConfigDict):
table: Optional[bool]

View File

@ -1,8 +1,8 @@
from typing import Optional
import pytest
from sqlmodel import Field, SQLModel
from pydantic import BaseModel
from sqlmodel import Field, SQLModel
def test_missing_sql_type():