From a7f8a27c2090c5f164a56d913a4ae85645aae497 Mon Sep 17 00:00:00 2001 From: Anton De Meester Date: Mon, 31 Jul 2023 13:38:26 +0000 Subject: [PATCH] Upgrade to Pydantic 2 Change imports Undefined => PydanticUndefined Update SQLModelMetaclass and SQLModel __init__ and __new__ functions Update SQL Alchemy type inference --- docs/tutorial/fastapi/multiple-models.md | 6 +- .../fastapi/app_testing/tutorial001/main.py | 2 +- .../tutorial/fastapi/delete/tutorial001.py | 2 +- .../fastapi/limit_and_offset/tutorial001.py | 2 +- .../fastapi/multiple_models/tutorial001.py | 2 +- .../fastapi/multiple_models/tutorial002.py | 2 +- .../tutorial/fastapi/read_one/tutorial001.py | 2 +- .../fastapi/relationships/tutorial001.py | 4 +- .../session_with_dependency/tutorial001.py | 2 +- .../tutorial/fastapi/teams/tutorial001.py | 4 +- .../tutorial/fastapi/update/tutorial001.py | 2 +- pyproject.toml | 5 +- sqlmodel/main.py | 403 +++++++----------- sqlmodel/typing.py | 7 + tests/test_instance_no_args.py | 2 +- tests/test_missing_type.py | 3 +- tests/test_validation.py | 8 +- 17 files changed, 175 insertions(+), 283 deletions(-) create mode 100644 sqlmodel/typing.py diff --git a/docs/tutorial/fastapi/multiple-models.md b/docs/tutorial/fastapi/multiple-models.md index c37fad3..2ebfe20 100644 --- a/docs/tutorial/fastapi/multiple-models.md +++ b/docs/tutorial/fastapi/multiple-models.md @@ -174,13 +174,13 @@ Now we use the type annotation `HeroCreate` for the request JSON data in the `he # Code below omitted 👇 ``` -Then we create a new `Hero` (this is the actual **table** model that saves things to the database) using `Hero.from_orm()`. +Then we create a new `Hero` (this is the actual **table** model that saves things to the database) using `Hero.model_validate()`. -The method `.from_orm()` reads data from another object with attributes and creates a new instance of this class, in this case `Hero`. +The method `.model_validate()` reads data from another object with attributes and creates a new instance of this class, in this case `Hero`. The alternative is `Hero.parse_obj()` that reads data from a dictionary. -But as in this case, we have a `HeroCreate` instance in the `hero` variable. This is an object with attributes, so we use `.from_orm()` to read those attributes. +But as in this case, we have a `HeroCreate` instance in the `hero` variable. This is an object with attributes, so we use `.model_validate()` to read those attributes. With this, we create a new `Hero` instance (the one for the database) and put it in the variable `db_hero` from the data in the `hero` variable that is the `HeroCreate` instance we received from the request. diff --git a/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py b/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py index 88b8fbb..c780965 100644 --- a/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py +++ b/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py @@ -54,7 +54,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/delete/tutorial001.py b/docs_src/tutorial/fastapi/delete/tutorial001.py index 3c15efb..3053300 100644 --- a/docs_src/tutorial/fastapi/delete/tutorial001.py +++ b/docs_src/tutorial/fastapi/delete/tutorial001.py @@ -50,7 +50,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py b/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py index aef2133..d43aa1f 100644 --- a/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py +++ b/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py @@ -44,7 +44,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/multiple_models/tutorial001.py b/docs_src/tutorial/fastapi/multiple_models/tutorial001.py index df20123..7f59ac6 100644 --- a/docs_src/tutorial/fastapi/multiple_models/tutorial001.py +++ b/docs_src/tutorial/fastapi/multiple_models/tutorial001.py @@ -46,7 +46,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/multiple_models/tutorial002.py b/docs_src/tutorial/fastapi/multiple_models/tutorial002.py index 392c2c5..fffbe72 100644 --- a/docs_src/tutorial/fastapi/multiple_models/tutorial002.py +++ b/docs_src/tutorial/fastapi/multiple_models/tutorial002.py @@ -44,7 +44,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/read_one/tutorial001.py b/docs_src/tutorial/fastapi/read_one/tutorial001.py index 4d66e47..f18426e 100644 --- a/docs_src/tutorial/fastapi/read_one/tutorial001.py +++ b/docs_src/tutorial/fastapi/read_one/tutorial001.py @@ -44,7 +44,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/relationships/tutorial001.py b/docs_src/tutorial/fastapi/relationships/tutorial001.py index 97220b9..248c3a2 100644 --- a/docs_src/tutorial/fastapi/relationships/tutorial001.py +++ b/docs_src/tutorial/fastapi/relationships/tutorial001.py @@ -92,7 +92,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) @@ -147,7 +147,7 @@ def delete_hero(*, session: Session = Depends(get_session), hero_id: int): @app.post("/teams/", response_model=TeamRead) def create_team(*, session: Session = Depends(get_session), team: TeamCreate): - db_team = Team.from_orm(team) + db_team = Team.model_validate(team) session.add(db_team) session.commit() session.refresh(db_team) diff --git a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py index 88b8fbb..c780965 100644 --- a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py +++ b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py @@ -54,7 +54,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/teams/tutorial001.py b/docs_src/tutorial/fastapi/teams/tutorial001.py index 2a0bd60..2cc198e 100644 --- a/docs_src/tutorial/fastapi/teams/tutorial001.py +++ b/docs_src/tutorial/fastapi/teams/tutorial001.py @@ -83,7 +83,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) @@ -138,7 +138,7 @@ def delete_hero(*, session: Session = Depends(get_session), hero_id: int): @app.post("/teams/", response_model=TeamRead) def create_team(*, session: Session = Depends(get_session), team: TeamCreate): - db_team = Team.from_orm(team) + db_team = Team.model_validate(team) session.add(db_team) session.commit() session.refresh(db_team) diff --git a/docs_src/tutorial/fastapi/update/tutorial001.py b/docs_src/tutorial/fastapi/update/tutorial001.py index 3555487..47f0197 100644 --- a/docs_src/tutorial/fastapi/update/tutorial001.py +++ b/docs_src/tutorial/fastapi/update/tutorial001.py @@ -50,7 +50,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/pyproject.toml b/pyproject.toml index 0a25be4..e6489c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.7" SQLAlchemy = ">=2.0.0,<=2.0.11" -pydantic = "^1.8.2" +pydantic = "^2.1.1" [tool.poetry.dev-dependencies] pytest = "^7.0.1" @@ -46,12 +46,13 @@ pillow = "^9.3.0" cairosvg = "^2.5.2" mdx-include = "^1.4.1" coverage = {extras = ["toml"], version = "^6.2"} -fastapi = "^0.68.1" +fastapi = "^0.100.0" requests = "^2.26.0" autoflake = "^1.4" isort = "^5.9.3" async_generator = {version = "*", python = "~3.7"} async-exit-stack = {version = "*", python = "~3.7"} +httpx = "^0.24.1" [build-system] requires = ["poetry-core"] diff --git a/sqlmodel/main.py b/sqlmodel/main.py index d05fdcc..46a8ba3 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -19,19 +19,17 @@ from typing import ( Set, Tuple, Type, - TypeVar, + TypeVar,ForwardRef, Union, - cast, + cast,get_origin,get_args ) -from pydantic import BaseConfig, BaseModel -from pydantic.errors import ConfigError, DictError -from pydantic.fields import SHAPE_SINGLETON +from pydantic import BaseModel, ValidationError +from pydantic_core import PydanticUndefined, PydanticUndefinedType from pydantic.fields import FieldInfo as PydanticFieldInfo -from pydantic.fields import ModelField, Undefined, UndefinedType -from pydantic.main import ModelMetaclass, validate_model -from pydantic.typing import NoArgAnyCallable, resolve_annotations -from pydantic.utils import ROOT_KEY, Representation +from pydantic._internal._model_construction import ModelMetaclass +from pydantic._internal._repr import Representation +from pydantic._internal._fields import PydanticGeneralMetadata from sqlalchemy import Boolean, Column, Date, DateTime from sqlalchemy import Enum as sa_Enum from sqlalchemy import Float, ForeignKey, Integer, Interval, Numeric, inspect @@ -43,8 +41,10 @@ from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time from .sql.sqltypes import GUID, AutoString +from .typing import SQLModelConfig _T = TypeVar("_T") +NoArgAnyCallable = Callable[[], Any] def __dataclass_transform__( @@ -58,22 +58,22 @@ def __dataclass_transform__( class FieldInfo(PydanticFieldInfo): - def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: + def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None: primary_key = kwargs.pop("primary_key", False) - nullable = kwargs.pop("nullable", Undefined) - foreign_key = kwargs.pop("foreign_key", Undefined) + nullable = kwargs.pop("nullable", PydanticUndefined) + foreign_key = kwargs.pop("foreign_key", PydanticUndefined) unique = kwargs.pop("unique", False) - index = kwargs.pop("index", Undefined) - sa_column = kwargs.pop("sa_column", Undefined) - sa_column_args = kwargs.pop("sa_column_args", Undefined) - sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined) - if sa_column is not Undefined: - if sa_column_args is not Undefined: + index = kwargs.pop("index", PydanticUndefined) + sa_column = kwargs.pop("sa_column", PydanticUndefined) + sa_column_args = kwargs.pop("sa_column_args", PydanticUndefined) + sa_column_kwargs = kwargs.pop("sa_column_kwargs", PydanticUndefined) + if sa_column is not PydanticUndefined: + if sa_column_args is not PydanticUndefined: raise RuntimeError( "Passing sa_column_args is not supported when " "also passing a sa_column" ) - if sa_column_kwargs is not Undefined: + if sa_column_kwargs is not PydanticUndefined: raise RuntimeError( "Passing sa_column_kwargs is not supported when " "also passing a sa_column" @@ -118,7 +118,7 @@ class RelationshipInfo(Representation): def Field( - default: Any = Undefined, + default: Any = PydanticUndefined, *, default_factory: Optional[NoArgAnyCallable] = None, alias: Optional[str] = None, @@ -145,11 +145,11 @@ def Field( primary_key: bool = False, foreign_key: Optional[Any] = None, unique: bool = False, - nullable: Union[bool, UndefinedType] = Undefined, - index: Union[bool, UndefinedType] = Undefined, - sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore - sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, - sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, + nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined, + index: Union[bool, PydanticUndefinedType] = PydanticUndefined, + sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore + sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined, + sa_column_kwargs: Union[Mapping[str, Any], PydanticUndefinedType] = PydanticUndefined, schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: current_schema_extra = schema_extra or {} @@ -183,7 +183,6 @@ def Field( sa_column_kwargs=sa_column_kwargs, **current_schema_extra, ) - field_info._validate() return field_info @@ -191,7 +190,7 @@ def Relationship( *, back_populates: Optional[str] = None, link_model: Optional[Any] = None, - sa_relationship: Optional[RelationshipProperty] = None, # type: ignore + sa_relationship: Optional[RelationshipProperty] = None, sa_relationship_args: Optional[Sequence[Any]] = None, sa_relationship_kwargs: Optional[Mapping[str, Any]] = None, ) -> Any: @@ -208,18 +207,18 @@ def Relationship( @__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): __sqlmodel_relationships__: Dict[str, RelationshipInfo] - __config__: Type[BaseConfig] - __fields__: Dict[str, ModelField] + model_config: Type[SQLModelConfig] + model_fields: Dict[str, FieldInfo] # Replicate SQLAlchemy def __setattr__(cls, name: str, value: Any) -> None: - if getattr(cls.__config__, "table", False): + if cls.model_config.get("table", False): DeclarativeMeta.__setattr__(cls, name, value) else: super().__setattr__(name, value) def __delattr__(cls, name: str) -> None: - if getattr(cls.__config__, "table", False): + if cls.model_config.get("table", False): DeclarativeMeta.__delattr__(cls, name) else: super().__delattr__(name) @@ -232,11 +231,10 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): class_dict: Dict[str, Any], **kwargs: Any, ) -> Any: + relationships: Dict[str, RelationshipInfo] = {} dict_for_pydantic = {} - original_annotations = resolve_annotations( - class_dict.get("__annotations__", {}), class_dict.get("__module__", None) - ) + original_annotations = class_dict.get("__annotations__", {}) pydantic_annotations = {} relationship_annotations = {} for k, v in class_dict.items(): @@ -260,7 +258,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): # superclass causing an error allowed_config_kwargs: Set[str] = { key - for key in dir(BaseConfig) + for key in dir(SQLModelConfig) if not ( key.startswith("__") and key.endswith("__") ) # skip dunder methods and attributes @@ -270,41 +268,49 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): key: pydantic_kwargs.pop(key) for key in pydantic_kwargs.keys() & allowed_config_kwargs } + config_table = getattr(class_dict.get('Config', object()), 'table', False) + # If we have a table, we need to have defaults for all fields + # Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything + if config_table is True: + for key in original_annotations.keys(): + if dict_used.get(key, PydanticUndefined) is PydanticUndefined: + dict_used[key] = None + new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) new_cls.__annotations__ = { **relationship_annotations, **pydantic_annotations, **new_cls.__annotations__, } - + def get_config(name: str) -> Any: - config_class_value = getattr(new_cls.__config__, name, Undefined) - if config_class_value is not Undefined: + config_class_value = new_cls.model_config.get(name, PydanticUndefined) + if config_class_value is not PydanticUndefined: return config_class_value - kwarg_value = kwargs.get(name, Undefined) - if kwarg_value is not Undefined: + kwarg_value = kwargs.get(name, PydanticUndefined) + if kwarg_value is not PydanticUndefined: return kwarg_value - return Undefined + return PydanticUndefined config_table = get_config("table") if config_table is True: # If it was passed by kwargs, ensure it's also set in config - new_cls.__config__.table = config_table - for k, v in new_cls.__fields__.items(): + new_cls.model_config['table'] = config_table + for k, v in new_cls.model_fields.items(): col = get_column_from_field(v) setattr(new_cls, k, col) # Set a config flag to tell FastAPI that this should be read with a field # in orm_mode instead of preemptively converting it to a dict. - # This could be done by reading new_cls.__config__.table in FastAPI, but + # This could be done by reading new_cls.model_config['table'] in FastAPI, but # that's very specific about SQLModel, so let's have another config that # other future tools based on Pydantic can use. - new_cls.__config__.read_with_orm_mode = True + new_cls.model_config['read_from_attributes'] = True config_registry = get_config("registry") - if config_registry is not Undefined: + if config_registry is not PydanticUndefined: config_registry = cast(registry, config_registry) # If it was passed by kwargs, ensure it's also set in config - new_cls.__config__.registry = config_table + new_cls.model_config['registry'] = config_table setattr(new_cls, "_sa_registry", config_registry) setattr(new_cls, "metadata", config_registry.metadata) setattr(new_cls, "__abstract__", True) @@ -320,13 +326,13 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): # triggers an error base_is_table = False for base in bases: - config = getattr(base, "__config__") + config = getattr(base, "model_config") if config and getattr(config, "table", False): base_is_table = True break - if getattr(cls.__config__, "table", False) and not base_is_table: + if cls.model_config.get("table", False) and not base_is_table: dict_used = dict_.copy() - for field_name, field_value in cls.__fields__.items(): + for field_name, field_value in cls.model_fields.items(): dict_used[field_name] = get_column_from_field(field_value) for rel_name, rel_info in cls.__sqlmodel_relationships__.items(): if rel_info.sa_relationship: @@ -335,16 +341,15 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): dict_used[rel_name] = rel_info.sa_relationship continue ann = cls.__annotations__[rel_name] - temp_field = ModelField.infer( - name=rel_name, - value=rel_info, - annotation=ann, - class_validators=None, - config=BaseConfig, - ) - relationship_to = temp_field.type_ - if isinstance(temp_field.type_, ForwardRef): - relationship_to = temp_field.type_.__forward_arg__ + relationship_to = get_origin(ann) + # If Union (Optional), get the real field + if relationship_to is Union: + relationship_to = get_args(ann)[0] + # If a list, then also get the real field + elif relationship_to is list: + relationship_to = get_args(ann)[0] + if isinstance(relationship_to, ForwardRef): + relationship_to = relationship_to.__forward_arg__ rel_kwargs: Dict[str, Any] = {} if rel_info.back_populates: rel_kwargs["back_populates"] = rel_info.back_populates @@ -362,7 +367,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): rel_args.extend(rel_info.sa_relationship_args) if rel_info.sa_relationship_kwargs: rel_kwargs.update(rel_info.sa_relationship_kwargs) - rel_value: RelationshipProperty = relationship( # type: ignore + rel_value: RelationshipProperty = relationship( relationship_to, *rel_args, **rel_kwargs ) dict_used[rel_name] = rel_value @@ -372,68 +377,78 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) -def get_sqlalchemy_type(field: ModelField) -> Any: - if issubclass(field.type_, str): - if field.field_info.max_length: - return AutoString(length=field.field_info.max_length) +def get_sqlalchemy_type(field: FieldInfo) -> Any: + type_ = field.annotation + # Resolve Optional fields + if type_ is not None and get_origin(type_) is Union: + bases = get_args(type_) + if len(bases) > 2: + raise RuntimeError("Cannot have a (non-optional) union as a SQL alchemy field") + type_ = bases[0] + + # The 3rd is PydanticGeneralMetadata + metadata = _get_field_metadata(field) + if issubclass(type_, str): + if getattr(metadata, "max_length", None): + return AutoString(length=metadata.max_length) return AutoString - if issubclass(field.type_, float): + if issubclass(type_, float): return Float - if issubclass(field.type_, bool): + if issubclass(type_, bool): return Boolean - if issubclass(field.type_, int): + if issubclass(type_, int): return Integer - if issubclass(field.type_, datetime): + if issubclass(type_, datetime): return DateTime - if issubclass(field.type_, date): + if issubclass(type_, date): return Date - if issubclass(field.type_, timedelta): + if issubclass(type_, timedelta): return Interval - if issubclass(field.type_, time): + if issubclass(type_, time): return Time - if issubclass(field.type_, Enum): - return sa_Enum(field.type_) - if issubclass(field.type_, bytes): + if issubclass(type_, Enum): + return sa_Enum(type_) + if issubclass(type_, bytes): return LargeBinary - if issubclass(field.type_, Decimal): + if issubclass(type_, Decimal): return Numeric( - precision=getattr(field.type_, "max_digits", None), - scale=getattr(field.type_, "decimal_places", None), + precision=getattr(metadata, "max_digits", None), + scale=getattr(metadata, "decimal_places", None), ) - if issubclass(field.type_, ipaddress.IPv4Address): + if issubclass(type_, ipaddress.IPv4Address): return AutoString - if issubclass(field.type_, ipaddress.IPv4Network): + if issubclass(type_, ipaddress.IPv4Network): return AutoString - if issubclass(field.type_, ipaddress.IPv6Address): + if issubclass(type_, ipaddress.IPv6Address): return AutoString - if issubclass(field.type_, ipaddress.IPv6Network): + if issubclass(type_, ipaddress.IPv6Network): return AutoString - if issubclass(field.type_, Path): + if issubclass(type_, Path): return AutoString - if issubclass(field.type_, uuid.UUID): + if issubclass(type_, uuid.UUID): return GUID - raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") + raise ValueError(f"The field {field.title} has no matching SQLAlchemy type") -def get_column_from_field(field: ModelField) -> Column: # type: ignore - sa_column = getattr(field.field_info, "sa_column", Undefined) +def get_column_from_field(field: FieldInfo) -> Column: # type: ignore + sa_column = getattr(field, "sa_column", PydanticUndefined) if isinstance(sa_column, Column): return sa_column sa_type = get_sqlalchemy_type(field) - primary_key = getattr(field.field_info, "primary_key", False) - index = getattr(field.field_info, "index", Undefined) - if index is Undefined: + primary_key = getattr(field, "primary_key", False) + index = getattr(field, "index", PydanticUndefined) + if index is PydanticUndefined: index = False nullable = not primary_key and _is_field_noneable(field) # Override derived nullability if the nullable property is set explicitly # on the field - if hasattr(field.field_info, "nullable"): - field_nullable = getattr(field.field_info, "nullable") - if field_nullable != Undefined: + if hasattr(field, "nullable"): + field_nullable = getattr(field, "nullable") + if field_nullable != PydanticUndefined: nullable = field_nullable args = [] - foreign_key = getattr(field.field_info, "foreign_key", None) - unique = getattr(field.field_info, "unique", False) + foreign_key = getattr(field, "foreign_key", None) + unique = getattr(field, "unique", False) if foreign_key: args.append(ForeignKey(foreign_key)) kwargs = { @@ -442,18 +457,18 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore "index": index, "unique": unique, } - sa_default = Undefined - if field.field_info.default_factory: - sa_default = field.field_info.default_factory - elif field.field_info.default is not Undefined: - sa_default = field.field_info.default - if sa_default is not Undefined: + sa_default = PydanticUndefined + if field.default_factory: + sa_default = field.default_factory + elif field.default is not PydanticUndefined: + sa_default = field.default + if sa_default is not PydanticUndefined: kwargs["default"] = sa_default - sa_column_args = getattr(field.field_info, "sa_column_args", Undefined) - if sa_column_args is not Undefined: + sa_column_args = getattr(field, "sa_column_args", PydanticUndefined) + if sa_column_args is not PydanticUndefined: args.extend(list(cast(Sequence[Any], sa_column_args))) - sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined) - if sa_column_kwargs is not Undefined: + sa_column_kwargs = getattr(field, "sa_column_kwargs", PydanticUndefined) + if sa_column_kwargs is not PydanticUndefined: kwargs.update(cast(Dict[Any, Any], sa_column_kwargs)) return Column(sa_type, *args, **kwargs) # type: ignore @@ -462,13 +477,6 @@ class_registry = weakref.WeakValueDictionary() # type: ignore default_registry = registry() - -def _value_items_is_true(v: Any) -> bool: - # Re-implement Pydantic's ValueItems.is_true() as it hasn't been released as of - # the current latest, Pydantic 1.8.2 - return v is True or v is ... - - _TSQLModel = TypeVar("_TSQLModel", bound="SQLModel") @@ -476,43 +484,17 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values __slots__ = ("__weakref__",) __tablename__: ClassVar[Union[str, Callable[..., str]]] - __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore + __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] __name__: ClassVar[str] metadata: ClassVar[MetaData] __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six - - class Config: - orm_mode = True - - def __new__(cls, *args: Any, **kwargs: Any) -> Any: - new_object = super().__new__(cls) - # SQLAlchemy doesn't call __init__ on the base class - # Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html - # Set __fields_set__ here, that would have been set when calling __init__ - # in the Pydantic model so that when SQLAlchemy sets attributes that are - # added (e.g. when querying from DB) to the __fields_set__, this already exists - object.__setattr__(new_object, "__fields_set__", set()) - return new_object + model_config = SQLModelConfig(from_attributes=True) def __init__(__pydantic_self__, **data: Any) -> None: - # Uses something other than `self` the first arg to allow "self" as a - # settable attribute - values, fields_set, validation_error = validate_model( - __pydantic_self__.__class__, data - ) - # Only raise errors if not a SQLModel model - if ( - not getattr(__pydantic_self__.__config__, "table", False) - and validation_error - ): - raise validation_error - # Do not set values as in Pydantic, pass them through setattr, so SQLAlchemy - # can handle them - # object.__setattr__(__pydantic_self__, '__dict__', values) - for key, value in values.items(): - setattr(__pydantic_self__, key, value) - object.__setattr__(__pydantic_self__, "__fields_set__", fields_set) - non_pydantic_keys = data.keys() - values.keys() + old_dict = __pydantic_self__.__dict__.copy() + super().__init__(**data) + __pydantic_self__.__dict__ = old_dict | __pydantic_self__.__dict__ + non_pydantic_keys = data.keys() - __pydantic_self__.model_fields for key in non_pydantic_keys: if key in __pydantic_self__.__sqlmodel_relationships__: setattr(__pydantic_self__, key, data[key]) @@ -523,135 +505,36 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry return else: # Set in SQLAlchemy, before Pydantic to trigger events and updates - if getattr(self.__config__, "table", False) and is_instrumented(self, name): # type: ignore + if self.model_config.get("table", False) and is_instrumented(self, name): set_attribute(self, name, value) # Set in Pydantic model to trigger possible validation changes, only for # non relationship values if name not in self.__sqlmodel_relationships__: - super().__setattr__(name, value) - - @classmethod - def from_orm( - cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None - ) -> _TSQLModel: - # Duplicated from Pydantic - if not cls.__config__.orm_mode: - raise ConfigError( - "You must have the config attribute orm_mode=True to use from_orm" - ) - obj = {ROOT_KEY: obj} if cls.__custom_root_type__ else cls._decompose_class(obj) - # SQLModel, support update dict - if update is not None: - obj = {**obj, **update} - # End SQLModel support dict - if not getattr(cls.__config__, "table", False): - # If not table, normal Pydantic code - m: _TSQLModel = cls.__new__(cls) - else: - # If table, create the new instance normally to make SQLAlchemy create - # the _sa_instance_state attribute - m = cls() - values, fields_set, validation_error = validate_model(cls, obj) - if validation_error: - raise validation_error - # Updated to trigger SQLAlchemy internal handling - if not getattr(cls.__config__, "table", False): - object.__setattr__(m, "__dict__", values) - else: - for key, value in values.items(): - setattr(m, key, value) - # Continue with standard Pydantic logic - object.__setattr__(m, "__fields_set__", fields_set) - m._init_private_attributes() - return m - - @classmethod - def parse_obj( - cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None - ) -> _TSQLModel: - obj = cls._enforce_dict_if_root(obj) - # SQLModel, support update dict - if update is not None: - obj = {**obj, **update} - # End SQLModel support dict - return super().parse_obj(obj) + super(SQLModel, self).__setattr__(name, value) def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: # Don't show SQLAlchemy private attributes return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_sa_")] - # From Pydantic, override to enforce validation with dict - @classmethod - def validate(cls: Type[_TSQLModel], value: Any) -> _TSQLModel: - if isinstance(value, cls): - return value.copy() if cls.__config__.copy_on_model_validation else value - - value = cls._enforce_dict_if_root(value) - if isinstance(value, dict): - values, fields_set, validation_error = validate_model(cls, value) - if validation_error: - raise validation_error - model = cls(**value) - # Reset fields set, this would have been done in Pydantic in __init__ - object.__setattr__(model, "__fields_set__", fields_set) - return model - elif cls.__config__.orm_mode: - return cls.from_orm(value) - elif cls.__custom_root_type__: - return cls.parse_obj(value) - else: - try: - value_as_dict = dict(value) - except (TypeError, ValueError) as e: - raise DictError() from e - return cls(**value_as_dict) - - # From Pydantic, override to only show keys from fields, omit SQLAlchemy attributes - def _calculate_keys( - self, - include: Optional[Mapping[Union[int, str], Any]], - exclude: Optional[Mapping[Union[int, str], Any]], - exclude_unset: bool, - update: Optional[Dict[str, Any]] = None, - ) -> Optional[AbstractSet[str]]: - if include is None and exclude is None and not exclude_unset: - # Original in Pydantic: - # return None - # Updated to not return SQLAlchemy attributes - # Do not include relationships as that would easily lead to infinite - # recursion, or traversing the whole database - return self.__fields__.keys() # | self.__sqlmodel_relationships__.keys() - - keys: AbstractSet[str] - if exclude_unset: - keys = self.__fields_set__.copy() - else: - # Original in Pydantic: - # keys = self.__dict__.keys() - # Updated to not return SQLAlchemy attributes - # Do not include relationships as that would easily lead to infinite - # recursion, or traversing the whole database - keys = self.__fields__.keys() # | self.__sqlmodel_relationships__.keys() - if include is not None: - keys &= include.keys() - - if update: - keys -= update.keys() - - if exclude: - keys -= {k for k, v in exclude.items() if _value_items_is_true(v)} - - return keys @declared_attr # type: ignore def __tablename__(cls) -> str: return cls.__name__.lower() -def _is_field_noneable(field: ModelField) -> bool: - if not field.required: - # Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947) - return field.allow_none and ( - field.shape != SHAPE_SINGLETON or not field.sub_fields - ) +def _is_field_noneable(field: FieldInfo) -> bool: + if not field.is_required(): + if field.annotation is None or field.annotation is type(None): + return True + if get_origin(field.annotation) is Union: + for base in get_args(field.annotation): + if base is type(None): + return True + return False return False + +def _get_field_metadata(field: FieldInfo) -> object: + for meta in field.metadata: + if isinstance(meta, PydanticGeneralMetadata): + return meta + return object() \ No newline at end of file diff --git a/sqlmodel/typing.py b/sqlmodel/typing.py new file mode 100644 index 0000000..570da2f --- /dev/null +++ b/sqlmodel/typing.py @@ -0,0 +1,7 @@ +from pydantic import ConfigDict +from typing import Optional, Any + +class SQLModelConfig(ConfigDict): + table: Optional[bool] + read_from_attributes: Optional[bool] + registry: Optional[Any] \ No newline at end of file diff --git a/tests/test_instance_no_args.py b/tests/test_instance_no_args.py index 14d5606..5dc520c 100644 --- a/tests/test_instance_no_args.py +++ b/tests/test_instance_no_args.py @@ -8,7 +8,7 @@ from sqlmodel import Field, SQLModel def test_allow_instantiation_without_arguments(clear_sqlmodel): class Item(SQLModel): id: Optional[int] = Field(default=None, primary_key=True) - name: str + name: str description: Optional[str] = None class Config: diff --git a/tests/test_missing_type.py b/tests/test_missing_type.py index 2185fa4..dd12b25 100644 --- a/tests/test_missing_type.py +++ b/tests/test_missing_type.py @@ -2,10 +2,11 @@ from typing import Optional import pytest from sqlmodel import Field, SQLModel +from pydantic import BaseModel def test_missing_sql_type(): - class CustomType: + class CustomType(BaseModel): @classmethod def __get_validators__(cls): yield cls.validate diff --git a/tests/test_validation.py b/tests/test_validation.py index a3ff6e3..4883fb2 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,7 +1,7 @@ from typing import Optional import pytest -from pydantic import validator +from pydantic import field_validator from pydantic.error_wrappers import ValidationError from sqlmodel import SQLModel @@ -22,12 +22,12 @@ def test_validation(clear_sqlmodel): secret_name: Optional[str] = None age: Optional[int] = None - @validator("name", "secret_name", "age") + @field_validator("name", "secret_name", "age") def reject_none(cls, v): assert v is not None return v - Hero.validate({"age": 25}) + Hero.model_validate({"age": 25}) with pytest.raises(ValidationError): - Hero.validate({"name": None, "age": 25}) + Hero.model_validate({"name": None, "age": 25})