Add SQLModel core code

This commit is contained in:
Sebastián Ramírez 2021-08-24 14:41:53 +02:00
commit fcff2050e6
17 changed files with 1867 additions and 0 deletions

139
sqlmodel/__init__.py Normal file
View File

@ -0,0 +1,139 @@
__version__ = "0.0.1"
# Re-export from SQLAlchemy
from sqlalchemy.engine import create_mock_engine as create_mock_engine
from sqlalchemy.engine import engine_from_config as engine_from_config
from sqlalchemy.inspection import inspect as inspect
from sqlalchemy.schema import BLANK_SCHEMA as BLANK_SCHEMA
from sqlalchemy.schema import CheckConstraint as CheckConstraint
from sqlalchemy.schema import Column as Column
from sqlalchemy.schema import ColumnDefault as ColumnDefault
from sqlalchemy.schema import Computed as Computed
from sqlalchemy.schema import Constraint as Constraint
from sqlalchemy.schema import DDL as DDL
from sqlalchemy.schema import DefaultClause as DefaultClause
from sqlalchemy.schema import FetchedValue as FetchedValue
from sqlalchemy.schema import ForeignKey as ForeignKey
from sqlalchemy.schema import ForeignKeyConstraint as ForeignKeyConstraint
from sqlalchemy.schema import Identity as Identity
from sqlalchemy.schema import Index as Index
from sqlalchemy.schema import MetaData as MetaData
from sqlalchemy.schema import PrimaryKeyConstraint as PrimaryKeyConstraint
from sqlalchemy.schema import Sequence as Sequence
from sqlalchemy.schema import Table as Table
from sqlalchemy.schema import ThreadLocalMetaData as ThreadLocalMetaData
from sqlalchemy.schema import UniqueConstraint as UniqueConstraint
from sqlalchemy.sql import alias as alias
from sqlalchemy.sql import all_ as all_
from sqlalchemy.sql import and_ as and_
from sqlalchemy.sql import any_ as any_
from sqlalchemy.sql import asc as asc
from sqlalchemy.sql import between as between
from sqlalchemy.sql import bindparam as bindparam
from sqlalchemy.sql import case as case
from sqlalchemy.sql import cast as cast
from sqlalchemy.sql import collate as collate
from sqlalchemy.sql import column as column
from sqlalchemy.sql import delete as delete
from sqlalchemy.sql import desc as desc
from sqlalchemy.sql import distinct as distinct
from sqlalchemy.sql import except_ as except_
from sqlalchemy.sql import except_all as except_all
from sqlalchemy.sql import exists as exists
from sqlalchemy.sql import extract as extract
from sqlalchemy.sql import false as false
from sqlalchemy.sql import func as func
from sqlalchemy.sql import funcfilter as funcfilter
from sqlalchemy.sql import insert as insert
from sqlalchemy.sql import intersect as intersect
from sqlalchemy.sql import intersect_all as intersect_all
from sqlalchemy.sql import join as join
from sqlalchemy.sql import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT
from sqlalchemy.sql import (
LABEL_STYLE_DISAMBIGUATE_ONLY as LABEL_STYLE_DISAMBIGUATE_ONLY,
)
from sqlalchemy.sql import LABEL_STYLE_NONE as LABEL_STYLE_NONE
from sqlalchemy.sql import (
LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL,
)
from sqlalchemy.sql import lambda_stmt as lambda_stmt
from sqlalchemy.sql import lateral as lateral
from sqlalchemy.sql import literal as literal
from sqlalchemy.sql import literal_column as literal_column
from sqlalchemy.sql import modifier as modifier
from sqlalchemy.sql import not_ as not_
from sqlalchemy.sql import null as null
from sqlalchemy.sql import nulls_first as nulls_first
from sqlalchemy.sql import nulls_last as nulls_last
from sqlalchemy.sql import nullsfirst as nullsfirst
from sqlalchemy.sql import nullslast as nullslast
from sqlalchemy.sql import or_ as or_
from sqlalchemy.sql import outerjoin as outerjoin
from sqlalchemy.sql import outparam as outparam
from sqlalchemy.sql import over as over
from sqlalchemy.sql import subquery as subquery
from sqlalchemy.sql import table as table
from sqlalchemy.sql import tablesample as tablesample
from sqlalchemy.sql import text as text
from sqlalchemy.sql import true as true
from sqlalchemy.sql import tuple_ as tuple_
from sqlalchemy.sql import type_coerce as type_coerce
from sqlalchemy.sql import union as union
from sqlalchemy.sql import union_all as union_all
from sqlalchemy.sql import update as update
from sqlalchemy.sql import values as values
from sqlalchemy.sql import within_group as within_group
from sqlalchemy.types import ARRAY as ARRAY
from sqlalchemy.types import BIGINT as BIGINT
from sqlalchemy.types import BigInteger as BigInteger
from sqlalchemy.types import BINARY as BINARY
from sqlalchemy.types import BLOB as BLOB
from sqlalchemy.types import BOOLEAN as BOOLEAN
from sqlalchemy.types import Boolean as Boolean
from sqlalchemy.types import CHAR as CHAR
from sqlalchemy.types import CLOB as CLOB
from sqlalchemy.types import DATE as DATE
from sqlalchemy.types import Date as Date
from sqlalchemy.types import DATETIME as DATETIME
from sqlalchemy.types import DateTime as DateTime
from sqlalchemy.types import DECIMAL as DECIMAL
from sqlalchemy.types import Enum as Enum
from sqlalchemy.types import FLOAT as FLOAT
from sqlalchemy.types import Float as Float
from sqlalchemy.types import INT as INT
from sqlalchemy.types import INTEGER as INTEGER
from sqlalchemy.types import Integer as Integer
from sqlalchemy.types import Interval as Interval
from sqlalchemy.types import JSON as JSON
from sqlalchemy.types import LargeBinary as LargeBinary
from sqlalchemy.types import NCHAR as NCHAR
from sqlalchemy.types import NUMERIC as NUMERIC
from sqlalchemy.types import Numeric as Numeric
from sqlalchemy.types import NVARCHAR as NVARCHAR
from sqlalchemy.types import PickleType as PickleType
from sqlalchemy.types import REAL as REAL
from sqlalchemy.types import SMALLINT as SMALLINT
from sqlalchemy.types import SmallInteger as SmallInteger
from sqlalchemy.types import String as String
from sqlalchemy.types import TEXT as TEXT
from sqlalchemy.types import Text as Text
from sqlalchemy.types import TIME as TIME
from sqlalchemy.types import Time as Time
from sqlalchemy.types import TIMESTAMP as TIMESTAMP
from sqlalchemy.types import TypeDecorator as TypeDecorator
from sqlalchemy.types import Unicode as Unicode
from sqlalchemy.types import UnicodeText as UnicodeText
from sqlalchemy.types import VARBINARY as VARBINARY
from sqlalchemy.types import VARCHAR as VARCHAR
# Extensions and modifications of SQLAlchemy in SQLModel
from .engine.create import create_engine as create_engine
from .orm.session import Session as Session
from .sql.expression import select as select
from .sql.expression import col as col
from .sql.sqltypes import AutoString as AutoString
# Export SQLModel specifics (equivalent to Pydantic)
from .main import SQLModel as SQLModel
from .main import Field as Field
from .main import Relationship as Relationship

32
sqlmodel/default.py Normal file
View File

@ -0,0 +1,32 @@
from typing import Any, TypeVar
class _DefaultPlaceholder:
"""
You shouldn't use this class directly.
It's used internally to recognize when a default value has been overwritten, even
if the overriden default value was truthy.
"""
def __init__(self, value: Any):
self.value = value
def __bool__(self) -> bool:
return bool(self.value)
def __eq__(self, o: object) -> bool:
return isinstance(o, _DefaultPlaceholder) and o.value == self.value
_TDefaultType = TypeVar("_TDefaultType")
def Default(value: _TDefaultType) -> _TDefaultType:
"""
You shouldn't use this function directly.
It's used internally to recognize when a default value has been overwritten, even
if the overriden default value was truthy.
"""
return _DefaultPlaceholder(value) # type: ignore

View File

139
sqlmodel/engine/create.py Normal file
View File

@ -0,0 +1,139 @@
import json
import sqlite3
from typing import Any, Callable, Dict, List, Optional, Type, Union
from sqlalchemy import create_engine as _create_engine
from sqlalchemy.engine.url import URL
from sqlalchemy.future import Engine as _FutureEngine
from sqlalchemy.pool import Pool
from typing_extensions import Literal, TypedDict
from ..default import Default, _DefaultPlaceholder
# Types defined in sqlalchemy2-stubs, but can't be imported, so re-define here
_Debug = Literal["debug"]
_IsolationLevel = Literal[
"SERIALIZABLE",
"REPEATABLE READ",
"READ COMMITTED",
"READ UNCOMMITTED",
"AUTOCOMMIT",
]
_ParamStyle = Literal["qmark", "numeric", "named", "format", "pyformat"]
_ResetOnReturn = Literal["rollback", "commit"]
class _SQLiteConnectArgs(TypedDict, total=False):
timeout: float
detect_types: Any
isolation_level: Optional[Literal["DEFERRED", "IMMEDIATE", "EXCLUSIVE"]]
check_same_thread: bool
factory: Type[sqlite3.Connection]
cached_statements: int
uri: bool
_ConnectArgs = Union[_SQLiteConnectArgs, Dict[str, Any]]
# Re-define create_engine to have by default future=True, and assume that's what is used
# Also show the default values used for each parameter, but don't set them unless
# explicitly passed as arguments by the user to prevent errors. E.g. SQLite doesn't
# support pool connection arguments.
def create_engine(
url: Union[str, URL],
*,
connect_args: _ConnectArgs = Default({}), # type: ignore
echo: Union[bool, _Debug] = Default(False),
echo_pool: Union[bool, _Debug] = Default(False),
enable_from_linting: bool = Default(True),
encoding: str = Default("utf-8"),
execution_options: Dict[Any, Any] = Default({}),
future: bool = True,
hide_parameters: bool = Default(False),
implicit_returning: bool = Default(True),
isolation_level: Optional[_IsolationLevel] = Default(None),
json_deserializer: Callable[..., Any] = Default(json.loads),
json_serializer: Callable[..., Any] = Default(json.dumps),
label_length: Optional[int] = Default(None),
logging_name: Optional[str] = Default(None),
max_identifier_length: Optional[int] = Default(None),
max_overflow: int = Default(10),
module: Optional[Any] = Default(None),
paramstyle: Optional[_ParamStyle] = Default(None),
pool: Optional[Pool] = Default(None),
poolclass: Optional[Type[Pool]] = Default(None),
pool_logging_name: Optional[str] = Default(None),
pool_pre_ping: bool = Default(False),
pool_size: int = Default(5),
pool_recycle: int = Default(-1),
pool_reset_on_return: Optional[_ResetOnReturn] = Default("rollback"),
pool_timeout: float = Default(30),
pool_use_lifo: bool = Default(False),
plugins: Optional[List[str]] = Default(None),
query_cache_size: Optional[int] = Default(None),
**kwargs: Any,
) -> _FutureEngine:
current_kwargs: Dict[str, Any] = {
"future": future,
}
if not isinstance(echo, _DefaultPlaceholder):
current_kwargs["echo"] = echo
if not isinstance(echo_pool, _DefaultPlaceholder):
current_kwargs["echo_pool"] = echo_pool
if not isinstance(enable_from_linting, _DefaultPlaceholder):
current_kwargs["enable_from_linting"] = enable_from_linting
if not isinstance(connect_args, _DefaultPlaceholder):
current_kwargs["connect_args"] = connect_args
if not isinstance(encoding, _DefaultPlaceholder):
current_kwargs["encoding"] = encoding
if not isinstance(execution_options, _DefaultPlaceholder):
current_kwargs["execution_options"] = execution_options
if not isinstance(hide_parameters, _DefaultPlaceholder):
current_kwargs["hide_parameters"] = hide_parameters
if not isinstance(implicit_returning, _DefaultPlaceholder):
current_kwargs["implicit_returning"] = implicit_returning
if not isinstance(isolation_level, _DefaultPlaceholder):
current_kwargs["isolation_level"] = isolation_level
if not isinstance(json_deserializer, _DefaultPlaceholder):
current_kwargs["json_deserializer"] = json_deserializer
if not isinstance(json_serializer, _DefaultPlaceholder):
current_kwargs["json_serializer"] = json_serializer
if not isinstance(label_length, _DefaultPlaceholder):
current_kwargs["label_length"] = label_length
if not isinstance(logging_name, _DefaultPlaceholder):
current_kwargs["logging_name"] = logging_name
if not isinstance(max_identifier_length, _DefaultPlaceholder):
current_kwargs["max_identifier_length"] = max_identifier_length
if not isinstance(max_overflow, _DefaultPlaceholder):
current_kwargs["max_overflow"] = max_overflow
if not isinstance(module, _DefaultPlaceholder):
current_kwargs["module"] = module
if not isinstance(paramstyle, _DefaultPlaceholder):
current_kwargs["paramstyle"] = paramstyle
if not isinstance(pool, _DefaultPlaceholder):
current_kwargs["pool"] = pool
if not isinstance(poolclass, _DefaultPlaceholder):
current_kwargs["poolclass"] = poolclass
if not isinstance(pool_logging_name, _DefaultPlaceholder):
current_kwargs["pool_logging_name"] = pool_logging_name
if not isinstance(pool_pre_ping, _DefaultPlaceholder):
current_kwargs["pool_pre_ping"] = pool_pre_ping
if not isinstance(pool_size, _DefaultPlaceholder):
current_kwargs["pool_size"] = pool_size
if not isinstance(pool_recycle, _DefaultPlaceholder):
current_kwargs["pool_recycle"] = pool_recycle
if not isinstance(pool_reset_on_return, _DefaultPlaceholder):
current_kwargs["pool_reset_on_return"] = pool_reset_on_return
if not isinstance(pool_timeout, _DefaultPlaceholder):
current_kwargs["pool_timeout"] = pool_timeout
if not isinstance(pool_use_lifo, _DefaultPlaceholder):
current_kwargs["pool_use_lifo"] = pool_use_lifo
if not isinstance(plugins, _DefaultPlaceholder):
current_kwargs["plugins"] = plugins
if not isinstance(query_cache_size, _DefaultPlaceholder):
current_kwargs["query_cache_size"] = query_cache_size
current_kwargs.update(kwargs)
return _create_engine(url, **current_kwargs)

79
sqlmodel/engine/result.py Normal file
View File

@ -0,0 +1,79 @@
from typing import Generic, Iterator, List, Optional, TypeVar
from sqlalchemy.engine.result import Result as _Result
from sqlalchemy.engine.result import ScalarResult as _ScalarResult
_T = TypeVar("_T")
class ScalarResult(_ScalarResult, Generic[_T]):
def all(self) -> List[_T]:
return super().all()
def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]:
return super().partitions(size)
def fetchall(self) -> List[_T]:
return super().fetchall()
def fetchmany(self, size: Optional[int] = None) -> List[_T]:
return super().fetchmany(size)
def __iter__(self) -> Iterator[_T]:
return super().__iter__()
def __next__(self) -> _T:
return super().__next__()
def first(self) -> Optional[_T]:
return super().first()
def one_or_none(self) -> Optional[_T]:
return super().one_or_none()
def one(self) -> _T:
return super().one()
class Result(_Result, Generic[_T]):
def scalars(self, index: int = 0) -> ScalarResult[_T]:
return super().scalars(index) # type: ignore
def __iter__(self) -> Iterator[_T]: # type: ignore
return super().__iter__() # type: ignore
def __next__(self) -> _T: # type: ignore
return super().__next__() # type: ignore
def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]: # type: ignore
return super().partitions(size) # type: ignore
def fetchall(self) -> List[_T]: # type: ignore
return super().fetchall() # type: ignore
def fetchone(self) -> Optional[_T]: # type: ignore
return super().fetchone() # type: ignore
def fetchmany(self, size: Optional[int] = None) -> List[_T]: # type: ignore
return super().fetchmany() # type: ignore
def all(self) -> List[_T]: # type: ignore
return super().all() # type: ignore
def first(self) -> Optional[_T]: # type: ignore
return super().first() # type: ignore
def one_or_none(self) -> Optional[_T]: # type: ignore
return super().one_or_none() # type: ignore
def scalar_one(self) -> _T:
return super().scalar_one() # type: ignore
def scalar_one_or_none(self) -> Optional[_T]:
return super().scalar_one_or_none() # type: ignore
def one(self) -> _T: # type: ignore
return super().one() # type: ignore
def scalar(self) -> Optional[_T]:
return super().scalar() # type: ignore

0
sqlmodel/ext/__init__.py Normal file
View File

View File

View File

@ -0,0 +1,62 @@
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union
from sqlalchemy import util
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
from sqlalchemy.ext.asyncio import engine
from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
from sqlalchemy.util.concurrency import greenlet_spawn
from sqlmodel.sql.base import Executable
from ...engine.result import ScalarResult
from ...orm.session import Session
from ...sql.expression import Select
_T = TypeVar("_T")
class AsyncSession(_AsyncSession):
sync_session: Session
def __init__(
self,
bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
**kw,
):
# All the same code of the original AsyncSession
kw["future"] = True
if bind:
self.bind = bind
bind = engine._get_sync_engine_or_connection(bind) # type: ignore
if binds:
self.binds = binds
binds = {
key: engine._get_sync_engine_or_connection(b) # type: ignore
for key, b in binds.items()
}
self.sync_session = self._proxied = self._assign_proxied( # type: ignore
Session(bind=bind, binds=binds, **kw) # type: ignore
)
async def exec(
self,
statement: Union[Select[_T], Executable[_T]],
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[Any, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
**kw: Any,
) -> ScalarResult[_T]:
# TODO: the documentation says execution_options accepts a dict, but only
# util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore
return await greenlet_spawn( # type: ignore
self.sync_session.exec,
statement,
params=params,
execution_options=execution_options,
bind_arguments=bind_arguments,
**kw,
)

631
sqlmodel/main.py Normal file
View File

@ -0,0 +1,631 @@
import ipaddress
import uuid
import weakref
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from enum import Enum
from pathlib import Path
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Callable,
ClassVar,
Dict,
ForwardRef,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
)
from pydantic import BaseModel
from pydantic.errors import ConfigError, DictError
from pydantic.fields import FieldInfo as PydanticFieldInfo
from pydantic.fields import ModelField, Undefined, UndefinedType
from pydantic.main import BaseConfig, ModelMetaclass, validate_model
from pydantic.typing import NoArgAnyCallable, resolve_annotations
from pydantic.utils import ROOT_KEY, Representation, ValueItems
from sqlalchemy import (
Boolean,
Column,
Date,
DateTime,
Float,
ForeignKey,
Integer,
Interval,
Numeric,
inspect,
)
from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship
from sqlalchemy.orm.attributes import set_attribute
from sqlalchemy.orm.decl_api import DeclarativeMeta
from sqlalchemy.orm.instrumentation import is_instrumented
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.sqltypes import LargeBinary, Time
from .sql.sqltypes import GUID, AutoString
_T = TypeVar("_T")
def __dataclass_transform__(
*,
eq_default: bool = True,
order_default: bool = False,
kw_only_default: bool = False,
field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
) -> Callable[[_T], _T]:
return lambda a: a
class FieldInfo(PydanticFieldInfo):
def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
primary_key = kwargs.pop("primary_key", False)
nullable = kwargs.pop("nullable", Undefined)
foreign_key = kwargs.pop("foreign_key", Undefined)
index = kwargs.pop("index", Undefined)
sa_column = kwargs.pop("sa_column", Undefined)
sa_column_args = kwargs.pop("sa_column_args", Undefined)
sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined)
if sa_column is not Undefined:
if sa_column_args is not Undefined:
raise RuntimeError(
"Passing sa_column_args is not supported when "
"also passing a sa_column"
)
if sa_column_kwargs is not Undefined:
raise RuntimeError(
"Passing sa_column_kwargs is not supported when "
"also passing a sa_column"
)
super().__init__(default=default, **kwargs)
self.primary_key = primary_key
self.nullable = nullable
self.foreign_key = foreign_key
self.index = index
self.sa_column = sa_column
self.sa_column_args = sa_column_args
self.sa_column_kwargs = sa_column_kwargs
class RelationshipInfo(Representation):
def __init__(
self,
*,
back_populates: Optional[str] = None,
link_model: Optional[Any] = None,
sa_relationship: Optional[RelationshipProperty] = None,
sa_relationship_args: Optional[Sequence[Any]] = None,
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
) -> None:
if sa_relationship is not None:
if sa_relationship_args is not None:
raise RuntimeError(
"Passing sa_relationship_args is not supported when "
"also passing a sa_relationship"
)
if sa_relationship_kwargs is not None:
raise RuntimeError(
"Passing sa_relationship_kwargs is not supported when "
"also passing a sa_relationship"
)
self.back_populates = back_populates
self.link_model = link_model
self.sa_relationship = sa_relationship
self.sa_relationship_args = sa_relationship_args
self.sa_relationship_kwargs = sa_relationship_kwargs
def Field(
default: Any = Undefined,
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: str = None,
title: str = None,
description: str = None,
exclude: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
include: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
const: bool = None,
gt: float = None,
ge: float = None,
lt: float = None,
le: float = None,
multiple_of: float = None,
min_items: int = None,
max_items: int = None,
min_length: int = None,
max_length: int = None,
allow_mutation: bool = True,
regex: str = None,
primary_key: bool = False,
foreign_key: Optional[Any] = None,
nullable: Union[bool, UndefinedType] = Undefined,
index: Union[bool, UndefinedType] = Undefined,
sa_column: Union[Column, UndefinedType] = Undefined,
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
current_schema_extra = schema_extra or {}
field_info = FieldInfo(
default,
default_factory=default_factory,
alias=alias,
title=title,
description=description,
exclude=exclude,
include=include,
const=const,
gt=gt,
ge=ge,
lt=lt,
le=le,
multiple_of=multiple_of,
min_items=min_items,
max_items=max_items,
min_length=min_length,
max_length=max_length,
allow_mutation=allow_mutation,
regex=regex,
primary_key=primary_key,
foreign_key=foreign_key,
nullable=nullable,
index=index,
sa_column=sa_column,
sa_column_args=sa_column_args,
sa_column_kwargs=sa_column_kwargs,
**current_schema_extra,
)
field_info._validate()
return field_info
def Relationship(
*,
back_populates: Optional[str] = None,
link_model: Optional[Any] = None,
sa_relationship: Optional[RelationshipProperty] = None,
sa_relationship_args: Optional[Sequence[Any]] = None,
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
) -> Any:
relationship_info = RelationshipInfo(
back_populates=back_populates,
link_model=link_model,
sa_relationship=sa_relationship,
sa_relationship_args=sa_relationship_args,
sa_relationship_kwargs=sa_relationship_kwargs,
)
return relationship_info
@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
__sqlmodel_relationships__: Dict[str, RelationshipInfo]
__config__: Type[BaseConfig]
__fields__: Dict[str, ModelField]
# Replicate SQLAlchemy
def __setattr__(cls, name: str, value: Any) -> None:
if getattr(cls.__config__, "table", False): # type: ignore
DeclarativeMeta.__setattr__(cls, name, value)
else:
super().__setattr__(name, value)
def __delattr__(cls, name: str) -> None:
if getattr(cls.__config__, "table", False): # type: ignore
DeclarativeMeta.__delattr__(cls, name)
else:
super().__delattr__(name)
# From Pydantic
def __new__(cls, name, bases, class_dict: dict, **kwargs) -> Any:
relationships: Dict[str, RelationshipInfo] = {}
dict_for_pydantic = {}
original_annotations = resolve_annotations(
class_dict.get("__annotations__", {}), class_dict.get("__module__", None)
)
pydantic_annotations = {}
relationship_annotations = {}
for k, v in class_dict.items():
if isinstance(v, RelationshipInfo):
relationships[k] = v
else:
dict_for_pydantic[k] = v
for k, v in original_annotations.items():
if k in relationships:
relationship_annotations[k] = v
else:
pydantic_annotations[k] = v
dict_used = {
**dict_for_pydantic,
"__weakref__": None,
"__sqlmodel_relationships__": relationships,
"__annotations__": pydantic_annotations,
}
# Duplicate logic from Pydantic to filter config kwargs because if they are
# passed directly including the registry Pydantic will pass them over to the
# superclass causing an error
allowed_config_kwargs: Set[str] = {
key
for key in dir(BaseConfig)
if not (
key.startswith("__") and key.endswith("__")
) # skip dunder methods and attributes
}
pydantic_kwargs = kwargs.copy()
config_kwargs = {
key: pydantic_kwargs.pop(key)
for key in pydantic_kwargs.keys() & allowed_config_kwargs
}
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
new_cls.__annotations__ = {
**relationship_annotations,
**pydantic_annotations,
**new_cls.__annotations__,
}
def get_config(name: str) -> Any:
config_class_value = getattr(new_cls.__config__, name, Undefined)
if config_class_value is not Undefined:
return config_class_value
kwarg_value = kwargs.get(name, Undefined)
if kwarg_value is not Undefined:
return kwarg_value
return Undefined
config_table = get_config("table")
if config_table is True:
# If it was passed by kwargs, ensure it's also set in config
new_cls.__config__.table = config_table
for k, v in new_cls.__fields__.items():
col = get_column_from_field(v)
setattr(new_cls, k, col)
# Set a config flag to tell FastAPI that this should be read with a field
# in orm_mode instead of preemptively converting it to a dict.
# This could be done by reading new_cls.__config__.table in FastAPI, but
# that's very specific about SQLModel, so let's have another config that
# other future tools based on Pydantic can use.
new_cls.__config__.read_with_orm_mode = True
config_registry = get_config("registry")
if config_registry is not Undefined:
config_registry = cast(registry, config_registry)
# If it was passed by kwargs, ensure it's also set in config
new_cls.__config__.registry = config_table
setattr(new_cls, "_sa_registry", config_registry)
setattr(new_cls, "metadata", config_registry.metadata)
setattr(new_cls, "__abstract__", True)
return new_cls
# Override SQLAlchemy, allow both SQLAlchemy and plain Pydantic models
def __init__(
cls, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any], **kw: Any
) -> None:
# Only one of the base classes (or the current one) should be a table model
# this allows FastAPI cloning a SQLModel for the response_model without
# trying to create a new SQLAlchemy, for a new table, with the same name, that
# triggers an error
base_is_table = False
for base in bases:
config = getattr(base, "__config__")
if config and getattr(config, "table", False):
base_is_table = True
break
if getattr(cls.__config__, "table", False) and not base_is_table:
dict_used = dict_.copy()
for field_name, field_value in cls.__fields__.items():
dict_used[field_name] = get_column_from_field(field_value)
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
if rel_info.sa_relationship:
# There's a SQLAlchemy relationship declared, that takes precedence
# over anything else, use that and continue with the next attribute
dict_used[rel_name] = rel_info.sa_relationship
continue
ann = cls.__annotations__[rel_name]
temp_field = ModelField.infer(
name=rel_name,
value=rel_info,
annotation=ann,
class_validators=None,
config=BaseConfig,
)
relationship_to = temp_field.type_
if isinstance(temp_field.type_, ForwardRef):
relationship_to = temp_field.type_.__forward_arg__
rel_kwargs: Dict[str, Any] = {}
if rel_info.back_populates:
rel_kwargs["back_populates"] = rel_info.back_populates
if rel_info.link_model:
ins = inspect(rel_info.link_model)
local_table = getattr(ins, "local_table")
if local_table is None:
raise RuntimeError(
"Couldn't find the secondary table for "
f"model {rel_info.link_model}"
)
rel_kwargs["secondary"] = local_table
rel_args: List[Any] = []
if rel_info.sa_relationship_args:
rel_args.extend(rel_info.sa_relationship_args)
if rel_info.sa_relationship_kwargs:
rel_kwargs.update(rel_info.sa_relationship_kwargs)
rel_value: RelationshipProperty = relationship(
relationship_to, *rel_args, **rel_kwargs
)
dict_used[rel_name] = rel_value
DeclarativeMeta.__init__(cls, classname, bases, dict_used, **kw)
else:
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
def get_sqlachemy_type(field: ModelField) -> Any:
if issubclass(field.type_, str):
if field.field_info.max_length:
return AutoString(length=field.field_info.max_length)
return AutoString
if issubclass(field.type_, float):
return Float
if issubclass(field.type_, bool):
return Boolean
if issubclass(field.type_, int):
return Integer
if issubclass(field.type_, datetime):
return DateTime
if issubclass(field.type_, date):
return Date
if issubclass(field.type_, timedelta):
return Interval
if issubclass(field.type_, time):
return Time
if issubclass(field.type_, Enum):
return Enum
if issubclass(field.type_, bytes):
return LargeBinary
if issubclass(field.type_, Decimal):
return Numeric
if issubclass(field.type_, ipaddress.IPv4Address):
return AutoString
if issubclass(field.type_, ipaddress.IPv4Network):
return AutoString
if issubclass(field.type_, ipaddress.IPv6Address):
return AutoString
if issubclass(field.type_, ipaddress.IPv6Network):
return AutoString
if issubclass(field.type_, Path):
return AutoString
if issubclass(field.type_, uuid.UUID):
return GUID
def get_column_from_field(field: ModelField) -> Column:
sa_column = getattr(field.field_info, "sa_column", Undefined)
if isinstance(sa_column, Column):
return sa_column
sa_type = get_sqlachemy_type(field)
primary_key = getattr(field.field_info, "primary_key", False)
nullable = not field.required
index = getattr(field.field_info, "index", Undefined)
if index is Undefined:
index = True
if hasattr(field.field_info, "nullable"):
field_nullable = getattr(field.field_info, "nullable")
if field_nullable != Undefined:
nullable = field_nullable
args = []
foreign_key = getattr(field.field_info, "foreign_key", None)
if foreign_key:
args.append(ForeignKey(foreign_key))
kwargs = {
"primary_key": primary_key,
"nullable": nullable,
"index": index,
}
sa_default = Undefined
if field.field_info.default_factory:
sa_default = field.field_info.default_factory
elif field.field_info.default is not Undefined:
sa_default = field.field_info.default
if sa_default is not Undefined:
kwargs["default"] = sa_default
sa_column_args = getattr(field.field_info, "sa_column_args", Undefined)
if sa_column_args is not Undefined:
args.extend(list(cast(Sequence, sa_column_args)))
sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined)
if sa_column_kwargs is not Undefined:
kwargs.update(cast(dict, sa_column_kwargs))
return Column(sa_type, *args, **kwargs)
class_registry = weakref.WeakValueDictionary() # type: ignore
default_registry = registry()
class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
__slots__ = ("__weakref__",)
__tablename__: ClassVar[Union[str, Callable[..., str]]]
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]]
__name__: ClassVar[str]
metadata: ClassVar[MetaData]
class Config:
orm_mode = True
def __new__(cls, *args, **kwargs) -> Any:
new_object = super().__new__(cls)
# SQLAlchemy doesn't call __init__ on the base class
# Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html
# Set __fields_set__ here, that would have been set when calling __init__
# in the Pydantic model so that when SQLAlchemy sets attributes that are
# added (e.g. when querying from DB) to the __fields_set__, this already exists
object.__setattr__(new_object, "__fields_set__", set())
return new_object
def __init__(__pydantic_self__, **data: Any) -> None:
# Uses something other than `self` the first arg to allow "self" as a
# settable attribute
if TYPE_CHECKING:
__pydantic_self__.__dict__: Dict[str, Any] = {}
__pydantic_self__.__fields_set__: Set[str] = set()
values, fields_set, validation_error = validate_model(
__pydantic_self__.__class__, data
)
# Only raise errors if not a SQLModel model
if (
not getattr(__pydantic_self__.__config__, "table", False)
and validation_error
):
raise validation_error
# Do not set values as in Pydantic, pass them through setattr, so SQLAlchemy
# can handle them
# object.__setattr__(__pydantic_self__, '__dict__', values)
object.__setattr__(__pydantic_self__, "__fields_set__", fields_set)
for key, value in values.items():
setattr(__pydantic_self__, key, value)
non_pydantic_keys = data.keys() - values.keys()
for key in non_pydantic_keys:
if key in __pydantic_self__.__sqlmodel_relationships__:
setattr(__pydantic_self__, key, data[key])
def __setattr__(self, name: str, value: Any) -> None:
if name in {"_sa_instance_state"}:
self.__dict__[name] = value
return
else:
# Set in SQLAlchemy, before Pydantic to trigger events and updates
if getattr(self.__config__, "table", False):
if is_instrumented(self, name):
set_attribute(self, name, value)
# Set in Pydantic model to trigger possible validation changes, only for
# non relationship values
if name not in self.__sqlmodel_relationships__:
super().__setattr__(name, value)
@classmethod
def from_orm(cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None):
# Duplicated from Pydantic
if not cls.__config__.orm_mode:
raise ConfigError(
"You must have the config attribute orm_mode=True to use from_orm"
)
obj = {ROOT_KEY: obj} if cls.__custom_root_type__ else cls._decompose_class(obj)
# SQLModel, support update dict
if update is not None:
obj = {**obj, **update}
# End SQLModel support dict
if not getattr(cls.__config__, "table", False):
# If not table, normal Pydantic code
m = cls.__new__(cls)
else:
# If table, create the new instance normally to make SQLAlchemy create
# the _sa_instance_state attribute
m = cls()
values, fields_set, validation_error = validate_model(cls, obj)
if validation_error:
raise validation_error
# Updated to trigger SQLAlchemy internal handling
if not getattr(cls.__config__, "table", False):
object.__setattr__(m, "__dict__", values)
else:
for key, value in values.items():
setattr(m, key, value)
# Continue with standard Pydantic logic
object.__setattr__(m, "__fields_set__", fields_set)
m._init_private_attributes()
return m
@classmethod
def parse_obj(
cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None
) -> "SQLModel":
obj = cls._enforce_dict_if_root(obj)
# SQLModel, support update dict
if update is not None:
obj = {**obj, **update}
# End SQLModel support dict
return super().parse_obj(obj)
def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]:
# Don't show SQLAlchemy private attributes
return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_sa_")]
# From Pydantic, override to enforce validation with dict
@classmethod
def validate(cls: Type["SQLModel"], value: Any) -> "SQLModel":
if isinstance(value, cls):
return value.copy() if cls.__config__.copy_on_model_validation else value
value = cls._enforce_dict_if_root(value)
if isinstance(value, dict):
values, fields_set, validation_error = validate_model(cls, value)
if validation_error:
raise validation_error
model = cls(**values)
# Reset fields set, this would have been done in Pydantic in __init__
object.__setattr__(model, "__fields_set__", fields_set)
return model
elif cls.__config__.orm_mode:
return cls.from_orm(value)
elif cls.__custom_root_type__:
return cls.parse_obj(value)
else:
try:
value_as_dict = dict(value)
except (TypeError, ValueError) as e:
raise DictError() from e
return cls(**value_as_dict)
# From Pydantic, override to only show keys from fields, omit SQLAlchemy attributes
def _calculate_keys(
self,
include: Optional[Mapping[Union[int, str], Any]],
exclude: Optional[Mapping[Union[int, str], Any]],
exclude_unset: bool,
update: Optional[Dict[str, Any]] = None,
) -> Optional[AbstractSet[str]]:
if include is None and exclude is None and exclude_unset is False:
# Original in Pydantic:
# return None
# Updated to not return SQLAlchemy attributes
# Do not include relationships as that would easily lead to infinite
# recursion, or traversing the whole database
return self.__fields__.keys() # | self.__sqlmodel_relationships__.keys()
keys: AbstractSet[str]
if exclude_unset:
keys = self.__fields_set__.copy()
else:
# Original in Pydantic:
# keys = self.__dict__.keys()
# Updated to not return SQLAlchemy attributes
# Do not include relationships as that would easily lead to infinite
# recursion, or traversing the whole database
keys = self.__fields__.keys() # | self.__sqlmodel_relationships__.keys()
if include is not None:
keys &= include.keys()
if update:
keys -= update.keys()
if exclude:
keys -= {k for k, v in exclude.items() if ValueItems.is_true(v)}
return keys
@declared_attr # type: ignore
def __tablename__(cls) -> str:
return cls.__name__.lower()

0
sqlmodel/orm/__init__.py Normal file
View File

135
sqlmodel/orm/session.py Normal file
View File

@ -0,0 +1,135 @@
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union, overload
from sqlalchemy import util
from sqlalchemy.orm import Query as _Query
from sqlalchemy.orm import Session as _Session
from sqlalchemy.sql.base import Executable as _Executable
from sqlmodel.sql.expression import Select, SelectOfScalar
from typing_extensions import Literal
from ..engine.result import Result, ScalarResult
from ..sql.base import Executable
_T = TypeVar("_T")
class Session(_Session):
@overload
def exec(
self,
statement: Select[_T],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> Union[Result[_T]]:
...
@overload
def exec(
self,
statement: SelectOfScalar[_T],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> Union[ScalarResult[_T]]:
...
def exec(
self,
statement: Union[Select[_T], SelectOfScalar[_T], Executable[_T]],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> Union[Result[_T], ScalarResult[_T]]:
results = super().execute(
statement,
params=params,
execution_options=execution_options, # type: ignore
bind_arguments=bind_arguments,
_parent_execute_state=_parent_execute_state,
_add_event=_add_event,
**kw,
)
if isinstance(statement, SelectOfScalar):
return results.scalars() # type: ignore
return results # type: ignore
def execute(
self,
statement: _Executable,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> Result[Any]:
"""
🚨 You probably want to use `session.exec()` instead of `session.execute()`.
This is the original SQLAlchemy `session.execute()` method that returns objects
of type `Row`, and that you have to call `scalars()` to get the model objects.
For example:
```Python
heroes = session.execute(select(Hero)).scalars().all()
```
instead you could use `exec()`:
```Python
heroes = session.exec(select(Hero)).all()
```
"""
return super().execute( # type: ignore
statement,
params=params,
execution_options=execution_options, # type: ignore
bind_arguments=bind_arguments,
_parent_execute_state=_parent_execute_state,
_add_event=_add_event,
**kw,
)
def query(self, *entities: Any, **kwargs: Any) -> "_Query[Any]":
"""
🚨 You probably want to use `session.exec()` instead of `session.query()`.
`session.exec()` is SQLModel's own short version with increased type
annotations.
Or otherwise you might want to use `session.execute()` instead of
`session.query()`.
"""
return super().query(*entities, **kwargs)
def get(
self,
entity: _T,
ident: Any,
options: Optional[Sequence[Any]] = None,
populate_existing: bool = False,
with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None,
identity_token: Optional[Any] = None,
) -> _T:
return super().get(
entity,
ident,
options=options,
populate_existing=populate_existing,
with_for_update=with_for_update,
identity_token=identity_token,
)

View File

@ -0,0 +1 @@
from sqlalchemy.pool import StaticPool as StaticPool # noqa: F401

0
sqlmodel/sql/__init__.py Normal file
View File

11
sqlmodel/sql/base.py Normal file
View File

@ -0,0 +1,11 @@
from typing import Generic, TypeVar
from sqlalchemy.sql.base import Executable as _Executable
_T = TypeVar("_T")
class Executable(_Executable, Generic[_T]):
def __init__(self, *args, **kwargs):
self.__dict__["_exec_options"] = kwargs.pop("_exec_options", None)
super(_Executable, self).__init__(*args, **kwargs)

459
sqlmodel/sql/expression.py Normal file
View File

@ -0,0 +1,459 @@
# WARNING: do not modify this code, it is generated by expression.py.jinja2
import sys
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
Generic,
Mapping,
Sequence,
Tuple,
Type,
TypeVar,
Union,
cast,
overload,
)
from uuid import UUID
from sqlalchemy import Column
from sqlalchemy.orm import InstrumentedAttribute
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.expression import Select as _Select
_TSelect = TypeVar("_TSelect")
# Workaround Generics incompatibility in Python 3.6
# Ref: https://github.com/python/typing/issues/449#issuecomment-316061322
if sys.version_info.minor >= 7:
class Select(_Select, Generic[_TSelect]):
pass
# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different
# purpose. This is the same as a normal SQLAlchemy Select class where there's only one
# entity, so the result will be converted to a scalar by default. This way writing
# for loops on the results will feel natural.
class SelectOfScalar(_Select, Generic[_TSelect]):
pass
else:
from typing import GenericMeta # type: ignore
class GenericSelectMeta(GenericMeta, _Select.__class__): # type: ignore
pass
class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
pass
class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
pass
# Cast them for editors to work correctly, from several tricks tried, this works
# for both VS Code and PyCharm
Select = cast("Select", _Py36Select) # type: ignore
SelectOfScalar = cast("SelectOfScalar", _Py36SelectOfScalar) # type: ignore
if TYPE_CHECKING: # pragma: no cover
from ..main import SQLModel
# Generated TypeVars start
_TScalar_0 = TypeVar(
"_TScalar_0",
Column,
Sequence,
Mapping,
UUID,
datetime,
float,
int,
bool,
bytes,
str,
None,
)
_TModel_0 = TypeVar("_TModel_0", bound="SQLModel")
_TScalar_1 = TypeVar(
"_TScalar_1",
Column,
Sequence,
Mapping,
UUID,
datetime,
float,
int,
bool,
bytes,
str,
None,
)
_TModel_1 = TypeVar("_TModel_1", bound="SQLModel")
_TScalar_2 = TypeVar(
"_TScalar_2",
Column,
Sequence,
Mapping,
UUID,
datetime,
float,
int,
bool,
bytes,
str,
None,
)
_TModel_2 = TypeVar("_TModel_2", bound="SQLModel")
_TScalar_3 = TypeVar(
"_TScalar_3",
Column,
Sequence,
Mapping,
UUID,
datetime,
float,
int,
bool,
bytes,
str,
None,
)
_TModel_3 = TypeVar("_TModel_3", bound="SQLModel")
# Generated TypeVars end
@overload
def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]: # type: ignore
...
@overload
def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: # type: ignore
...
# Generated overloads start
@overload
def select( # type: ignore
entity_0: _TScalar_0,
entity_1: _TScalar_1,
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TScalar_1]]:
...
@overload
def select( # type: ignore
entity_0: _TScalar_0,
entity_1: Type[_TModel_1],
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TModel_1]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
entity_1: _TScalar_1,
**kw: Any,
) -> Select[Tuple[_TModel_0, _TScalar_1]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
entity_1: Type[_TModel_1],
**kw: Any,
) -> Select[Tuple[_TModel_0, _TModel_1]]:
...
@overload
def select( # type: ignore
entity_0: _TScalar_0,
entity_1: _TScalar_1,
entity_2: _TScalar_2,
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2]]:
...
@overload
def select( # type: ignore
entity_0: _TScalar_0,
entity_1: _TScalar_1,
entity_2: Type[_TModel_2],
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2]]:
...
@overload
def select( # type: ignore
entity_0: _TScalar_0,
entity_1: Type[_TModel_1],
entity_2: _TScalar_2,
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2]]:
...
@overload
def select( # type: ignore
entity_0: _TScalar_0,
entity_1: Type[_TModel_1],
entity_2: Type[_TModel_2],
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
entity_1: _TScalar_1,
entity_2: _TScalar_2,
**kw: Any,
) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
entity_1: _TScalar_1,
entity_2: Type[_TModel_2],
**kw: Any,
) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
entity_1: Type[_TModel_1],
entity_2: _TScalar_2,
**kw: Any,
) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
entity_1: Type[_TModel_1],
entity_2: Type[_TModel_2],
**kw: Any,
) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2]]:
...
@overload
def select( # type: ignore
entity_0: _TScalar_0,
entity_1: _TScalar_1,
entity_2: _TScalar_2,
entity_3: _TScalar_3,
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TScalar_3]]:
...
@overload
def select( # type: ignore
entity_0: _TScalar_0,
entity_1: _TScalar_1,
entity_2: _TScalar_2,
entity_3: Type[_TModel_3],
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TModel_3]]:
...
@overload
def select( # type: ignore
entity_0: _TScalar_0,
entity_1: _TScalar_1,
entity_2: Type[_TModel_2],
entity_3: _TScalar_3,
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TScalar_3]]:
...
@overload
def select( # type: ignore
entity_0: _TScalar_0,
entity_1: _TScalar_1,
entity_2: Type[_TModel_2],
entity_3: Type[_TModel_3],
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TModel_3]]:
...
@overload
def select( # type: ignore
entity_0: _TScalar_0,
entity_1: Type[_TModel_1],
entity_2: _TScalar_2,
entity_3: _TScalar_3,
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TScalar_3]]:
...
@overload
def select( # type: ignore
entity_0: _TScalar_0,
entity_1: Type[_TModel_1],
entity_2: _TScalar_2,
entity_3: Type[_TModel_3],
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TModel_3]]:
...
@overload
def select( # type: ignore
entity_0: _TScalar_0,
entity_1: Type[_TModel_1],
entity_2: Type[_TModel_2],
entity_3: _TScalar_3,
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TScalar_3]]:
...
@overload
def select( # type: ignore
entity_0: _TScalar_0,
entity_1: Type[_TModel_1],
entity_2: Type[_TModel_2],
entity_3: Type[_TModel_3],
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TModel_3]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
entity_1: _TScalar_1,
entity_2: _TScalar_2,
entity_3: _TScalar_3,
**kw: Any,
) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TScalar_3]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
entity_1: _TScalar_1,
entity_2: _TScalar_2,
entity_3: Type[_TModel_3],
**kw: Any,
) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TModel_3]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
entity_1: _TScalar_1,
entity_2: Type[_TModel_2],
entity_3: _TScalar_3,
**kw: Any,
) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TScalar_3]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
entity_1: _TScalar_1,
entity_2: Type[_TModel_2],
entity_3: Type[_TModel_3],
**kw: Any,
) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TModel_3]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
entity_1: Type[_TModel_1],
entity_2: _TScalar_2,
entity_3: _TScalar_3,
**kw: Any,
) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TScalar_3]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
entity_1: Type[_TModel_1],
entity_2: _TScalar_2,
entity_3: Type[_TModel_3],
**kw: Any,
) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TModel_3]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
entity_1: Type[_TModel_1],
entity_2: Type[_TModel_2],
entity_3: _TScalar_3,
**kw: Any,
) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TScalar_3]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
entity_1: Type[_TModel_1],
entity_2: Type[_TModel_2],
entity_3: Type[_TModel_3],
**kw: Any,
) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TModel_3]]:
...
# Generated overloads end
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:
if len(entities) == 1:
return SelectOfScalar._create(*entities, **kw) # type: ignore
return Select._create(*entities, **kw) # type: ignore
# TODO: add several @overload from Python types to SQLAlchemy equivalents
def col(column_expression: Any) -> ColumnClause:
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
return column_expression

View File

@ -0,0 +1,119 @@
import sys
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
Generic,
Mapping,
Sequence,
Tuple,
Type,
TypeVar,
Union,
cast,
overload,
)
from uuid import UUID
from sqlalchemy import Column
from sqlalchemy.orm import InstrumentedAttribute
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.expression import Select as _Select
_TSelect = TypeVar("_TSelect")
# Workaround Generics incompatibility in Python 3.6
# Ref: https://github.com/python/typing/issues/449#issuecomment-316061322
if sys.version_info.minor >= 7:
class Select(_Select, Generic[_TSelect]):
pass
# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different
# purpose. This is the same as a normal SQLAlchemy Select class where there's only one
# entity, so the result will be converted to a scalar by default. This way writing
# for loops on the results will feel natural.
class SelectOfScalar(_Select, Generic[_TSelect]):
pass
else:
from typing import GenericMeta # type: ignore
class GenericSelectMeta(GenericMeta, _Select.__class__): # type: ignore
pass
class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
pass
class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
pass
# Cast them for editors to work correctly, from several tricks tried, this works
# for both VS Code and PyCharm
Select = cast("Select", _Py36Select) # type: ignore
SelectOfScalar = cast("SelectOfScalar", _Py36SelectOfScalar) # type: ignore
if TYPE_CHECKING: # pragma: no cover
from ..main import SQLModel
# Generated TypeVars start
{% for i in range(number_of_types) %}
_TScalar_{{ i }} = TypeVar(
"_TScalar_{{ i }}",
Column,
Sequence,
Mapping,
UUID,
datetime,
float,
int,
bool,
bytes,
str,
None,
)
_TModel_{{ i }} = TypeVar("_TModel_{{ i }}", bound="SQLModel")
{% endfor %}
# Generated TypeVars end
@overload
def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]: # type: ignore
...
@overload
def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: # type: ignore
...
# Generated overloads start
{% for signature in signatures %}
@overload
def select( # type: ignore
{% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %}**kw: Any,
) -> Select[Tuple[{%for ret in signature[1] %}{{ ret }} {% if not loop.last %}, {% endif %}{% endfor %}]]:
...
{% endfor %}
# Generated overloads end
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:
if len(entities) == 1:
return SelectOfScalar._create(*entities, **kw) # type: ignore
return Select._create(*entities, **kw)
# TODO: add several @overload from Python types to SQLAlchemy equivalents
def col(column_expression: Any) -> ColumnClause:
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
return column_expression

60
sqlmodel/sql/sqltypes.py Normal file
View File

@ -0,0 +1,60 @@
import uuid
from typing import Any, cast
from sqlalchemy import types
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.types import CHAR, TypeDecorator
class AutoString(types.TypeDecorator):
impl = types.String
cache_ok = True
mysql_default_length = 255
def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]":
impl = cast(types.String, self.impl)
if impl.length is None and dialect.name == "mysql":
return dialect.type_descriptor(types.String(self.mysql_default_length)) # type: ignore
return super().load_dialect_impl(dialect)
# Reference form SQLAlchemy docs: https://docs.sqlalchemy.org/en/14/core/custom_types.html#backend-agnostic-guid-type
# with small modifications
class GUID(TypeDecorator):
"""Platform-independent GUID type.
Uses PostgreSQL's UUID type, otherwise uses
CHAR(32), storing as stringified hex values.
"""
impl = CHAR
cache_ok = True
def load_dialect_impl(self, dialect):
if dialect.name == "postgresql":
return dialect.type_descriptor(UUID())
else:
return dialect.type_descriptor(CHAR(32))
def process_bind_param(self, value, dialect):
if value is None:
return value
elif dialect.name == "postgresql":
return str(value)
else:
if not isinstance(value, uuid.UUID):
return f"{uuid.UUID(value).int:x}"
else:
# hexstring
return f"{value.int:x}"
def process_result_value(self, value, dialect):
if value is None:
return value
else:
if not isinstance(value, uuid.UUID):
value = uuid.UUID(value)
return value