Formatting

This commit is contained in:
Anton De Meester 2023-07-31 13:40:23 +00:00
parent a7f8a27c20
commit 6955600120
4 changed files with 31 additions and 23 deletions

View File

@ -19,17 +19,19 @@ from typing import (
Set, Set,
Tuple, Tuple,
Type, Type,
TypeVar,ForwardRef, TypeVar,
Union, Union,
cast,get_origin,get_args cast,
get_args,
get_origin,
) )
from pydantic import BaseModel, ValidationError from pydantic import BaseModel
from pydantic_core import PydanticUndefined, PydanticUndefinedType from pydantic._internal._fields import PydanticGeneralMetadata
from pydantic.fields import FieldInfo as PydanticFieldInfo
from pydantic._internal._model_construction import ModelMetaclass from pydantic._internal._model_construction import ModelMetaclass
from pydantic._internal._repr import Representation from pydantic._internal._repr import Representation
from pydantic._internal._fields import PydanticGeneralMetadata from pydantic.fields import FieldInfo as PydanticFieldInfo
from pydantic_core import PydanticUndefined, PydanticUndefinedType
from sqlalchemy import Boolean, Column, Date, DateTime from sqlalchemy import Boolean, Column, Date, DateTime
from sqlalchemy import Enum as sa_Enum from sqlalchemy import Enum as sa_Enum
from sqlalchemy import Float, ForeignKey, Integer, Interval, Numeric, inspect from sqlalchemy import Float, ForeignKey, Integer, Interval, Numeric, inspect
@ -149,7 +151,9 @@ def Field(
index: Union[bool, PydanticUndefinedType] = PydanticUndefined, index: Union[bool, PydanticUndefinedType] = PydanticUndefined,
sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore
sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined, sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined,
sa_column_kwargs: Union[Mapping[str, Any], PydanticUndefinedType] = PydanticUndefined, sa_column_kwargs: Union[
Mapping[str, Any], PydanticUndefinedType
] = PydanticUndefined,
schema_extra: Optional[Dict[str, Any]] = None, schema_extra: Optional[Dict[str, Any]] = None,
) -> Any: ) -> Any:
current_schema_extra = schema_extra or {} current_schema_extra = schema_extra or {}
@ -231,7 +235,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
class_dict: Dict[str, Any], class_dict: Dict[str, Any],
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
relationships: Dict[str, RelationshipInfo] = {} relationships: Dict[str, RelationshipInfo] = {}
dict_for_pydantic = {} dict_for_pydantic = {}
original_annotations = class_dict.get("__annotations__", {}) original_annotations = class_dict.get("__annotations__", {})
@ -268,21 +272,21 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
key: pydantic_kwargs.pop(key) key: pydantic_kwargs.pop(key)
for key in pydantic_kwargs.keys() & allowed_config_kwargs for key in pydantic_kwargs.keys() & allowed_config_kwargs
} }
config_table = getattr(class_dict.get('Config', object()), 'table', False) config_table = getattr(class_dict.get("Config", object()), "table", False)
# If we have a table, we need to have defaults for all fields # If we have a table, we need to have defaults for all fields
# Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything # Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything
if config_table is True: if config_table is True:
for key in original_annotations.keys(): for key in original_annotations.keys():
if dict_used.get(key, PydanticUndefined) is PydanticUndefined: if dict_used.get(key, PydanticUndefined) is PydanticUndefined:
dict_used[key] = None dict_used[key] = None
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
new_cls.__annotations__ = { new_cls.__annotations__ = {
**relationship_annotations, **relationship_annotations,
**pydantic_annotations, **pydantic_annotations,
**new_cls.__annotations__, **new_cls.__annotations__,
} }
def get_config(name: str) -> Any: def get_config(name: str) -> Any:
config_class_value = new_cls.model_config.get(name, PydanticUndefined) config_class_value = new_cls.model_config.get(name, PydanticUndefined)
if config_class_value is not PydanticUndefined: if config_class_value is not PydanticUndefined:
@ -295,7 +299,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
config_table = get_config("table") config_table = get_config("table")
if config_table is True: if config_table is True:
# If it was passed by kwargs, ensure it's also set in config # If it was passed by kwargs, ensure it's also set in config
new_cls.model_config['table'] = config_table new_cls.model_config["table"] = config_table
for k, v in new_cls.model_fields.items(): for k, v in new_cls.model_fields.items():
col = get_column_from_field(v) col = get_column_from_field(v)
setattr(new_cls, k, col) setattr(new_cls, k, col)
@ -304,13 +308,13 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
# This could be done by reading new_cls.model_config['table'] in FastAPI, but # This could be done by reading new_cls.model_config['table'] in FastAPI, but
# that's very specific about SQLModel, so let's have another config that # that's very specific about SQLModel, so let's have another config that
# other future tools based on Pydantic can use. # other future tools based on Pydantic can use.
new_cls.model_config['read_from_attributes'] = True new_cls.model_config["read_from_attributes"] = True
config_registry = get_config("registry") config_registry = get_config("registry")
if config_registry is not PydanticUndefined: if config_registry is not PydanticUndefined:
config_registry = cast(registry, config_registry) config_registry = cast(registry, config_registry)
# If it was passed by kwargs, ensure it's also set in config # If it was passed by kwargs, ensure it's also set in config
new_cls.model_config['registry'] = config_table new_cls.model_config["registry"] = config_table
setattr(new_cls, "_sa_registry", config_registry) setattr(new_cls, "_sa_registry", config_registry)
setattr(new_cls, "metadata", config_registry.metadata) setattr(new_cls, "metadata", config_registry.metadata)
setattr(new_cls, "__abstract__", True) setattr(new_cls, "__abstract__", True)
@ -367,7 +371,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
rel_args.extend(rel_info.sa_relationship_args) rel_args.extend(rel_info.sa_relationship_args)
if rel_info.sa_relationship_kwargs: if rel_info.sa_relationship_kwargs:
rel_kwargs.update(rel_info.sa_relationship_kwargs) rel_kwargs.update(rel_info.sa_relationship_kwargs)
rel_value: RelationshipProperty = relationship( rel_value: RelationshipProperty = relationship(
relationship_to, *rel_args, **rel_kwargs relationship_to, *rel_args, **rel_kwargs
) )
dict_used[rel_name] = rel_value dict_used[rel_name] = rel_value
@ -383,8 +387,10 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
if type_ is not None and get_origin(type_) is Union: if type_ is not None and get_origin(type_) is Union:
bases = get_args(type_) bases = get_args(type_)
if len(bases) > 2: if len(bases) > 2:
raise RuntimeError("Cannot have a (non-optional) union as a SQL alchemy field") raise RuntimeError(
type_ = bases[0] "Cannot have a (non-optional) union as a SQL alchemy field"
)
type_ = bases[0]
# The 3rd is PydanticGeneralMetadata # The 3rd is PydanticGeneralMetadata
metadata = _get_field_metadata(field) metadata = _get_field_metadata(field)
@ -516,7 +522,6 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
# Don't show SQLAlchemy private attributes # Don't show SQLAlchemy private attributes
return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_sa_")] return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_sa_")]
@declared_attr # type: ignore @declared_attr # type: ignore
def __tablename__(cls) -> str: def __tablename__(cls) -> str:
return cls.__name__.lower() return cls.__name__.lower()
@ -533,8 +538,9 @@ def _is_field_noneable(field: FieldInfo) -> bool:
return False return False
return False return False
def _get_field_metadata(field: FieldInfo) -> object: def _get_field_metadata(field: FieldInfo) -> object:
for meta in field.metadata: for meta in field.metadata:
if isinstance(meta, PydanticGeneralMetadata): if isinstance(meta, PydanticGeneralMetadata):
return meta return meta
return object() return object()

View File

@ -1,7 +1,9 @@
from typing import Any, Optional
from pydantic import ConfigDict from pydantic import ConfigDict
from typing import Optional, Any
class SQLModelConfig(ConfigDict): class SQLModelConfig(ConfigDict):
table: Optional[bool] table: Optional[bool]
read_from_attributes: Optional[bool] read_from_attributes: Optional[bool]
registry: Optional[Any] registry: Optional[Any]

View File

@ -8,7 +8,7 @@ from sqlmodel import Field, SQLModel
def test_allow_instantiation_without_arguments(clear_sqlmodel): def test_allow_instantiation_without_arguments(clear_sqlmodel):
class Item(SQLModel): class Item(SQLModel):
id: Optional[int] = Field(default=None, primary_key=True) id: Optional[int] = Field(default=None, primary_key=True)
name: str name: str
description: Optional[str] = None description: Optional[str] = None
class Config: class Config:

View File

@ -1,8 +1,8 @@
from typing import Optional from typing import Optional
import pytest import pytest
from sqlmodel import Field, SQLModel
from pydantic import BaseModel from pydantic import BaseModel
from sqlmodel import Field, SQLModel
def test_missing_sql_type(): def test_missing_sql_type():