Upgrade to Pydantic 2

Change imports
Undefined => PydanticUndefined
Update SQLModelMetaclass and SQLModel __init__ and __new__ functions
Update SQL Alchemy type inference
This commit is contained in:
Anton De Meester 2023-07-31 13:38:26 +00:00
parent b1848af842
commit a7f8a27c20
17 changed files with 175 additions and 283 deletions

View File

@ -174,13 +174,13 @@ Now we use the type annotation `HeroCreate` for the request JSON data in the `he
# Code below omitted 👇 # Code below omitted 👇
``` ```
Then we create a new `Hero` (this is the actual **table** model that saves things to the database) using `Hero.from_orm()`. Then we create a new `Hero` (this is the actual **table** model that saves things to the database) using `Hero.model_validate()`.
The method `.from_orm()` reads data from another object with attributes and creates a new instance of this class, in this case `Hero`. The method `.model_validate()` reads data from another object with attributes and creates a new instance of this class, in this case `Hero`.
The alternative is `Hero.parse_obj()` that reads data from a dictionary. The alternative is `Hero.parse_obj()` that reads data from a dictionary.
But as in this case, we have a `HeroCreate` instance in the `hero` variable. This is an object with attributes, so we use `.from_orm()` to read those attributes. But as in this case, we have a `HeroCreate` instance in the `hero` variable. This is an object with attributes, so we use `.model_validate()` to read those attributes.
With this, we create a new `Hero` instance (the one for the database) and put it in the variable `db_hero` from the data in the `hero` variable that is the `HeroCreate` instance we received from the request. With this, we create a new `Hero` instance (the one for the database) and put it in the variable `db_hero` from the data in the `hero` variable that is the `HeroCreate` instance we received from the request.

View File

@ -54,7 +54,7 @@ def on_startup():
@app.post("/heroes/", response_model=HeroRead) @app.post("/heroes/", response_model=HeroRead)
def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate):
db_hero = Hero.from_orm(hero) db_hero = Hero.model_validate(hero)
session.add(db_hero) session.add(db_hero)
session.commit() session.commit()
session.refresh(db_hero) session.refresh(db_hero)

View File

@ -50,7 +50,7 @@ def on_startup():
@app.post("/heroes/", response_model=HeroRead) @app.post("/heroes/", response_model=HeroRead)
def create_hero(hero: HeroCreate): def create_hero(hero: HeroCreate):
with Session(engine) as session: with Session(engine) as session:
db_hero = Hero.from_orm(hero) db_hero = Hero.model_validate(hero)
session.add(db_hero) session.add(db_hero)
session.commit() session.commit()
session.refresh(db_hero) session.refresh(db_hero)

View File

@ -44,7 +44,7 @@ def on_startup():
@app.post("/heroes/", response_model=HeroRead) @app.post("/heroes/", response_model=HeroRead)
def create_hero(hero: HeroCreate): def create_hero(hero: HeroCreate):
with Session(engine) as session: with Session(engine) as session:
db_hero = Hero.from_orm(hero) db_hero = Hero.model_validate(hero)
session.add(db_hero) session.add(db_hero)
session.commit() session.commit()
session.refresh(db_hero) session.refresh(db_hero)

View File

@ -46,7 +46,7 @@ def on_startup():
@app.post("/heroes/", response_model=HeroRead) @app.post("/heroes/", response_model=HeroRead)
def create_hero(hero: HeroCreate): def create_hero(hero: HeroCreate):
with Session(engine) as session: with Session(engine) as session:
db_hero = Hero.from_orm(hero) db_hero = Hero.model_validate(hero)
session.add(db_hero) session.add(db_hero)
session.commit() session.commit()
session.refresh(db_hero) session.refresh(db_hero)

View File

@ -44,7 +44,7 @@ def on_startup():
@app.post("/heroes/", response_model=HeroRead) @app.post("/heroes/", response_model=HeroRead)
def create_hero(hero: HeroCreate): def create_hero(hero: HeroCreate):
with Session(engine) as session: with Session(engine) as session:
db_hero = Hero.from_orm(hero) db_hero = Hero.model_validate(hero)
session.add(db_hero) session.add(db_hero)
session.commit() session.commit()
session.refresh(db_hero) session.refresh(db_hero)

View File

@ -44,7 +44,7 @@ def on_startup():
@app.post("/heroes/", response_model=HeroRead) @app.post("/heroes/", response_model=HeroRead)
def create_hero(hero: HeroCreate): def create_hero(hero: HeroCreate):
with Session(engine) as session: with Session(engine) as session:
db_hero = Hero.from_orm(hero) db_hero = Hero.model_validate(hero)
session.add(db_hero) session.add(db_hero)
session.commit() session.commit()
session.refresh(db_hero) session.refresh(db_hero)

View File

@ -92,7 +92,7 @@ def on_startup():
@app.post("/heroes/", response_model=HeroRead) @app.post("/heroes/", response_model=HeroRead)
def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate):
db_hero = Hero.from_orm(hero) db_hero = Hero.model_validate(hero)
session.add(db_hero) session.add(db_hero)
session.commit() session.commit()
session.refresh(db_hero) session.refresh(db_hero)
@ -147,7 +147,7 @@ def delete_hero(*, session: Session = Depends(get_session), hero_id: int):
@app.post("/teams/", response_model=TeamRead) @app.post("/teams/", response_model=TeamRead)
def create_team(*, session: Session = Depends(get_session), team: TeamCreate): def create_team(*, session: Session = Depends(get_session), team: TeamCreate):
db_team = Team.from_orm(team) db_team = Team.model_validate(team)
session.add(db_team) session.add(db_team)
session.commit() session.commit()
session.refresh(db_team) session.refresh(db_team)

View File

@ -54,7 +54,7 @@ def on_startup():
@app.post("/heroes/", response_model=HeroRead) @app.post("/heroes/", response_model=HeroRead)
def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate):
db_hero = Hero.from_orm(hero) db_hero = Hero.model_validate(hero)
session.add(db_hero) session.add(db_hero)
session.commit() session.commit()
session.refresh(db_hero) session.refresh(db_hero)

View File

@ -83,7 +83,7 @@ def on_startup():
@app.post("/heroes/", response_model=HeroRead) @app.post("/heroes/", response_model=HeroRead)
def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate):
db_hero = Hero.from_orm(hero) db_hero = Hero.model_validate(hero)
session.add(db_hero) session.add(db_hero)
session.commit() session.commit()
session.refresh(db_hero) session.refresh(db_hero)
@ -138,7 +138,7 @@ def delete_hero(*, session: Session = Depends(get_session), hero_id: int):
@app.post("/teams/", response_model=TeamRead) @app.post("/teams/", response_model=TeamRead)
def create_team(*, session: Session = Depends(get_session), team: TeamCreate): def create_team(*, session: Session = Depends(get_session), team: TeamCreate):
db_team = Team.from_orm(team) db_team = Team.model_validate(team)
session.add(db_team) session.add(db_team)
session.commit() session.commit()
session.refresh(db_team) session.refresh(db_team)

View File

@ -50,7 +50,7 @@ def on_startup():
@app.post("/heroes/", response_model=HeroRead) @app.post("/heroes/", response_model=HeroRead)
def create_hero(hero: HeroCreate): def create_hero(hero: HeroCreate):
with Session(engine) as session: with Session(engine) as session:
db_hero = Hero.from_orm(hero) db_hero = Hero.model_validate(hero)
session.add(db_hero) session.add(db_hero)
session.commit() session.commit()
session.refresh(db_hero) session.refresh(db_hero)

View File

@ -33,7 +33,7 @@ classifiers = [
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.7" python = "^3.7"
SQLAlchemy = ">=2.0.0,<=2.0.11" SQLAlchemy = ">=2.0.0,<=2.0.11"
pydantic = "^1.8.2" pydantic = "^2.1.1"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
pytest = "^7.0.1" pytest = "^7.0.1"
@ -46,12 +46,13 @@ pillow = "^9.3.0"
cairosvg = "^2.5.2" cairosvg = "^2.5.2"
mdx-include = "^1.4.1" mdx-include = "^1.4.1"
coverage = {extras = ["toml"], version = "^6.2"} coverage = {extras = ["toml"], version = "^6.2"}
fastapi = "^0.68.1" fastapi = "^0.100.0"
requests = "^2.26.0" requests = "^2.26.0"
autoflake = "^1.4" autoflake = "^1.4"
isort = "^5.9.3" isort = "^5.9.3"
async_generator = {version = "*", python = "~3.7"} async_generator = {version = "*", python = "~3.7"}
async-exit-stack = {version = "*", python = "~3.7"} async-exit-stack = {version = "*", python = "~3.7"}
httpx = "^0.24.1"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]

View File

@ -19,19 +19,17 @@ from typing import (
Set, Set,
Tuple, Tuple,
Type, Type,
TypeVar, TypeVar,ForwardRef,
Union, Union,
cast, cast,get_origin,get_args
) )
from pydantic import BaseConfig, BaseModel from pydantic import BaseModel, ValidationError
from pydantic.errors import ConfigError, DictError from pydantic_core import PydanticUndefined, PydanticUndefinedType
from pydantic.fields import SHAPE_SINGLETON
from pydantic.fields import FieldInfo as PydanticFieldInfo from pydantic.fields import FieldInfo as PydanticFieldInfo
from pydantic.fields import ModelField, Undefined, UndefinedType from pydantic._internal._model_construction import ModelMetaclass
from pydantic.main import ModelMetaclass, validate_model from pydantic._internal._repr import Representation
from pydantic.typing import NoArgAnyCallable, resolve_annotations from pydantic._internal._fields import PydanticGeneralMetadata
from pydantic.utils import ROOT_KEY, Representation
from sqlalchemy import Boolean, Column, Date, DateTime from sqlalchemy import Boolean, Column, Date, DateTime
from sqlalchemy import Enum as sa_Enum from sqlalchemy import Enum as sa_Enum
from sqlalchemy import Float, ForeignKey, Integer, Interval, Numeric, inspect from sqlalchemy import Float, ForeignKey, Integer, Interval, Numeric, inspect
@ -43,8 +41,10 @@ from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.sqltypes import LargeBinary, Time from sqlalchemy.sql.sqltypes import LargeBinary, Time
from .sql.sqltypes import GUID, AutoString from .sql.sqltypes import GUID, AutoString
from .typing import SQLModelConfig
_T = TypeVar("_T") _T = TypeVar("_T")
NoArgAnyCallable = Callable[[], Any]
def __dataclass_transform__( def __dataclass_transform__(
@ -58,22 +58,22 @@ def __dataclass_transform__(
class FieldInfo(PydanticFieldInfo): class FieldInfo(PydanticFieldInfo):
def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None:
primary_key = kwargs.pop("primary_key", False) primary_key = kwargs.pop("primary_key", False)
nullable = kwargs.pop("nullable", Undefined) nullable = kwargs.pop("nullable", PydanticUndefined)
foreign_key = kwargs.pop("foreign_key", Undefined) foreign_key = kwargs.pop("foreign_key", PydanticUndefined)
unique = kwargs.pop("unique", False) unique = kwargs.pop("unique", False)
index = kwargs.pop("index", Undefined) index = kwargs.pop("index", PydanticUndefined)
sa_column = kwargs.pop("sa_column", Undefined) sa_column = kwargs.pop("sa_column", PydanticUndefined)
sa_column_args = kwargs.pop("sa_column_args", Undefined) sa_column_args = kwargs.pop("sa_column_args", PydanticUndefined)
sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined) sa_column_kwargs = kwargs.pop("sa_column_kwargs", PydanticUndefined)
if sa_column is not Undefined: if sa_column is not PydanticUndefined:
if sa_column_args is not Undefined: if sa_column_args is not PydanticUndefined:
raise RuntimeError( raise RuntimeError(
"Passing sa_column_args is not supported when " "Passing sa_column_args is not supported when "
"also passing a sa_column" "also passing a sa_column"
) )
if sa_column_kwargs is not Undefined: if sa_column_kwargs is not PydanticUndefined:
raise RuntimeError( raise RuntimeError(
"Passing sa_column_kwargs is not supported when " "Passing sa_column_kwargs is not supported when "
"also passing a sa_column" "also passing a sa_column"
@ -118,7 +118,7 @@ class RelationshipInfo(Representation):
def Field( def Field(
default: Any = Undefined, default: Any = PydanticUndefined,
*, *,
default_factory: Optional[NoArgAnyCallable] = None, default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None, alias: Optional[str] = None,
@ -145,11 +145,11 @@ def Field(
primary_key: bool = False, primary_key: bool = False,
foreign_key: Optional[Any] = None, foreign_key: Optional[Any] = None,
unique: bool = False, unique: bool = False,
nullable: Union[bool, UndefinedType] = Undefined, nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined,
index: Union[bool, UndefinedType] = Undefined, index: Union[bool, PydanticUndefinedType] = PydanticUndefined,
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined,
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, sa_column_kwargs: Union[Mapping[str, Any], PydanticUndefinedType] = PydanticUndefined,
schema_extra: Optional[Dict[str, Any]] = None, schema_extra: Optional[Dict[str, Any]] = None,
) -> Any: ) -> Any:
current_schema_extra = schema_extra or {} current_schema_extra = schema_extra or {}
@ -183,7 +183,6 @@ def Field(
sa_column_kwargs=sa_column_kwargs, sa_column_kwargs=sa_column_kwargs,
**current_schema_extra, **current_schema_extra,
) )
field_info._validate()
return field_info return field_info
@ -191,7 +190,7 @@ def Relationship(
*, *,
back_populates: Optional[str] = None, back_populates: Optional[str] = None,
link_model: Optional[Any] = None, link_model: Optional[Any] = None,
sa_relationship: Optional[RelationshipProperty] = None, # type: ignore sa_relationship: Optional[RelationshipProperty] = None,
sa_relationship_args: Optional[Sequence[Any]] = None, sa_relationship_args: Optional[Sequence[Any]] = None,
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None, sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
) -> Any: ) -> Any:
@ -208,18 +207,18 @@ def Relationship(
@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) @__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
__sqlmodel_relationships__: Dict[str, RelationshipInfo] __sqlmodel_relationships__: Dict[str, RelationshipInfo]
__config__: Type[BaseConfig] model_config: Type[SQLModelConfig]
__fields__: Dict[str, ModelField] model_fields: Dict[str, FieldInfo]
# Replicate SQLAlchemy # Replicate SQLAlchemy
def __setattr__(cls, name: str, value: Any) -> None: def __setattr__(cls, name: str, value: Any) -> None:
if getattr(cls.__config__, "table", False): if cls.model_config.get("table", False):
DeclarativeMeta.__setattr__(cls, name, value) DeclarativeMeta.__setattr__(cls, name, value)
else: else:
super().__setattr__(name, value) super().__setattr__(name, value)
def __delattr__(cls, name: str) -> None: def __delattr__(cls, name: str) -> None:
if getattr(cls.__config__, "table", False): if cls.model_config.get("table", False):
DeclarativeMeta.__delattr__(cls, name) DeclarativeMeta.__delattr__(cls, name)
else: else:
super().__delattr__(name) super().__delattr__(name)
@ -232,11 +231,10 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
class_dict: Dict[str, Any], class_dict: Dict[str, Any],
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
relationships: Dict[str, RelationshipInfo] = {} relationships: Dict[str, RelationshipInfo] = {}
dict_for_pydantic = {} dict_for_pydantic = {}
original_annotations = resolve_annotations( original_annotations = class_dict.get("__annotations__", {})
class_dict.get("__annotations__", {}), class_dict.get("__module__", None)
)
pydantic_annotations = {} pydantic_annotations = {}
relationship_annotations = {} relationship_annotations = {}
for k, v in class_dict.items(): for k, v in class_dict.items():
@ -260,7 +258,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
# superclass causing an error # superclass causing an error
allowed_config_kwargs: Set[str] = { allowed_config_kwargs: Set[str] = {
key key
for key in dir(BaseConfig) for key in dir(SQLModelConfig)
if not ( if not (
key.startswith("__") and key.endswith("__") key.startswith("__") and key.endswith("__")
) # skip dunder methods and attributes ) # skip dunder methods and attributes
@ -270,41 +268,49 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
key: pydantic_kwargs.pop(key) key: pydantic_kwargs.pop(key)
for key in pydantic_kwargs.keys() & allowed_config_kwargs for key in pydantic_kwargs.keys() & allowed_config_kwargs
} }
config_table = getattr(class_dict.get('Config', object()), 'table', False)
# If we have a table, we need to have defaults for all fields
# Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything
if config_table is True:
for key in original_annotations.keys():
if dict_used.get(key, PydanticUndefined) is PydanticUndefined:
dict_used[key] = None
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
new_cls.__annotations__ = { new_cls.__annotations__ = {
**relationship_annotations, **relationship_annotations,
**pydantic_annotations, **pydantic_annotations,
**new_cls.__annotations__, **new_cls.__annotations__,
} }
def get_config(name: str) -> Any: def get_config(name: str) -> Any:
config_class_value = getattr(new_cls.__config__, name, Undefined) config_class_value = new_cls.model_config.get(name, PydanticUndefined)
if config_class_value is not Undefined: if config_class_value is not PydanticUndefined:
return config_class_value return config_class_value
kwarg_value = kwargs.get(name, Undefined) kwarg_value = kwargs.get(name, PydanticUndefined)
if kwarg_value is not Undefined: if kwarg_value is not PydanticUndefined:
return kwarg_value return kwarg_value
return Undefined return PydanticUndefined
config_table = get_config("table") config_table = get_config("table")
if config_table is True: if config_table is True:
# If it was passed by kwargs, ensure it's also set in config # If it was passed by kwargs, ensure it's also set in config
new_cls.__config__.table = config_table new_cls.model_config['table'] = config_table
for k, v in new_cls.__fields__.items(): for k, v in new_cls.model_fields.items():
col = get_column_from_field(v) col = get_column_from_field(v)
setattr(new_cls, k, col) setattr(new_cls, k, col)
# Set a config flag to tell FastAPI that this should be read with a field # Set a config flag to tell FastAPI that this should be read with a field
# in orm_mode instead of preemptively converting it to a dict. # in orm_mode instead of preemptively converting it to a dict.
# This could be done by reading new_cls.__config__.table in FastAPI, but # This could be done by reading new_cls.model_config['table'] in FastAPI, but
# that's very specific about SQLModel, so let's have another config that # that's very specific about SQLModel, so let's have another config that
# other future tools based on Pydantic can use. # other future tools based on Pydantic can use.
new_cls.__config__.read_with_orm_mode = True new_cls.model_config['read_from_attributes'] = True
config_registry = get_config("registry") config_registry = get_config("registry")
if config_registry is not Undefined: if config_registry is not PydanticUndefined:
config_registry = cast(registry, config_registry) config_registry = cast(registry, config_registry)
# If it was passed by kwargs, ensure it's also set in config # If it was passed by kwargs, ensure it's also set in config
new_cls.__config__.registry = config_table new_cls.model_config['registry'] = config_table
setattr(new_cls, "_sa_registry", config_registry) setattr(new_cls, "_sa_registry", config_registry)
setattr(new_cls, "metadata", config_registry.metadata) setattr(new_cls, "metadata", config_registry.metadata)
setattr(new_cls, "__abstract__", True) setattr(new_cls, "__abstract__", True)
@ -320,13 +326,13 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
# triggers an error # triggers an error
base_is_table = False base_is_table = False
for base in bases: for base in bases:
config = getattr(base, "__config__") config = getattr(base, "model_config")
if config and getattr(config, "table", False): if config and getattr(config, "table", False):
base_is_table = True base_is_table = True
break break
if getattr(cls.__config__, "table", False) and not base_is_table: if cls.model_config.get("table", False) and not base_is_table:
dict_used = dict_.copy() dict_used = dict_.copy()
for field_name, field_value in cls.__fields__.items(): for field_name, field_value in cls.model_fields.items():
dict_used[field_name] = get_column_from_field(field_value) dict_used[field_name] = get_column_from_field(field_value)
for rel_name, rel_info in cls.__sqlmodel_relationships__.items(): for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
if rel_info.sa_relationship: if rel_info.sa_relationship:
@ -335,16 +341,15 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
dict_used[rel_name] = rel_info.sa_relationship dict_used[rel_name] = rel_info.sa_relationship
continue continue
ann = cls.__annotations__[rel_name] ann = cls.__annotations__[rel_name]
temp_field = ModelField.infer( relationship_to = get_origin(ann)
name=rel_name, # If Union (Optional), get the real field
value=rel_info, if relationship_to is Union:
annotation=ann, relationship_to = get_args(ann)[0]
class_validators=None, # If a list, then also get the real field
config=BaseConfig, elif relationship_to is list:
) relationship_to = get_args(ann)[0]
relationship_to = temp_field.type_ if isinstance(relationship_to, ForwardRef):
if isinstance(temp_field.type_, ForwardRef): relationship_to = relationship_to.__forward_arg__
relationship_to = temp_field.type_.__forward_arg__
rel_kwargs: Dict[str, Any] = {} rel_kwargs: Dict[str, Any] = {}
if rel_info.back_populates: if rel_info.back_populates:
rel_kwargs["back_populates"] = rel_info.back_populates rel_kwargs["back_populates"] = rel_info.back_populates
@ -362,7 +367,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
rel_args.extend(rel_info.sa_relationship_args) rel_args.extend(rel_info.sa_relationship_args)
if rel_info.sa_relationship_kwargs: if rel_info.sa_relationship_kwargs:
rel_kwargs.update(rel_info.sa_relationship_kwargs) rel_kwargs.update(rel_info.sa_relationship_kwargs)
rel_value: RelationshipProperty = relationship( # type: ignore rel_value: RelationshipProperty = relationship(
relationship_to, *rel_args, **rel_kwargs relationship_to, *rel_args, **rel_kwargs
) )
dict_used[rel_name] = rel_value dict_used[rel_name] = rel_value
@ -372,68 +377,78 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
def get_sqlalchemy_type(field: ModelField) -> Any: def get_sqlalchemy_type(field: FieldInfo) -> Any:
if issubclass(field.type_, str): type_ = field.annotation
if field.field_info.max_length: # Resolve Optional fields
return AutoString(length=field.field_info.max_length) if type_ is not None and get_origin(type_) is Union:
bases = get_args(type_)
if len(bases) > 2:
raise RuntimeError("Cannot have a (non-optional) union as a SQL alchemy field")
type_ = bases[0]
# The 3rd is PydanticGeneralMetadata
metadata = _get_field_metadata(field)
if issubclass(type_, str):
if getattr(metadata, "max_length", None):
return AutoString(length=metadata.max_length)
return AutoString return AutoString
if issubclass(field.type_, float): if issubclass(type_, float):
return Float return Float
if issubclass(field.type_, bool): if issubclass(type_, bool):
return Boolean return Boolean
if issubclass(field.type_, int): if issubclass(type_, int):
return Integer return Integer
if issubclass(field.type_, datetime): if issubclass(type_, datetime):
return DateTime return DateTime
if issubclass(field.type_, date): if issubclass(type_, date):
return Date return Date
if issubclass(field.type_, timedelta): if issubclass(type_, timedelta):
return Interval return Interval
if issubclass(field.type_, time): if issubclass(type_, time):
return Time return Time
if issubclass(field.type_, Enum): if issubclass(type_, Enum):
return sa_Enum(field.type_) return sa_Enum(type_)
if issubclass(field.type_, bytes): if issubclass(type_, bytes):
return LargeBinary return LargeBinary
if issubclass(field.type_, Decimal): if issubclass(type_, Decimal):
return Numeric( return Numeric(
precision=getattr(field.type_, "max_digits", None), precision=getattr(metadata, "max_digits", None),
scale=getattr(field.type_, "decimal_places", None), scale=getattr(metadata, "decimal_places", None),
) )
if issubclass(field.type_, ipaddress.IPv4Address): if issubclass(type_, ipaddress.IPv4Address):
return AutoString return AutoString
if issubclass(field.type_, ipaddress.IPv4Network): if issubclass(type_, ipaddress.IPv4Network):
return AutoString return AutoString
if issubclass(field.type_, ipaddress.IPv6Address): if issubclass(type_, ipaddress.IPv6Address):
return AutoString return AutoString
if issubclass(field.type_, ipaddress.IPv6Network): if issubclass(type_, ipaddress.IPv6Network):
return AutoString return AutoString
if issubclass(field.type_, Path): if issubclass(type_, Path):
return AutoString return AutoString
if issubclass(field.type_, uuid.UUID): if issubclass(type_, uuid.UUID):
return GUID return GUID
raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") raise ValueError(f"The field {field.title} has no matching SQLAlchemy type")
def get_column_from_field(field: ModelField) -> Column: # type: ignore def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
sa_column = getattr(field.field_info, "sa_column", Undefined) sa_column = getattr(field, "sa_column", PydanticUndefined)
if isinstance(sa_column, Column): if isinstance(sa_column, Column):
return sa_column return sa_column
sa_type = get_sqlalchemy_type(field) sa_type = get_sqlalchemy_type(field)
primary_key = getattr(field.field_info, "primary_key", False) primary_key = getattr(field, "primary_key", False)
index = getattr(field.field_info, "index", Undefined) index = getattr(field, "index", PydanticUndefined)
if index is Undefined: if index is PydanticUndefined:
index = False index = False
nullable = not primary_key and _is_field_noneable(field) nullable = not primary_key and _is_field_noneable(field)
# Override derived nullability if the nullable property is set explicitly # Override derived nullability if the nullable property is set explicitly
# on the field # on the field
if hasattr(field.field_info, "nullable"): if hasattr(field, "nullable"):
field_nullable = getattr(field.field_info, "nullable") field_nullable = getattr(field, "nullable")
if field_nullable != Undefined: if field_nullable != PydanticUndefined:
nullable = field_nullable nullable = field_nullable
args = [] args = []
foreign_key = getattr(field.field_info, "foreign_key", None) foreign_key = getattr(field, "foreign_key", None)
unique = getattr(field.field_info, "unique", False) unique = getattr(field, "unique", False)
if foreign_key: if foreign_key:
args.append(ForeignKey(foreign_key)) args.append(ForeignKey(foreign_key))
kwargs = { kwargs = {
@ -442,18 +457,18 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore
"index": index, "index": index,
"unique": unique, "unique": unique,
} }
sa_default = Undefined sa_default = PydanticUndefined
if field.field_info.default_factory: if field.default_factory:
sa_default = field.field_info.default_factory sa_default = field.default_factory
elif field.field_info.default is not Undefined: elif field.default is not PydanticUndefined:
sa_default = field.field_info.default sa_default = field.default
if sa_default is not Undefined: if sa_default is not PydanticUndefined:
kwargs["default"] = sa_default kwargs["default"] = sa_default
sa_column_args = getattr(field.field_info, "sa_column_args", Undefined) sa_column_args = getattr(field, "sa_column_args", PydanticUndefined)
if sa_column_args is not Undefined: if sa_column_args is not PydanticUndefined:
args.extend(list(cast(Sequence[Any], sa_column_args))) args.extend(list(cast(Sequence[Any], sa_column_args)))
sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined) sa_column_kwargs = getattr(field, "sa_column_kwargs", PydanticUndefined)
if sa_column_kwargs is not Undefined: if sa_column_kwargs is not PydanticUndefined:
kwargs.update(cast(Dict[Any, Any], sa_column_kwargs)) kwargs.update(cast(Dict[Any, Any], sa_column_kwargs))
return Column(sa_type, *args, **kwargs) # type: ignore return Column(sa_type, *args, **kwargs) # type: ignore
@ -462,13 +477,6 @@ class_registry = weakref.WeakValueDictionary() # type: ignore
default_registry = registry() default_registry = registry()
def _value_items_is_true(v: Any) -> bool:
# Re-implement Pydantic's ValueItems.is_true() as it hasn't been released as of
# the current latest, Pydantic 1.8.2
return v is True or v is ...
_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel") _TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")
@ -476,43 +484,17 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
__slots__ = ("__weakref__",) __slots__ = ("__weakref__",)
__tablename__: ClassVar[Union[str, Callable[..., str]]] __tablename__: ClassVar[Union[str, Callable[..., str]]]
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]]
__name__: ClassVar[str] __name__: ClassVar[str]
metadata: ClassVar[MetaData] metadata: ClassVar[MetaData]
__allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six
model_config = SQLModelConfig(from_attributes=True)
class Config:
orm_mode = True
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
new_object = super().__new__(cls)
# SQLAlchemy doesn't call __init__ on the base class
# Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html
# Set __fields_set__ here, that would have been set when calling __init__
# in the Pydantic model so that when SQLAlchemy sets attributes that are
# added (e.g. when querying from DB) to the __fields_set__, this already exists
object.__setattr__(new_object, "__fields_set__", set())
return new_object
def __init__(__pydantic_self__, **data: Any) -> None: def __init__(__pydantic_self__, **data: Any) -> None:
# Uses something other than `self` the first arg to allow "self" as a old_dict = __pydantic_self__.__dict__.copy()
# settable attribute super().__init__(**data)
values, fields_set, validation_error = validate_model( __pydantic_self__.__dict__ = old_dict | __pydantic_self__.__dict__
__pydantic_self__.__class__, data non_pydantic_keys = data.keys() - __pydantic_self__.model_fields
)
# Only raise errors if not a SQLModel model
if (
not getattr(__pydantic_self__.__config__, "table", False)
and validation_error
):
raise validation_error
# Do not set values as in Pydantic, pass them through setattr, so SQLAlchemy
# can handle them
# object.__setattr__(__pydantic_self__, '__dict__', values)
for key, value in values.items():
setattr(__pydantic_self__, key, value)
object.__setattr__(__pydantic_self__, "__fields_set__", fields_set)
non_pydantic_keys = data.keys() - values.keys()
for key in non_pydantic_keys: for key in non_pydantic_keys:
if key in __pydantic_self__.__sqlmodel_relationships__: if key in __pydantic_self__.__sqlmodel_relationships__:
setattr(__pydantic_self__, key, data[key]) setattr(__pydantic_self__, key, data[key])
@ -523,135 +505,36 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
return return
else: else:
# Set in SQLAlchemy, before Pydantic to trigger events and updates # Set in SQLAlchemy, before Pydantic to trigger events and updates
if getattr(self.__config__, "table", False) and is_instrumented(self, name): # type: ignore if self.model_config.get("table", False) and is_instrumented(self, name):
set_attribute(self, name, value) set_attribute(self, name, value)
# Set in Pydantic model to trigger possible validation changes, only for # Set in Pydantic model to trigger possible validation changes, only for
# non relationship values # non relationship values
if name not in self.__sqlmodel_relationships__: if name not in self.__sqlmodel_relationships__:
super().__setattr__(name, value) super(SQLModel, self).__setattr__(name, value)
@classmethod
def from_orm(
cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None
) -> _TSQLModel:
# Duplicated from Pydantic
if not cls.__config__.orm_mode:
raise ConfigError(
"You must have the config attribute orm_mode=True to use from_orm"
)
obj = {ROOT_KEY: obj} if cls.__custom_root_type__ else cls._decompose_class(obj)
# SQLModel, support update dict
if update is not None:
obj = {**obj, **update}
# End SQLModel support dict
if not getattr(cls.__config__, "table", False):
# If not table, normal Pydantic code
m: _TSQLModel = cls.__new__(cls)
else:
# If table, create the new instance normally to make SQLAlchemy create
# the _sa_instance_state attribute
m = cls()
values, fields_set, validation_error = validate_model(cls, obj)
if validation_error:
raise validation_error
# Updated to trigger SQLAlchemy internal handling
if not getattr(cls.__config__, "table", False):
object.__setattr__(m, "__dict__", values)
else:
for key, value in values.items():
setattr(m, key, value)
# Continue with standard Pydantic logic
object.__setattr__(m, "__fields_set__", fields_set)
m._init_private_attributes()
return m
@classmethod
def parse_obj(
cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None
) -> _TSQLModel:
obj = cls._enforce_dict_if_root(obj)
# SQLModel, support update dict
if update is not None:
obj = {**obj, **update}
# End SQLModel support dict
return super().parse_obj(obj)
def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]:
# Don't show SQLAlchemy private attributes # Don't show SQLAlchemy private attributes
return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_sa_")] return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_sa_")]
# From Pydantic, override to enforce validation with dict
@classmethod
def validate(cls: Type[_TSQLModel], value: Any) -> _TSQLModel:
if isinstance(value, cls):
return value.copy() if cls.__config__.copy_on_model_validation else value
value = cls._enforce_dict_if_root(value)
if isinstance(value, dict):
values, fields_set, validation_error = validate_model(cls, value)
if validation_error:
raise validation_error
model = cls(**value)
# Reset fields set, this would have been done in Pydantic in __init__
object.__setattr__(model, "__fields_set__", fields_set)
return model
elif cls.__config__.orm_mode:
return cls.from_orm(value)
elif cls.__custom_root_type__:
return cls.parse_obj(value)
else:
try:
value_as_dict = dict(value)
except (TypeError, ValueError) as e:
raise DictError() from e
return cls(**value_as_dict)
# From Pydantic, override to only show keys from fields, omit SQLAlchemy attributes
def _calculate_keys(
self,
include: Optional[Mapping[Union[int, str], Any]],
exclude: Optional[Mapping[Union[int, str], Any]],
exclude_unset: bool,
update: Optional[Dict[str, Any]] = None,
) -> Optional[AbstractSet[str]]:
if include is None and exclude is None and not exclude_unset:
# Original in Pydantic:
# return None
# Updated to not return SQLAlchemy attributes
# Do not include relationships as that would easily lead to infinite
# recursion, or traversing the whole database
return self.__fields__.keys() # | self.__sqlmodel_relationships__.keys()
keys: AbstractSet[str]
if exclude_unset:
keys = self.__fields_set__.copy()
else:
# Original in Pydantic:
# keys = self.__dict__.keys()
# Updated to not return SQLAlchemy attributes
# Do not include relationships as that would easily lead to infinite
# recursion, or traversing the whole database
keys = self.__fields__.keys() # | self.__sqlmodel_relationships__.keys()
if include is not None:
keys &= include.keys()
if update:
keys -= update.keys()
if exclude:
keys -= {k for k, v in exclude.items() if _value_items_is_true(v)}
return keys
@declared_attr # type: ignore @declared_attr # type: ignore
def __tablename__(cls) -> str: def __tablename__(cls) -> str:
return cls.__name__.lower() return cls.__name__.lower()
def _is_field_noneable(field: ModelField) -> bool: def _is_field_noneable(field: FieldInfo) -> bool:
if not field.required: if not field.is_required():
# Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947) if field.annotation is None or field.annotation is type(None):
return field.allow_none and ( return True
field.shape != SHAPE_SINGLETON or not field.sub_fields if get_origin(field.annotation) is Union:
) for base in get_args(field.annotation):
if base is type(None):
return True
return False
return False return False
def _get_field_metadata(field: FieldInfo) -> object:
for meta in field.metadata:
if isinstance(meta, PydanticGeneralMetadata):
return meta
return object()

7
sqlmodel/typing.py Normal file
View File

@ -0,0 +1,7 @@
from pydantic import ConfigDict
from typing import Optional, Any
class SQLModelConfig(ConfigDict):
table: Optional[bool]
read_from_attributes: Optional[bool]
registry: Optional[Any]

View File

@ -8,7 +8,7 @@ from sqlmodel import Field, SQLModel
def test_allow_instantiation_without_arguments(clear_sqlmodel): def test_allow_instantiation_without_arguments(clear_sqlmodel):
class Item(SQLModel): class Item(SQLModel):
id: Optional[int] = Field(default=None, primary_key=True) id: Optional[int] = Field(default=None, primary_key=True)
name: str name: str
description: Optional[str] = None description: Optional[str] = None
class Config: class Config:

View File

@ -2,10 +2,11 @@ from typing import Optional
import pytest import pytest
from sqlmodel import Field, SQLModel from sqlmodel import Field, SQLModel
from pydantic import BaseModel
def test_missing_sql_type(): def test_missing_sql_type():
class CustomType: class CustomType(BaseModel):
@classmethod @classmethod
def __get_validators__(cls): def __get_validators__(cls):
yield cls.validate yield cls.validate

View File

@ -1,7 +1,7 @@
from typing import Optional from typing import Optional
import pytest import pytest
from pydantic import validator from pydantic import field_validator
from pydantic.error_wrappers import ValidationError from pydantic.error_wrappers import ValidationError
from sqlmodel import SQLModel from sqlmodel import SQLModel
@ -22,12 +22,12 @@ def test_validation(clear_sqlmodel):
secret_name: Optional[str] = None secret_name: Optional[str] = None
age: Optional[int] = None age: Optional[int] = None
@validator("name", "secret_name", "age") @field_validator("name", "secret_name", "age")
def reject_none(cls, v): def reject_none(cls, v):
assert v is not None assert v is not None
return v return v
Hero.validate({"age": 25}) Hero.model_validate({"age": 25})
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
Hero.validate({"name": None, "age": 25}) Hero.model_validate({"name": None, "age": 25})