commit fcff2050e671c68de9cf5d8b7bc22b05119db1c4 Author: Sebastián Ramírez Date: Tue Aug 24 14:41:53 2021 +0200 ✨ Add SQLModel core code diff --git a/sqlmodel/__init__.py b/sqlmodel/__init__.py new file mode 100644 index 0000000..cdfb889 --- /dev/null +++ b/sqlmodel/__init__.py @@ -0,0 +1,139 @@ +__version__ = "0.0.1" + +# Re-export from SQLAlchemy +from sqlalchemy.engine import create_mock_engine as create_mock_engine +from sqlalchemy.engine import engine_from_config as engine_from_config +from sqlalchemy.inspection import inspect as inspect +from sqlalchemy.schema import BLANK_SCHEMA as BLANK_SCHEMA +from sqlalchemy.schema import CheckConstraint as CheckConstraint +from sqlalchemy.schema import Column as Column +from sqlalchemy.schema import ColumnDefault as ColumnDefault +from sqlalchemy.schema import Computed as Computed +from sqlalchemy.schema import Constraint as Constraint +from sqlalchemy.schema import DDL as DDL +from sqlalchemy.schema import DefaultClause as DefaultClause +from sqlalchemy.schema import FetchedValue as FetchedValue +from sqlalchemy.schema import ForeignKey as ForeignKey +from sqlalchemy.schema import ForeignKeyConstraint as ForeignKeyConstraint +from sqlalchemy.schema import Identity as Identity +from sqlalchemy.schema import Index as Index +from sqlalchemy.schema import MetaData as MetaData +from sqlalchemy.schema import PrimaryKeyConstraint as PrimaryKeyConstraint +from sqlalchemy.schema import Sequence as Sequence +from sqlalchemy.schema import Table as Table +from sqlalchemy.schema import ThreadLocalMetaData as ThreadLocalMetaData +from sqlalchemy.schema import UniqueConstraint as UniqueConstraint +from sqlalchemy.sql import alias as alias +from sqlalchemy.sql import all_ as all_ +from sqlalchemy.sql import and_ as and_ +from sqlalchemy.sql import any_ as any_ +from sqlalchemy.sql import asc as asc +from sqlalchemy.sql import between as between +from sqlalchemy.sql import bindparam as bindparam +from sqlalchemy.sql import case as case +from sqlalchemy.sql import cast as cast +from sqlalchemy.sql import collate as collate +from sqlalchemy.sql import column as column +from sqlalchemy.sql import delete as delete +from sqlalchemy.sql import desc as desc +from sqlalchemy.sql import distinct as distinct +from sqlalchemy.sql import except_ as except_ +from sqlalchemy.sql import except_all as except_all +from sqlalchemy.sql import exists as exists +from sqlalchemy.sql import extract as extract +from sqlalchemy.sql import false as false +from sqlalchemy.sql import func as func +from sqlalchemy.sql import funcfilter as funcfilter +from sqlalchemy.sql import insert as insert +from sqlalchemy.sql import intersect as intersect +from sqlalchemy.sql import intersect_all as intersect_all +from sqlalchemy.sql import join as join +from sqlalchemy.sql import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT +from sqlalchemy.sql import ( + LABEL_STYLE_DISAMBIGUATE_ONLY as LABEL_STYLE_DISAMBIGUATE_ONLY, +) +from sqlalchemy.sql import LABEL_STYLE_NONE as LABEL_STYLE_NONE +from sqlalchemy.sql import ( + LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL, +) +from sqlalchemy.sql import lambda_stmt as lambda_stmt +from sqlalchemy.sql import lateral as lateral +from sqlalchemy.sql import literal as literal +from sqlalchemy.sql import literal_column as literal_column +from sqlalchemy.sql import modifier as modifier +from sqlalchemy.sql import not_ as not_ +from sqlalchemy.sql import null as null +from sqlalchemy.sql import nulls_first as nulls_first +from sqlalchemy.sql import nulls_last as nulls_last +from sqlalchemy.sql import nullsfirst as nullsfirst +from sqlalchemy.sql import nullslast as nullslast +from sqlalchemy.sql import or_ as or_ +from sqlalchemy.sql import outerjoin as outerjoin +from sqlalchemy.sql import outparam as outparam +from sqlalchemy.sql import over as over +from sqlalchemy.sql import subquery as subquery +from sqlalchemy.sql import table as table +from sqlalchemy.sql import tablesample as tablesample +from sqlalchemy.sql import text as text +from sqlalchemy.sql import true as true +from sqlalchemy.sql import tuple_ as tuple_ +from sqlalchemy.sql import type_coerce as type_coerce +from sqlalchemy.sql import union as union +from sqlalchemy.sql import union_all as union_all +from sqlalchemy.sql import update as update +from sqlalchemy.sql import values as values +from sqlalchemy.sql import within_group as within_group +from sqlalchemy.types import ARRAY as ARRAY +from sqlalchemy.types import BIGINT as BIGINT +from sqlalchemy.types import BigInteger as BigInteger +from sqlalchemy.types import BINARY as BINARY +from sqlalchemy.types import BLOB as BLOB +from sqlalchemy.types import BOOLEAN as BOOLEAN +from sqlalchemy.types import Boolean as Boolean +from sqlalchemy.types import CHAR as CHAR +from sqlalchemy.types import CLOB as CLOB +from sqlalchemy.types import DATE as DATE +from sqlalchemy.types import Date as Date +from sqlalchemy.types import DATETIME as DATETIME +from sqlalchemy.types import DateTime as DateTime +from sqlalchemy.types import DECIMAL as DECIMAL +from sqlalchemy.types import Enum as Enum +from sqlalchemy.types import FLOAT as FLOAT +from sqlalchemy.types import Float as Float +from sqlalchemy.types import INT as INT +from sqlalchemy.types import INTEGER as INTEGER +from sqlalchemy.types import Integer as Integer +from sqlalchemy.types import Interval as Interval +from sqlalchemy.types import JSON as JSON +from sqlalchemy.types import LargeBinary as LargeBinary +from sqlalchemy.types import NCHAR as NCHAR +from sqlalchemy.types import NUMERIC as NUMERIC +from sqlalchemy.types import Numeric as Numeric +from sqlalchemy.types import NVARCHAR as NVARCHAR +from sqlalchemy.types import PickleType as PickleType +from sqlalchemy.types import REAL as REAL +from sqlalchemy.types import SMALLINT as SMALLINT +from sqlalchemy.types import SmallInteger as SmallInteger +from sqlalchemy.types import String as String +from sqlalchemy.types import TEXT as TEXT +from sqlalchemy.types import Text as Text +from sqlalchemy.types import TIME as TIME +from sqlalchemy.types import Time as Time +from sqlalchemy.types import TIMESTAMP as TIMESTAMP +from sqlalchemy.types import TypeDecorator as TypeDecorator +from sqlalchemy.types import Unicode as Unicode +from sqlalchemy.types import UnicodeText as UnicodeText +from sqlalchemy.types import VARBINARY as VARBINARY +from sqlalchemy.types import VARCHAR as VARCHAR + +# Extensions and modifications of SQLAlchemy in SQLModel +from .engine.create import create_engine as create_engine +from .orm.session import Session as Session +from .sql.expression import select as select +from .sql.expression import col as col +from .sql.sqltypes import AutoString as AutoString + +# Export SQLModel specifics (equivalent to Pydantic) +from .main import SQLModel as SQLModel +from .main import Field as Field +from .main import Relationship as Relationship diff --git a/sqlmodel/default.py b/sqlmodel/default.py new file mode 100644 index 0000000..bb44972 --- /dev/null +++ b/sqlmodel/default.py @@ -0,0 +1,32 @@ +from typing import Any, TypeVar + + +class _DefaultPlaceholder: + """ + You shouldn't use this class directly. + + It's used internally to recognize when a default value has been overwritten, even + if the overriden default value was truthy. + """ + + def __init__(self, value: Any): + self.value = value + + def __bool__(self) -> bool: + return bool(self.value) + + def __eq__(self, o: object) -> bool: + return isinstance(o, _DefaultPlaceholder) and o.value == self.value + + +_TDefaultType = TypeVar("_TDefaultType") + + +def Default(value: _TDefaultType) -> _TDefaultType: + """ + You shouldn't use this function directly. + + It's used internally to recognize when a default value has been overwritten, even + if the overriden default value was truthy. + """ + return _DefaultPlaceholder(value) # type: ignore diff --git a/sqlmodel/engine/__init__.py b/sqlmodel/engine/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sqlmodel/engine/create.py b/sqlmodel/engine/create.py new file mode 100644 index 0000000..9748125 --- /dev/null +++ b/sqlmodel/engine/create.py @@ -0,0 +1,139 @@ +import json +import sqlite3 +from typing import Any, Callable, Dict, List, Optional, Type, Union + +from sqlalchemy import create_engine as _create_engine +from sqlalchemy.engine.url import URL +from sqlalchemy.future import Engine as _FutureEngine +from sqlalchemy.pool import Pool +from typing_extensions import Literal, TypedDict + +from ..default import Default, _DefaultPlaceholder + +# Types defined in sqlalchemy2-stubs, but can't be imported, so re-define here + +_Debug = Literal["debug"] + +_IsolationLevel = Literal[ + "SERIALIZABLE", + "REPEATABLE READ", + "READ COMMITTED", + "READ UNCOMMITTED", + "AUTOCOMMIT", +] +_ParamStyle = Literal["qmark", "numeric", "named", "format", "pyformat"] +_ResetOnReturn = Literal["rollback", "commit"] + + +class _SQLiteConnectArgs(TypedDict, total=False): + timeout: float + detect_types: Any + isolation_level: Optional[Literal["DEFERRED", "IMMEDIATE", "EXCLUSIVE"]] + check_same_thread: bool + factory: Type[sqlite3.Connection] + cached_statements: int + uri: bool + + +_ConnectArgs = Union[_SQLiteConnectArgs, Dict[str, Any]] + + +# Re-define create_engine to have by default future=True, and assume that's what is used +# Also show the default values used for each parameter, but don't set them unless +# explicitly passed as arguments by the user to prevent errors. E.g. SQLite doesn't +# support pool connection arguments. +def create_engine( + url: Union[str, URL], + *, + connect_args: _ConnectArgs = Default({}), # type: ignore + echo: Union[bool, _Debug] = Default(False), + echo_pool: Union[bool, _Debug] = Default(False), + enable_from_linting: bool = Default(True), + encoding: str = Default("utf-8"), + execution_options: Dict[Any, Any] = Default({}), + future: bool = True, + hide_parameters: bool = Default(False), + implicit_returning: bool = Default(True), + isolation_level: Optional[_IsolationLevel] = Default(None), + json_deserializer: Callable[..., Any] = Default(json.loads), + json_serializer: Callable[..., Any] = Default(json.dumps), + label_length: Optional[int] = Default(None), + logging_name: Optional[str] = Default(None), + max_identifier_length: Optional[int] = Default(None), + max_overflow: int = Default(10), + module: Optional[Any] = Default(None), + paramstyle: Optional[_ParamStyle] = Default(None), + pool: Optional[Pool] = Default(None), + poolclass: Optional[Type[Pool]] = Default(None), + pool_logging_name: Optional[str] = Default(None), + pool_pre_ping: bool = Default(False), + pool_size: int = Default(5), + pool_recycle: int = Default(-1), + pool_reset_on_return: Optional[_ResetOnReturn] = Default("rollback"), + pool_timeout: float = Default(30), + pool_use_lifo: bool = Default(False), + plugins: Optional[List[str]] = Default(None), + query_cache_size: Optional[int] = Default(None), + **kwargs: Any, +) -> _FutureEngine: + current_kwargs: Dict[str, Any] = { + "future": future, + } + if not isinstance(echo, _DefaultPlaceholder): + current_kwargs["echo"] = echo + if not isinstance(echo_pool, _DefaultPlaceholder): + current_kwargs["echo_pool"] = echo_pool + if not isinstance(enable_from_linting, _DefaultPlaceholder): + current_kwargs["enable_from_linting"] = enable_from_linting + if not isinstance(connect_args, _DefaultPlaceholder): + current_kwargs["connect_args"] = connect_args + if not isinstance(encoding, _DefaultPlaceholder): + current_kwargs["encoding"] = encoding + if not isinstance(execution_options, _DefaultPlaceholder): + current_kwargs["execution_options"] = execution_options + if not isinstance(hide_parameters, _DefaultPlaceholder): + current_kwargs["hide_parameters"] = hide_parameters + if not isinstance(implicit_returning, _DefaultPlaceholder): + current_kwargs["implicit_returning"] = implicit_returning + if not isinstance(isolation_level, _DefaultPlaceholder): + current_kwargs["isolation_level"] = isolation_level + if not isinstance(json_deserializer, _DefaultPlaceholder): + current_kwargs["json_deserializer"] = json_deserializer + if not isinstance(json_serializer, _DefaultPlaceholder): + current_kwargs["json_serializer"] = json_serializer + if not isinstance(label_length, _DefaultPlaceholder): + current_kwargs["label_length"] = label_length + if not isinstance(logging_name, _DefaultPlaceholder): + current_kwargs["logging_name"] = logging_name + if not isinstance(max_identifier_length, _DefaultPlaceholder): + current_kwargs["max_identifier_length"] = max_identifier_length + if not isinstance(max_overflow, _DefaultPlaceholder): + current_kwargs["max_overflow"] = max_overflow + if not isinstance(module, _DefaultPlaceholder): + current_kwargs["module"] = module + if not isinstance(paramstyle, _DefaultPlaceholder): + current_kwargs["paramstyle"] = paramstyle + if not isinstance(pool, _DefaultPlaceholder): + current_kwargs["pool"] = pool + if not isinstance(poolclass, _DefaultPlaceholder): + current_kwargs["poolclass"] = poolclass + if not isinstance(pool_logging_name, _DefaultPlaceholder): + current_kwargs["pool_logging_name"] = pool_logging_name + if not isinstance(pool_pre_ping, _DefaultPlaceholder): + current_kwargs["pool_pre_ping"] = pool_pre_ping + if not isinstance(pool_size, _DefaultPlaceholder): + current_kwargs["pool_size"] = pool_size + if not isinstance(pool_recycle, _DefaultPlaceholder): + current_kwargs["pool_recycle"] = pool_recycle + if not isinstance(pool_reset_on_return, _DefaultPlaceholder): + current_kwargs["pool_reset_on_return"] = pool_reset_on_return + if not isinstance(pool_timeout, _DefaultPlaceholder): + current_kwargs["pool_timeout"] = pool_timeout + if not isinstance(pool_use_lifo, _DefaultPlaceholder): + current_kwargs["pool_use_lifo"] = pool_use_lifo + if not isinstance(plugins, _DefaultPlaceholder): + current_kwargs["plugins"] = plugins + if not isinstance(query_cache_size, _DefaultPlaceholder): + current_kwargs["query_cache_size"] = query_cache_size + current_kwargs.update(kwargs) + return _create_engine(url, **current_kwargs) diff --git a/sqlmodel/engine/result.py b/sqlmodel/engine/result.py new file mode 100644 index 0000000..d521427 --- /dev/null +++ b/sqlmodel/engine/result.py @@ -0,0 +1,79 @@ +from typing import Generic, Iterator, List, Optional, TypeVar + +from sqlalchemy.engine.result import Result as _Result +from sqlalchemy.engine.result import ScalarResult as _ScalarResult + +_T = TypeVar("_T") + + +class ScalarResult(_ScalarResult, Generic[_T]): + def all(self) -> List[_T]: + return super().all() + + def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]: + return super().partitions(size) + + def fetchall(self) -> List[_T]: + return super().fetchall() + + def fetchmany(self, size: Optional[int] = None) -> List[_T]: + return super().fetchmany(size) + + def __iter__(self) -> Iterator[_T]: + return super().__iter__() + + def __next__(self) -> _T: + return super().__next__() + + def first(self) -> Optional[_T]: + return super().first() + + def one_or_none(self) -> Optional[_T]: + return super().one_or_none() + + def one(self) -> _T: + return super().one() + + +class Result(_Result, Generic[_T]): + def scalars(self, index: int = 0) -> ScalarResult[_T]: + return super().scalars(index) # type: ignore + + def __iter__(self) -> Iterator[_T]: # type: ignore + return super().__iter__() # type: ignore + + def __next__(self) -> _T: # type: ignore + return super().__next__() # type: ignore + + def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]: # type: ignore + return super().partitions(size) # type: ignore + + def fetchall(self) -> List[_T]: # type: ignore + return super().fetchall() # type: ignore + + def fetchone(self) -> Optional[_T]: # type: ignore + return super().fetchone() # type: ignore + + def fetchmany(self, size: Optional[int] = None) -> List[_T]: # type: ignore + return super().fetchmany() # type: ignore + + def all(self) -> List[_T]: # type: ignore + return super().all() # type: ignore + + def first(self) -> Optional[_T]: # type: ignore + return super().first() # type: ignore + + def one_or_none(self) -> Optional[_T]: # type: ignore + return super().one_or_none() # type: ignore + + def scalar_one(self) -> _T: + return super().scalar_one() # type: ignore + + def scalar_one_or_none(self) -> Optional[_T]: + return super().scalar_one_or_none() # type: ignore + + def one(self) -> _T: # type: ignore + return super().one() # type: ignore + + def scalar(self) -> Optional[_T]: + return super().scalar() # type: ignore diff --git a/sqlmodel/ext/__init__.py b/sqlmodel/ext/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sqlmodel/ext/asyncio/__init__.py b/sqlmodel/ext/asyncio/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sqlmodel/ext/asyncio/session.py b/sqlmodel/ext/asyncio/session.py new file mode 100644 index 0000000..40e5b76 --- /dev/null +++ b/sqlmodel/ext/asyncio/session.py @@ -0,0 +1,62 @@ +from typing import Any, Mapping, Optional, Sequence, TypeVar, Union + +from sqlalchemy import util +from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession +from sqlalchemy.ext.asyncio import engine +from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine +from sqlalchemy.util.concurrency import greenlet_spawn +from sqlmodel.sql.base import Executable + +from ...engine.result import ScalarResult +from ...orm.session import Session +from ...sql.expression import Select + +_T = TypeVar("_T") + + +class AsyncSession(_AsyncSession): + sync_session: Session + + def __init__( + self, + bind: Optional[Union[AsyncConnection, AsyncEngine]] = None, + binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None, + **kw, + ): + # All the same code of the original AsyncSession + kw["future"] = True + if bind: + self.bind = bind + bind = engine._get_sync_engine_or_connection(bind) # type: ignore + + if binds: + self.binds = binds + binds = { + key: engine._get_sync_engine_or_connection(b) # type: ignore + for key, b in binds.items() + } + + self.sync_session = self._proxied = self._assign_proxied( # type: ignore + Session(bind=bind, binds=binds, **kw) # type: ignore + ) + + async def exec( + self, + statement: Union[Select[_T], Executable[_T]], + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Mapping[Any, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Mapping[str, Any]] = None, + **kw: Any, + ) -> ScalarResult[_T]: + # TODO: the documentation says execution_options accepts a dict, but only + # util.immutabledict has the union() method. Is this a bug in SQLAlchemy? + execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore + + return await greenlet_spawn( # type: ignore + self.sync_session.exec, + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) diff --git a/sqlmodel/main.py b/sqlmodel/main.py new file mode 100644 index 0000000..8036ceb --- /dev/null +++ b/sqlmodel/main.py @@ -0,0 +1,631 @@ +import ipaddress +import uuid +import weakref +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from enum import Enum +from pathlib import Path +from typing import ( + TYPE_CHECKING, + AbstractSet, + Any, + Callable, + ClassVar, + Dict, + ForwardRef, + List, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +from pydantic import BaseModel +from pydantic.errors import ConfigError, DictError +from pydantic.fields import FieldInfo as PydanticFieldInfo +from pydantic.fields import ModelField, Undefined, UndefinedType +from pydantic.main import BaseConfig, ModelMetaclass, validate_model +from pydantic.typing import NoArgAnyCallable, resolve_annotations +from pydantic.utils import ROOT_KEY, Representation, ValueItems +from sqlalchemy import ( + Boolean, + Column, + Date, + DateTime, + Float, + ForeignKey, + Integer, + Interval, + Numeric, + inspect, +) +from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship +from sqlalchemy.orm.attributes import set_attribute +from sqlalchemy.orm.decl_api import DeclarativeMeta +from sqlalchemy.orm.instrumentation import is_instrumented +from sqlalchemy.sql.schema import MetaData +from sqlalchemy.sql.sqltypes import LargeBinary, Time + +from .sql.sqltypes import GUID, AutoString + +_T = TypeVar("_T") + + +def __dataclass_transform__( + *, + eq_default: bool = True, + order_default: bool = False, + kw_only_default: bool = False, + field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()), +) -> Callable[[_T], _T]: + return lambda a: a + + +class FieldInfo(PydanticFieldInfo): + def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: + primary_key = kwargs.pop("primary_key", False) + nullable = kwargs.pop("nullable", Undefined) + foreign_key = kwargs.pop("foreign_key", Undefined) + index = kwargs.pop("index", Undefined) + sa_column = kwargs.pop("sa_column", Undefined) + sa_column_args = kwargs.pop("sa_column_args", Undefined) + sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined) + if sa_column is not Undefined: + if sa_column_args is not Undefined: + raise RuntimeError( + "Passing sa_column_args is not supported when " + "also passing a sa_column" + ) + if sa_column_kwargs is not Undefined: + raise RuntimeError( + "Passing sa_column_kwargs is not supported when " + "also passing a sa_column" + ) + super().__init__(default=default, **kwargs) + self.primary_key = primary_key + self.nullable = nullable + self.foreign_key = foreign_key + self.index = index + self.sa_column = sa_column + self.sa_column_args = sa_column_args + self.sa_column_kwargs = sa_column_kwargs + + +class RelationshipInfo(Representation): + def __init__( + self, + *, + back_populates: Optional[str] = None, + link_model: Optional[Any] = None, + sa_relationship: Optional[RelationshipProperty] = None, + sa_relationship_args: Optional[Sequence[Any]] = None, + sa_relationship_kwargs: Optional[Mapping[str, Any]] = None, + ) -> None: + if sa_relationship is not None: + if sa_relationship_args is not None: + raise RuntimeError( + "Passing sa_relationship_args is not supported when " + "also passing a sa_relationship" + ) + if sa_relationship_kwargs is not None: + raise RuntimeError( + "Passing sa_relationship_kwargs is not supported when " + "also passing a sa_relationship" + ) + self.back_populates = back_populates + self.link_model = link_model + self.sa_relationship = sa_relationship + self.sa_relationship_args = sa_relationship_args + self.sa_relationship_kwargs = sa_relationship_kwargs + + +def Field( + default: Any = Undefined, + *, + default_factory: Optional[NoArgAnyCallable] = None, + alias: str = None, + title: str = None, + description: str = None, + exclude: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + include: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + const: bool = None, + gt: float = None, + ge: float = None, + lt: float = None, + le: float = None, + multiple_of: float = None, + min_items: int = None, + max_items: int = None, + min_length: int = None, + max_length: int = None, + allow_mutation: bool = True, + regex: str = None, + primary_key: bool = False, + foreign_key: Optional[Any] = None, + nullable: Union[bool, UndefinedType] = Undefined, + index: Union[bool, UndefinedType] = Undefined, + sa_column: Union[Column, UndefinedType] = Undefined, + sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, + sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, + schema_extra: Optional[Dict[str, Any]] = None, +) -> Any: + current_schema_extra = schema_extra or {} + field_info = FieldInfo( + default, + default_factory=default_factory, + alias=alias, + title=title, + description=description, + exclude=exclude, + include=include, + const=const, + gt=gt, + ge=ge, + lt=lt, + le=le, + multiple_of=multiple_of, + min_items=min_items, + max_items=max_items, + min_length=min_length, + max_length=max_length, + allow_mutation=allow_mutation, + regex=regex, + primary_key=primary_key, + foreign_key=foreign_key, + nullable=nullable, + index=index, + sa_column=sa_column, + sa_column_args=sa_column_args, + sa_column_kwargs=sa_column_kwargs, + **current_schema_extra, + ) + field_info._validate() + return field_info + + +def Relationship( + *, + back_populates: Optional[str] = None, + link_model: Optional[Any] = None, + sa_relationship: Optional[RelationshipProperty] = None, + sa_relationship_args: Optional[Sequence[Any]] = None, + sa_relationship_kwargs: Optional[Mapping[str, Any]] = None, +) -> Any: + relationship_info = RelationshipInfo( + back_populates=back_populates, + link_model=link_model, + sa_relationship=sa_relationship, + sa_relationship_args=sa_relationship_args, + sa_relationship_kwargs=sa_relationship_kwargs, + ) + return relationship_info + + +@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) +class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): + __sqlmodel_relationships__: Dict[str, RelationshipInfo] + __config__: Type[BaseConfig] + __fields__: Dict[str, ModelField] + + # Replicate SQLAlchemy + def __setattr__(cls, name: str, value: Any) -> None: + if getattr(cls.__config__, "table", False): # type: ignore + DeclarativeMeta.__setattr__(cls, name, value) + else: + super().__setattr__(name, value) + + def __delattr__(cls, name: str) -> None: + if getattr(cls.__config__, "table", False): # type: ignore + DeclarativeMeta.__delattr__(cls, name) + else: + super().__delattr__(name) + + # From Pydantic + def __new__(cls, name, bases, class_dict: dict, **kwargs) -> Any: + relationships: Dict[str, RelationshipInfo] = {} + dict_for_pydantic = {} + original_annotations = resolve_annotations( + class_dict.get("__annotations__", {}), class_dict.get("__module__", None) + ) + pydantic_annotations = {} + relationship_annotations = {} + for k, v in class_dict.items(): + if isinstance(v, RelationshipInfo): + relationships[k] = v + else: + dict_for_pydantic[k] = v + for k, v in original_annotations.items(): + if k in relationships: + relationship_annotations[k] = v + else: + pydantic_annotations[k] = v + dict_used = { + **dict_for_pydantic, + "__weakref__": None, + "__sqlmodel_relationships__": relationships, + "__annotations__": pydantic_annotations, + } + # Duplicate logic from Pydantic to filter config kwargs because if they are + # passed directly including the registry Pydantic will pass them over to the + # superclass causing an error + allowed_config_kwargs: Set[str] = { + key + for key in dir(BaseConfig) + if not ( + key.startswith("__") and key.endswith("__") + ) # skip dunder methods and attributes + } + pydantic_kwargs = kwargs.copy() + config_kwargs = { + key: pydantic_kwargs.pop(key) + for key in pydantic_kwargs.keys() & allowed_config_kwargs + } + new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) + new_cls.__annotations__ = { + **relationship_annotations, + **pydantic_annotations, + **new_cls.__annotations__, + } + + def get_config(name: str) -> Any: + config_class_value = getattr(new_cls.__config__, name, Undefined) + if config_class_value is not Undefined: + return config_class_value + kwarg_value = kwargs.get(name, Undefined) + if kwarg_value is not Undefined: + return kwarg_value + return Undefined + + config_table = get_config("table") + if config_table is True: + # If it was passed by kwargs, ensure it's also set in config + new_cls.__config__.table = config_table + for k, v in new_cls.__fields__.items(): + col = get_column_from_field(v) + setattr(new_cls, k, col) + # Set a config flag to tell FastAPI that this should be read with a field + # in orm_mode instead of preemptively converting it to a dict. + # This could be done by reading new_cls.__config__.table in FastAPI, but + # that's very specific about SQLModel, so let's have another config that + # other future tools based on Pydantic can use. + new_cls.__config__.read_with_orm_mode = True + + config_registry = get_config("registry") + if config_registry is not Undefined: + config_registry = cast(registry, config_registry) + # If it was passed by kwargs, ensure it's also set in config + new_cls.__config__.registry = config_table + setattr(new_cls, "_sa_registry", config_registry) + setattr(new_cls, "metadata", config_registry.metadata) + setattr(new_cls, "__abstract__", True) + return new_cls + + # Override SQLAlchemy, allow both SQLAlchemy and plain Pydantic models + def __init__( + cls, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any], **kw: Any + ) -> None: + # Only one of the base classes (or the current one) should be a table model + # this allows FastAPI cloning a SQLModel for the response_model without + # trying to create a new SQLAlchemy, for a new table, with the same name, that + # triggers an error + base_is_table = False + for base in bases: + config = getattr(base, "__config__") + if config and getattr(config, "table", False): + base_is_table = True + break + if getattr(cls.__config__, "table", False) and not base_is_table: + dict_used = dict_.copy() + for field_name, field_value in cls.__fields__.items(): + dict_used[field_name] = get_column_from_field(field_value) + for rel_name, rel_info in cls.__sqlmodel_relationships__.items(): + if rel_info.sa_relationship: + # There's a SQLAlchemy relationship declared, that takes precedence + # over anything else, use that and continue with the next attribute + dict_used[rel_name] = rel_info.sa_relationship + continue + ann = cls.__annotations__[rel_name] + temp_field = ModelField.infer( + name=rel_name, + value=rel_info, + annotation=ann, + class_validators=None, + config=BaseConfig, + ) + relationship_to = temp_field.type_ + if isinstance(temp_field.type_, ForwardRef): + relationship_to = temp_field.type_.__forward_arg__ + rel_kwargs: Dict[str, Any] = {} + if rel_info.back_populates: + rel_kwargs["back_populates"] = rel_info.back_populates + if rel_info.link_model: + ins = inspect(rel_info.link_model) + local_table = getattr(ins, "local_table") + if local_table is None: + raise RuntimeError( + "Couldn't find the secondary table for " + f"model {rel_info.link_model}" + ) + rel_kwargs["secondary"] = local_table + rel_args: List[Any] = [] + if rel_info.sa_relationship_args: + rel_args.extend(rel_info.sa_relationship_args) + if rel_info.sa_relationship_kwargs: + rel_kwargs.update(rel_info.sa_relationship_kwargs) + rel_value: RelationshipProperty = relationship( + relationship_to, *rel_args, **rel_kwargs + ) + dict_used[rel_name] = rel_value + DeclarativeMeta.__init__(cls, classname, bases, dict_used, **kw) + else: + ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) + + +def get_sqlachemy_type(field: ModelField) -> Any: + if issubclass(field.type_, str): + if field.field_info.max_length: + return AutoString(length=field.field_info.max_length) + return AutoString + if issubclass(field.type_, float): + return Float + if issubclass(field.type_, bool): + return Boolean + if issubclass(field.type_, int): + return Integer + if issubclass(field.type_, datetime): + return DateTime + if issubclass(field.type_, date): + return Date + if issubclass(field.type_, timedelta): + return Interval + if issubclass(field.type_, time): + return Time + if issubclass(field.type_, Enum): + return Enum + if issubclass(field.type_, bytes): + return LargeBinary + if issubclass(field.type_, Decimal): + return Numeric + if issubclass(field.type_, ipaddress.IPv4Address): + return AutoString + if issubclass(field.type_, ipaddress.IPv4Network): + return AutoString + if issubclass(field.type_, ipaddress.IPv6Address): + return AutoString + if issubclass(field.type_, ipaddress.IPv6Network): + return AutoString + if issubclass(field.type_, Path): + return AutoString + if issubclass(field.type_, uuid.UUID): + return GUID + + +def get_column_from_field(field: ModelField) -> Column: + sa_column = getattr(field.field_info, "sa_column", Undefined) + if isinstance(sa_column, Column): + return sa_column + sa_type = get_sqlachemy_type(field) + primary_key = getattr(field.field_info, "primary_key", False) + nullable = not field.required + index = getattr(field.field_info, "index", Undefined) + if index is Undefined: + index = True + if hasattr(field.field_info, "nullable"): + field_nullable = getattr(field.field_info, "nullable") + if field_nullable != Undefined: + nullable = field_nullable + args = [] + foreign_key = getattr(field.field_info, "foreign_key", None) + if foreign_key: + args.append(ForeignKey(foreign_key)) + kwargs = { + "primary_key": primary_key, + "nullable": nullable, + "index": index, + } + sa_default = Undefined + if field.field_info.default_factory: + sa_default = field.field_info.default_factory + elif field.field_info.default is not Undefined: + sa_default = field.field_info.default + if sa_default is not Undefined: + kwargs["default"] = sa_default + sa_column_args = getattr(field.field_info, "sa_column_args", Undefined) + if sa_column_args is not Undefined: + args.extend(list(cast(Sequence, sa_column_args))) + sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined) + if sa_column_kwargs is not Undefined: + kwargs.update(cast(dict, sa_column_kwargs)) + return Column(sa_type, *args, **kwargs) + + +class_registry = weakref.WeakValueDictionary() # type: ignore + +default_registry = registry() + + +class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry): + # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values + __slots__ = ("__weakref__",) + __tablename__: ClassVar[Union[str, Callable[..., str]]] + __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] + __name__: ClassVar[str] + metadata: ClassVar[MetaData] + + class Config: + orm_mode = True + + def __new__(cls, *args, **kwargs) -> Any: + new_object = super().__new__(cls) + # SQLAlchemy doesn't call __init__ on the base class + # Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html + # Set __fields_set__ here, that would have been set when calling __init__ + # in the Pydantic model so that when SQLAlchemy sets attributes that are + # added (e.g. when querying from DB) to the __fields_set__, this already exists + object.__setattr__(new_object, "__fields_set__", set()) + return new_object + + def __init__(__pydantic_self__, **data: Any) -> None: + # Uses something other than `self` the first arg to allow "self" as a + # settable attribute + if TYPE_CHECKING: + __pydantic_self__.__dict__: Dict[str, Any] = {} + __pydantic_self__.__fields_set__: Set[str] = set() + values, fields_set, validation_error = validate_model( + __pydantic_self__.__class__, data + ) + # Only raise errors if not a SQLModel model + if ( + not getattr(__pydantic_self__.__config__, "table", False) + and validation_error + ): + raise validation_error + # Do not set values as in Pydantic, pass them through setattr, so SQLAlchemy + # can handle them + # object.__setattr__(__pydantic_self__, '__dict__', values) + object.__setattr__(__pydantic_self__, "__fields_set__", fields_set) + for key, value in values.items(): + setattr(__pydantic_self__, key, value) + non_pydantic_keys = data.keys() - values.keys() + for key in non_pydantic_keys: + if key in __pydantic_self__.__sqlmodel_relationships__: + setattr(__pydantic_self__, key, data[key]) + + def __setattr__(self, name: str, value: Any) -> None: + if name in {"_sa_instance_state"}: + self.__dict__[name] = value + return + else: + # Set in SQLAlchemy, before Pydantic to trigger events and updates + if getattr(self.__config__, "table", False): + if is_instrumented(self, name): + set_attribute(self, name, value) + # Set in Pydantic model to trigger possible validation changes, only for + # non relationship values + if name not in self.__sqlmodel_relationships__: + super().__setattr__(name, value) + + @classmethod + def from_orm(cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None): + # Duplicated from Pydantic + if not cls.__config__.orm_mode: + raise ConfigError( + "You must have the config attribute orm_mode=True to use from_orm" + ) + obj = {ROOT_KEY: obj} if cls.__custom_root_type__ else cls._decompose_class(obj) + # SQLModel, support update dict + if update is not None: + obj = {**obj, **update} + # End SQLModel support dict + if not getattr(cls.__config__, "table", False): + # If not table, normal Pydantic code + m = cls.__new__(cls) + else: + # If table, create the new instance normally to make SQLAlchemy create + # the _sa_instance_state attribute + m = cls() + values, fields_set, validation_error = validate_model(cls, obj) + if validation_error: + raise validation_error + # Updated to trigger SQLAlchemy internal handling + if not getattr(cls.__config__, "table", False): + object.__setattr__(m, "__dict__", values) + else: + for key, value in values.items(): + setattr(m, key, value) + # Continue with standard Pydantic logic + object.__setattr__(m, "__fields_set__", fields_set) + m._init_private_attributes() + return m + + @classmethod + def parse_obj( + cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None + ) -> "SQLModel": + obj = cls._enforce_dict_if_root(obj) + # SQLModel, support update dict + if update is not None: + obj = {**obj, **update} + # End SQLModel support dict + return super().parse_obj(obj) + + def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: + # Don't show SQLAlchemy private attributes + return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_sa_")] + + # From Pydantic, override to enforce validation with dict + @classmethod + def validate(cls: Type["SQLModel"], value: Any) -> "SQLModel": + if isinstance(value, cls): + return value.copy() if cls.__config__.copy_on_model_validation else value + + value = cls._enforce_dict_if_root(value) + if isinstance(value, dict): + values, fields_set, validation_error = validate_model(cls, value) + if validation_error: + raise validation_error + model = cls(**values) + # Reset fields set, this would have been done in Pydantic in __init__ + object.__setattr__(model, "__fields_set__", fields_set) + return model + elif cls.__config__.orm_mode: + return cls.from_orm(value) + elif cls.__custom_root_type__: + return cls.parse_obj(value) + else: + try: + value_as_dict = dict(value) + except (TypeError, ValueError) as e: + raise DictError() from e + return cls(**value_as_dict) + + # From Pydantic, override to only show keys from fields, omit SQLAlchemy attributes + def _calculate_keys( + self, + include: Optional[Mapping[Union[int, str], Any]], + exclude: Optional[Mapping[Union[int, str], Any]], + exclude_unset: bool, + update: Optional[Dict[str, Any]] = None, + ) -> Optional[AbstractSet[str]]: + if include is None and exclude is None and exclude_unset is False: + # Original in Pydantic: + # return None + # Updated to not return SQLAlchemy attributes + # Do not include relationships as that would easily lead to infinite + # recursion, or traversing the whole database + return self.__fields__.keys() # | self.__sqlmodel_relationships__.keys() + + keys: AbstractSet[str] + if exclude_unset: + keys = self.__fields_set__.copy() + else: + # Original in Pydantic: + # keys = self.__dict__.keys() + # Updated to not return SQLAlchemy attributes + # Do not include relationships as that would easily lead to infinite + # recursion, or traversing the whole database + keys = self.__fields__.keys() # | self.__sqlmodel_relationships__.keys() + + if include is not None: + keys &= include.keys() + + if update: + keys -= update.keys() + + if exclude: + keys -= {k for k, v in exclude.items() if ValueItems.is_true(v)} + + return keys + + @declared_attr # type: ignore + def __tablename__(cls) -> str: + return cls.__name__.lower() diff --git a/sqlmodel/orm/__init__.py b/sqlmodel/orm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py new file mode 100644 index 0000000..a96544e --- /dev/null +++ b/sqlmodel/orm/session.py @@ -0,0 +1,135 @@ +from typing import Any, Mapping, Optional, Sequence, TypeVar, Union, overload + +from sqlalchemy import util +from sqlalchemy.orm import Query as _Query +from sqlalchemy.orm import Session as _Session +from sqlalchemy.sql.base import Executable as _Executable +from sqlmodel.sql.expression import Select, SelectOfScalar +from typing_extensions import Literal + +from ..engine.result import Result, ScalarResult +from ..sql.base import Executable + +_T = TypeVar("_T") + + +class Session(_Session): + @overload + def exec( + self, + statement: Select[_T], + *, + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Mapping[str, Any]] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + **kw: Any, + ) -> Union[Result[_T]]: + ... + + @overload + def exec( + self, + statement: SelectOfScalar[_T], + *, + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Mapping[str, Any]] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + **kw: Any, + ) -> Union[ScalarResult[_T]]: + ... + + def exec( + self, + statement: Union[Select[_T], SelectOfScalar[_T], Executable[_T]], + *, + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Mapping[str, Any]] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + **kw: Any, + ) -> Union[Result[_T], ScalarResult[_T]]: + results = super().execute( + statement, + params=params, + execution_options=execution_options, # type: ignore + bind_arguments=bind_arguments, + _parent_execute_state=_parent_execute_state, + _add_event=_add_event, + **kw, + ) + if isinstance(statement, SelectOfScalar): + return results.scalars() # type: ignore + return results # type: ignore + + def execute( + self, + statement: _Executable, + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Mapping[str, Any]] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + **kw: Any, + ) -> Result[Any]: + """ + 🚨 You probably want to use `session.exec()` instead of `session.execute()`. + + This is the original SQLAlchemy `session.execute()` method that returns objects + of type `Row`, and that you have to call `scalars()` to get the model objects. + + For example: + + ```Python + heroes = session.execute(select(Hero)).scalars().all() + ``` + + instead you could use `exec()`: + + ```Python + heroes = session.exec(select(Hero)).all() + ``` + """ + return super().execute( # type: ignore + statement, + params=params, + execution_options=execution_options, # type: ignore + bind_arguments=bind_arguments, + _parent_execute_state=_parent_execute_state, + _add_event=_add_event, + **kw, + ) + + def query(self, *entities: Any, **kwargs: Any) -> "_Query[Any]": + """ + 🚨 You probably want to use `session.exec()` instead of `session.query()`. + + `session.exec()` is SQLModel's own short version with increased type + annotations. + + Or otherwise you might want to use `session.execute()` instead of + `session.query()`. + """ + return super().query(*entities, **kwargs) + + def get( + self, + entity: _T, + ident: Any, + options: Optional[Sequence[Any]] = None, + populate_existing: bool = False, + with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None, + identity_token: Optional[Any] = None, + ) -> _T: + return super().get( + entity, + ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + ) diff --git a/sqlmodel/pool/__init__.py b/sqlmodel/pool/__init__.py new file mode 100644 index 0000000..20bb952 --- /dev/null +++ b/sqlmodel/pool/__init__.py @@ -0,0 +1 @@ +from sqlalchemy.pool import StaticPool as StaticPool # noqa: F401 diff --git a/sqlmodel/sql/__init__.py b/sqlmodel/sql/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sqlmodel/sql/base.py b/sqlmodel/sql/base.py new file mode 100644 index 0000000..129e4d4 --- /dev/null +++ b/sqlmodel/sql/base.py @@ -0,0 +1,11 @@ +from typing import Generic, TypeVar + +from sqlalchemy.sql.base import Executable as _Executable + +_T = TypeVar("_T") + + +class Executable(_Executable, Generic[_T]): + def __init__(self, *args, **kwargs): + self.__dict__["_exec_options"] = kwargs.pop("_exec_options", None) + super(_Executable, self).__init__(*args, **kwargs) diff --git a/sqlmodel/sql/expression.py b/sqlmodel/sql/expression.py new file mode 100644 index 0000000..e8a922e --- /dev/null +++ b/sqlmodel/sql/expression.py @@ -0,0 +1,459 @@ +# WARNING: do not modify this code, it is generated by expression.py.jinja2 + +import sys +from datetime import datetime +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Mapping, + Sequence, + Tuple, + Type, + TypeVar, + Union, + cast, + overload, +) +from uuid import UUID + +from sqlalchemy import Column +from sqlalchemy.orm import InstrumentedAttribute +from sqlalchemy.sql.elements import ColumnClause +from sqlalchemy.sql.expression import Select as _Select + +_TSelect = TypeVar("_TSelect") + +# Workaround Generics incompatibility in Python 3.6 +# Ref: https://github.com/python/typing/issues/449#issuecomment-316061322 +if sys.version_info.minor >= 7: + + class Select(_Select, Generic[_TSelect]): + pass + + # This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different + # purpose. This is the same as a normal SQLAlchemy Select class where there's only one + # entity, so the result will be converted to a scalar by default. This way writing + # for loops on the results will feel natural. + class SelectOfScalar(_Select, Generic[_TSelect]): + pass + + +else: + from typing import GenericMeta # type: ignore + + class GenericSelectMeta(GenericMeta, _Select.__class__): # type: ignore + pass + + class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): + pass + + class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): + pass + + # Cast them for editors to work correctly, from several tricks tried, this works + # for both VS Code and PyCharm + Select = cast("Select", _Py36Select) # type: ignore + SelectOfScalar = cast("SelectOfScalar", _Py36SelectOfScalar) # type: ignore + + +if TYPE_CHECKING: # pragma: no cover + from ..main import SQLModel + +# Generated TypeVars start + + +_TScalar_0 = TypeVar( + "_TScalar_0", + Column, + Sequence, + Mapping, + UUID, + datetime, + float, + int, + bool, + bytes, + str, + None, +) + +_TModel_0 = TypeVar("_TModel_0", bound="SQLModel") + + +_TScalar_1 = TypeVar( + "_TScalar_1", + Column, + Sequence, + Mapping, + UUID, + datetime, + float, + int, + bool, + bytes, + str, + None, +) + +_TModel_1 = TypeVar("_TModel_1", bound="SQLModel") + + +_TScalar_2 = TypeVar( + "_TScalar_2", + Column, + Sequence, + Mapping, + UUID, + datetime, + float, + int, + bool, + bytes, + str, + None, +) + +_TModel_2 = TypeVar("_TModel_2", bound="SQLModel") + + +_TScalar_3 = TypeVar( + "_TScalar_3", + Column, + Sequence, + Mapping, + UUID, + datetime, + float, + int, + bool, + bytes, + str, + None, +) + +_TModel_3 = TypeVar("_TModel_3", bound="SQLModel") + + +# Generated TypeVars end + + +@overload +def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]: # type: ignore + ... + + +@overload +def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: # type: ignore + ... + + +# Generated overloads start + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: _TScalar_1, + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TScalar_1]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: Type[_TModel_1], + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TModel_1]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: _TScalar_1, + **kw: Any, +) -> Select[Tuple[_TModel_0, _TScalar_1]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: Type[_TModel_1], + **kw: Any, +) -> Select[Tuple[_TModel_0, _TModel_1]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: _TScalar_1, + entity_2: _TScalar_2, + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: _TScalar_1, + entity_2: Type[_TModel_2], + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: Type[_TModel_1], + entity_2: _TScalar_2, + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: Type[_TModel_1], + entity_2: Type[_TModel_2], + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: _TScalar_1, + entity_2: _TScalar_2, + **kw: Any, +) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: _TScalar_1, + entity_2: Type[_TModel_2], + **kw: Any, +) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: Type[_TModel_1], + entity_2: _TScalar_2, + **kw: Any, +) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: Type[_TModel_1], + entity_2: Type[_TModel_2], + **kw: Any, +) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: _TScalar_1, + entity_2: _TScalar_2, + entity_3: _TScalar_3, + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TScalar_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: _TScalar_1, + entity_2: _TScalar_2, + entity_3: Type[_TModel_3], + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TModel_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: _TScalar_1, + entity_2: Type[_TModel_2], + entity_3: _TScalar_3, + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TScalar_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: _TScalar_1, + entity_2: Type[_TModel_2], + entity_3: Type[_TModel_3], + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TModel_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: Type[_TModel_1], + entity_2: _TScalar_2, + entity_3: _TScalar_3, + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TScalar_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: Type[_TModel_1], + entity_2: _TScalar_2, + entity_3: Type[_TModel_3], + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TModel_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: Type[_TModel_1], + entity_2: Type[_TModel_2], + entity_3: _TScalar_3, + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TScalar_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: Type[_TModel_1], + entity_2: Type[_TModel_2], + entity_3: Type[_TModel_3], + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TModel_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: _TScalar_1, + entity_2: _TScalar_2, + entity_3: _TScalar_3, + **kw: Any, +) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TScalar_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: _TScalar_1, + entity_2: _TScalar_2, + entity_3: Type[_TModel_3], + **kw: Any, +) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TModel_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: _TScalar_1, + entity_2: Type[_TModel_2], + entity_3: _TScalar_3, + **kw: Any, +) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TScalar_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: _TScalar_1, + entity_2: Type[_TModel_2], + entity_3: Type[_TModel_3], + **kw: Any, +) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TModel_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: Type[_TModel_1], + entity_2: _TScalar_2, + entity_3: _TScalar_3, + **kw: Any, +) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TScalar_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: Type[_TModel_1], + entity_2: _TScalar_2, + entity_3: Type[_TModel_3], + **kw: Any, +) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TModel_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: Type[_TModel_1], + entity_2: Type[_TModel_2], + entity_3: _TScalar_3, + **kw: Any, +) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TScalar_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: Type[_TModel_1], + entity_2: Type[_TModel_2], + entity_3: Type[_TModel_3], + **kw: Any, +) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TModel_3]]: + ... + + +# Generated overloads end + + +def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: + if len(entities) == 1: + return SelectOfScalar._create(*entities, **kw) # type: ignore + return Select._create(*entities, **kw) # type: ignore + + +# TODO: add several @overload from Python types to SQLAlchemy equivalents +def col(column_expression: Any) -> ColumnClause: + if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)): + raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}") + return column_expression diff --git a/sqlmodel/sql/expression.py.jinja2 b/sqlmodel/sql/expression.py.jinja2 new file mode 100644 index 0000000..b39d636 --- /dev/null +++ b/sqlmodel/sql/expression.py.jinja2 @@ -0,0 +1,119 @@ +import sys +from datetime import datetime +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Mapping, + Sequence, + Tuple, + Type, + TypeVar, + Union, + cast, + overload, +) +from uuid import UUID + +from sqlalchemy import Column +from sqlalchemy.orm import InstrumentedAttribute +from sqlalchemy.sql.elements import ColumnClause +from sqlalchemy.sql.expression import Select as _Select + +_TSelect = TypeVar("_TSelect") + +# Workaround Generics incompatibility in Python 3.6 +# Ref: https://github.com/python/typing/issues/449#issuecomment-316061322 +if sys.version_info.minor >= 7: + + class Select(_Select, Generic[_TSelect]): + pass + + # This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different + # purpose. This is the same as a normal SQLAlchemy Select class where there's only one + # entity, so the result will be converted to a scalar by default. This way writing + # for loops on the results will feel natural. + class SelectOfScalar(_Select, Generic[_TSelect]): + pass + + +else: + from typing import GenericMeta # type: ignore + + class GenericSelectMeta(GenericMeta, _Select.__class__): # type: ignore + pass + + class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): + pass + + class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): + pass + + # Cast them for editors to work correctly, from several tricks tried, this works + # for both VS Code and PyCharm + Select = cast("Select", _Py36Select) # type: ignore + SelectOfScalar = cast("SelectOfScalar", _Py36SelectOfScalar) # type: ignore + + +if TYPE_CHECKING: # pragma: no cover + from ..main import SQLModel + +# Generated TypeVars start + +{% for i in range(number_of_types) %} +_TScalar_{{ i }} = TypeVar( + "_TScalar_{{ i }}", + Column, + Sequence, + Mapping, + UUID, + datetime, + float, + int, + bool, + bytes, + str, + None, +) + +_TModel_{{ i }} = TypeVar("_TModel_{{ i }}", bound="SQLModel") + +{% endfor %} + +# Generated TypeVars end + +@overload +def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]: # type: ignore + ... + + +@overload +def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: # type: ignore + ... + + +# Generated overloads start + +{% for signature in signatures %} + +@overload +def select( # type: ignore + {% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %}**kw: Any, + ) -> Select[Tuple[{%for ret in signature[1] %}{{ ret }} {% if not loop.last %}, {% endif %}{% endfor %}]]: + ... + +{% endfor %} + +# Generated overloads end + +def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: + if len(entities) == 1: + return SelectOfScalar._create(*entities, **kw) # type: ignore + return Select._create(*entities, **kw) + + +# TODO: add several @overload from Python types to SQLAlchemy equivalents +def col(column_expression: Any) -> ColumnClause: + if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)): + raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}") + return column_expression diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py new file mode 100644 index 0000000..e7b77b8 --- /dev/null +++ b/sqlmodel/sql/sqltypes.py @@ -0,0 +1,60 @@ +import uuid +from typing import Any, cast + +from sqlalchemy import types +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.engine.interfaces import Dialect +from sqlalchemy.types import CHAR, TypeDecorator + + +class AutoString(types.TypeDecorator): + + impl = types.String + cache_ok = True + mysql_default_length = 255 + + def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]": + impl = cast(types.String, self.impl) + if impl.length is None and dialect.name == "mysql": + return dialect.type_descriptor(types.String(self.mysql_default_length)) # type: ignore + return super().load_dialect_impl(dialect) + + +# Reference form SQLAlchemy docs: https://docs.sqlalchemy.org/en/14/core/custom_types.html#backend-agnostic-guid-type +# with small modifications +class GUID(TypeDecorator): + """Platform-independent GUID type. + + Uses PostgreSQL's UUID type, otherwise uses + CHAR(32), storing as stringified hex values. + + """ + + impl = CHAR + cache_ok = True + + def load_dialect_impl(self, dialect): + if dialect.name == "postgresql": + return dialect.type_descriptor(UUID()) + else: + return dialect.type_descriptor(CHAR(32)) + + def process_bind_param(self, value, dialect): + if value is None: + return value + elif dialect.name == "postgresql": + return str(value) + else: + if not isinstance(value, uuid.UUID): + return f"{uuid.UUID(value).int:x}" + else: + # hexstring + return f"{value.int:x}" + + def process_result_value(self, value, dialect): + if value is None: + return value + else: + if not isinstance(value, uuid.UUID): + value = uuid.UUID(value) + return value