mirror of
https://github.com/PaiGramTeam/sqlmodel.git
synced 2024-11-24 17:19:33 +00:00
✨ Add SQLModel core code
This commit is contained in:
commit
fcff2050e6
139
sqlmodel/__init__.py
Normal file
139
sqlmodel/__init__.py
Normal file
@ -0,0 +1,139 @@
|
||||
__version__ = "0.0.1"
|
||||
|
||||
# Re-export from SQLAlchemy
|
||||
from sqlalchemy.engine import create_mock_engine as create_mock_engine
|
||||
from sqlalchemy.engine import engine_from_config as engine_from_config
|
||||
from sqlalchemy.inspection import inspect as inspect
|
||||
from sqlalchemy.schema import BLANK_SCHEMA as BLANK_SCHEMA
|
||||
from sqlalchemy.schema import CheckConstraint as CheckConstraint
|
||||
from sqlalchemy.schema import Column as Column
|
||||
from sqlalchemy.schema import ColumnDefault as ColumnDefault
|
||||
from sqlalchemy.schema import Computed as Computed
|
||||
from sqlalchemy.schema import Constraint as Constraint
|
||||
from sqlalchemy.schema import DDL as DDL
|
||||
from sqlalchemy.schema import DefaultClause as DefaultClause
|
||||
from sqlalchemy.schema import FetchedValue as FetchedValue
|
||||
from sqlalchemy.schema import ForeignKey as ForeignKey
|
||||
from sqlalchemy.schema import ForeignKeyConstraint as ForeignKeyConstraint
|
||||
from sqlalchemy.schema import Identity as Identity
|
||||
from sqlalchemy.schema import Index as Index
|
||||
from sqlalchemy.schema import MetaData as MetaData
|
||||
from sqlalchemy.schema import PrimaryKeyConstraint as PrimaryKeyConstraint
|
||||
from sqlalchemy.schema import Sequence as Sequence
|
||||
from sqlalchemy.schema import Table as Table
|
||||
from sqlalchemy.schema import ThreadLocalMetaData as ThreadLocalMetaData
|
||||
from sqlalchemy.schema import UniqueConstraint as UniqueConstraint
|
||||
from sqlalchemy.sql import alias as alias
|
||||
from sqlalchemy.sql import all_ as all_
|
||||
from sqlalchemy.sql import and_ as and_
|
||||
from sqlalchemy.sql import any_ as any_
|
||||
from sqlalchemy.sql import asc as asc
|
||||
from sqlalchemy.sql import between as between
|
||||
from sqlalchemy.sql import bindparam as bindparam
|
||||
from sqlalchemy.sql import case as case
|
||||
from sqlalchemy.sql import cast as cast
|
||||
from sqlalchemy.sql import collate as collate
|
||||
from sqlalchemy.sql import column as column
|
||||
from sqlalchemy.sql import delete as delete
|
||||
from sqlalchemy.sql import desc as desc
|
||||
from sqlalchemy.sql import distinct as distinct
|
||||
from sqlalchemy.sql import except_ as except_
|
||||
from sqlalchemy.sql import except_all as except_all
|
||||
from sqlalchemy.sql import exists as exists
|
||||
from sqlalchemy.sql import extract as extract
|
||||
from sqlalchemy.sql import false as false
|
||||
from sqlalchemy.sql import func as func
|
||||
from sqlalchemy.sql import funcfilter as funcfilter
|
||||
from sqlalchemy.sql import insert as insert
|
||||
from sqlalchemy.sql import intersect as intersect
|
||||
from sqlalchemy.sql import intersect_all as intersect_all
|
||||
from sqlalchemy.sql import join as join
|
||||
from sqlalchemy.sql import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT
|
||||
from sqlalchemy.sql import (
|
||||
LABEL_STYLE_DISAMBIGUATE_ONLY as LABEL_STYLE_DISAMBIGUATE_ONLY,
|
||||
)
|
||||
from sqlalchemy.sql import LABEL_STYLE_NONE as LABEL_STYLE_NONE
|
||||
from sqlalchemy.sql import (
|
||||
LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL,
|
||||
)
|
||||
from sqlalchemy.sql import lambda_stmt as lambda_stmt
|
||||
from sqlalchemy.sql import lateral as lateral
|
||||
from sqlalchemy.sql import literal as literal
|
||||
from sqlalchemy.sql import literal_column as literal_column
|
||||
from sqlalchemy.sql import modifier as modifier
|
||||
from sqlalchemy.sql import not_ as not_
|
||||
from sqlalchemy.sql import null as null
|
||||
from sqlalchemy.sql import nulls_first as nulls_first
|
||||
from sqlalchemy.sql import nulls_last as nulls_last
|
||||
from sqlalchemy.sql import nullsfirst as nullsfirst
|
||||
from sqlalchemy.sql import nullslast as nullslast
|
||||
from sqlalchemy.sql import or_ as or_
|
||||
from sqlalchemy.sql import outerjoin as outerjoin
|
||||
from sqlalchemy.sql import outparam as outparam
|
||||
from sqlalchemy.sql import over as over
|
||||
from sqlalchemy.sql import subquery as subquery
|
||||
from sqlalchemy.sql import table as table
|
||||
from sqlalchemy.sql import tablesample as tablesample
|
||||
from sqlalchemy.sql import text as text
|
||||
from sqlalchemy.sql import true as true
|
||||
from sqlalchemy.sql import tuple_ as tuple_
|
||||
from sqlalchemy.sql import type_coerce as type_coerce
|
||||
from sqlalchemy.sql import union as union
|
||||
from sqlalchemy.sql import union_all as union_all
|
||||
from sqlalchemy.sql import update as update
|
||||
from sqlalchemy.sql import values as values
|
||||
from sqlalchemy.sql import within_group as within_group
|
||||
from sqlalchemy.types import ARRAY as ARRAY
|
||||
from sqlalchemy.types import BIGINT as BIGINT
|
||||
from sqlalchemy.types import BigInteger as BigInteger
|
||||
from sqlalchemy.types import BINARY as BINARY
|
||||
from sqlalchemy.types import BLOB as BLOB
|
||||
from sqlalchemy.types import BOOLEAN as BOOLEAN
|
||||
from sqlalchemy.types import Boolean as Boolean
|
||||
from sqlalchemy.types import CHAR as CHAR
|
||||
from sqlalchemy.types import CLOB as CLOB
|
||||
from sqlalchemy.types import DATE as DATE
|
||||
from sqlalchemy.types import Date as Date
|
||||
from sqlalchemy.types import DATETIME as DATETIME
|
||||
from sqlalchemy.types import DateTime as DateTime
|
||||
from sqlalchemy.types import DECIMAL as DECIMAL
|
||||
from sqlalchemy.types import Enum as Enum
|
||||
from sqlalchemy.types import FLOAT as FLOAT
|
||||
from sqlalchemy.types import Float as Float
|
||||
from sqlalchemy.types import INT as INT
|
||||
from sqlalchemy.types import INTEGER as INTEGER
|
||||
from sqlalchemy.types import Integer as Integer
|
||||
from sqlalchemy.types import Interval as Interval
|
||||
from sqlalchemy.types import JSON as JSON
|
||||
from sqlalchemy.types import LargeBinary as LargeBinary
|
||||
from sqlalchemy.types import NCHAR as NCHAR
|
||||
from sqlalchemy.types import NUMERIC as NUMERIC
|
||||
from sqlalchemy.types import Numeric as Numeric
|
||||
from sqlalchemy.types import NVARCHAR as NVARCHAR
|
||||
from sqlalchemy.types import PickleType as PickleType
|
||||
from sqlalchemy.types import REAL as REAL
|
||||
from sqlalchemy.types import SMALLINT as SMALLINT
|
||||
from sqlalchemy.types import SmallInteger as SmallInteger
|
||||
from sqlalchemy.types import String as String
|
||||
from sqlalchemy.types import TEXT as TEXT
|
||||
from sqlalchemy.types import Text as Text
|
||||
from sqlalchemy.types import TIME as TIME
|
||||
from sqlalchemy.types import Time as Time
|
||||
from sqlalchemy.types import TIMESTAMP as TIMESTAMP
|
||||
from sqlalchemy.types import TypeDecorator as TypeDecorator
|
||||
from sqlalchemy.types import Unicode as Unicode
|
||||
from sqlalchemy.types import UnicodeText as UnicodeText
|
||||
from sqlalchemy.types import VARBINARY as VARBINARY
|
||||
from sqlalchemy.types import VARCHAR as VARCHAR
|
||||
|
||||
# Extensions and modifications of SQLAlchemy in SQLModel
|
||||
from .engine.create import create_engine as create_engine
|
||||
from .orm.session import Session as Session
|
||||
from .sql.expression import select as select
|
||||
from .sql.expression import col as col
|
||||
from .sql.sqltypes import AutoString as AutoString
|
||||
|
||||
# Export SQLModel specifics (equivalent to Pydantic)
|
||||
from .main import SQLModel as SQLModel
|
||||
from .main import Field as Field
|
||||
from .main import Relationship as Relationship
|
32
sqlmodel/default.py
Normal file
32
sqlmodel/default.py
Normal file
@ -0,0 +1,32 @@
|
||||
from typing import Any, TypeVar
|
||||
|
||||
|
||||
class _DefaultPlaceholder:
|
||||
"""
|
||||
You shouldn't use this class directly.
|
||||
|
||||
It's used internally to recognize when a default value has been overwritten, even
|
||||
if the overriden default value was truthy.
|
||||
"""
|
||||
|
||||
def __init__(self, value: Any):
|
||||
self.value = value
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.value)
|
||||
|
||||
def __eq__(self, o: object) -> bool:
|
||||
return isinstance(o, _DefaultPlaceholder) and o.value == self.value
|
||||
|
||||
|
||||
_TDefaultType = TypeVar("_TDefaultType")
|
||||
|
||||
|
||||
def Default(value: _TDefaultType) -> _TDefaultType:
|
||||
"""
|
||||
You shouldn't use this function directly.
|
||||
|
||||
It's used internally to recognize when a default value has been overwritten, even
|
||||
if the overriden default value was truthy.
|
||||
"""
|
||||
return _DefaultPlaceholder(value) # type: ignore
|
0
sqlmodel/engine/__init__.py
Normal file
0
sqlmodel/engine/__init__.py
Normal file
139
sqlmodel/engine/create.py
Normal file
139
sqlmodel/engine/create.py
Normal file
@ -0,0 +1,139 @@
|
||||
import json
|
||||
import sqlite3
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||
|
||||
from sqlalchemy import create_engine as _create_engine
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.future import Engine as _FutureEngine
|
||||
from sqlalchemy.pool import Pool
|
||||
from typing_extensions import Literal, TypedDict
|
||||
|
||||
from ..default import Default, _DefaultPlaceholder
|
||||
|
||||
# Types defined in sqlalchemy2-stubs, but can't be imported, so re-define here
|
||||
|
||||
_Debug = Literal["debug"]
|
||||
|
||||
_IsolationLevel = Literal[
|
||||
"SERIALIZABLE",
|
||||
"REPEATABLE READ",
|
||||
"READ COMMITTED",
|
||||
"READ UNCOMMITTED",
|
||||
"AUTOCOMMIT",
|
||||
]
|
||||
_ParamStyle = Literal["qmark", "numeric", "named", "format", "pyformat"]
|
||||
_ResetOnReturn = Literal["rollback", "commit"]
|
||||
|
||||
|
||||
class _SQLiteConnectArgs(TypedDict, total=False):
|
||||
timeout: float
|
||||
detect_types: Any
|
||||
isolation_level: Optional[Literal["DEFERRED", "IMMEDIATE", "EXCLUSIVE"]]
|
||||
check_same_thread: bool
|
||||
factory: Type[sqlite3.Connection]
|
||||
cached_statements: int
|
||||
uri: bool
|
||||
|
||||
|
||||
_ConnectArgs = Union[_SQLiteConnectArgs, Dict[str, Any]]
|
||||
|
||||
|
||||
# Re-define create_engine to have by default future=True, and assume that's what is used
|
||||
# Also show the default values used for each parameter, but don't set them unless
|
||||
# explicitly passed as arguments by the user to prevent errors. E.g. SQLite doesn't
|
||||
# support pool connection arguments.
|
||||
def create_engine(
|
||||
url: Union[str, URL],
|
||||
*,
|
||||
connect_args: _ConnectArgs = Default({}), # type: ignore
|
||||
echo: Union[bool, _Debug] = Default(False),
|
||||
echo_pool: Union[bool, _Debug] = Default(False),
|
||||
enable_from_linting: bool = Default(True),
|
||||
encoding: str = Default("utf-8"),
|
||||
execution_options: Dict[Any, Any] = Default({}),
|
||||
future: bool = True,
|
||||
hide_parameters: bool = Default(False),
|
||||
implicit_returning: bool = Default(True),
|
||||
isolation_level: Optional[_IsolationLevel] = Default(None),
|
||||
json_deserializer: Callable[..., Any] = Default(json.loads),
|
||||
json_serializer: Callable[..., Any] = Default(json.dumps),
|
||||
label_length: Optional[int] = Default(None),
|
||||
logging_name: Optional[str] = Default(None),
|
||||
max_identifier_length: Optional[int] = Default(None),
|
||||
max_overflow: int = Default(10),
|
||||
module: Optional[Any] = Default(None),
|
||||
paramstyle: Optional[_ParamStyle] = Default(None),
|
||||
pool: Optional[Pool] = Default(None),
|
||||
poolclass: Optional[Type[Pool]] = Default(None),
|
||||
pool_logging_name: Optional[str] = Default(None),
|
||||
pool_pre_ping: bool = Default(False),
|
||||
pool_size: int = Default(5),
|
||||
pool_recycle: int = Default(-1),
|
||||
pool_reset_on_return: Optional[_ResetOnReturn] = Default("rollback"),
|
||||
pool_timeout: float = Default(30),
|
||||
pool_use_lifo: bool = Default(False),
|
||||
plugins: Optional[List[str]] = Default(None),
|
||||
query_cache_size: Optional[int] = Default(None),
|
||||
**kwargs: Any,
|
||||
) -> _FutureEngine:
|
||||
current_kwargs: Dict[str, Any] = {
|
||||
"future": future,
|
||||
}
|
||||
if not isinstance(echo, _DefaultPlaceholder):
|
||||
current_kwargs["echo"] = echo
|
||||
if not isinstance(echo_pool, _DefaultPlaceholder):
|
||||
current_kwargs["echo_pool"] = echo_pool
|
||||
if not isinstance(enable_from_linting, _DefaultPlaceholder):
|
||||
current_kwargs["enable_from_linting"] = enable_from_linting
|
||||
if not isinstance(connect_args, _DefaultPlaceholder):
|
||||
current_kwargs["connect_args"] = connect_args
|
||||
if not isinstance(encoding, _DefaultPlaceholder):
|
||||
current_kwargs["encoding"] = encoding
|
||||
if not isinstance(execution_options, _DefaultPlaceholder):
|
||||
current_kwargs["execution_options"] = execution_options
|
||||
if not isinstance(hide_parameters, _DefaultPlaceholder):
|
||||
current_kwargs["hide_parameters"] = hide_parameters
|
||||
if not isinstance(implicit_returning, _DefaultPlaceholder):
|
||||
current_kwargs["implicit_returning"] = implicit_returning
|
||||
if not isinstance(isolation_level, _DefaultPlaceholder):
|
||||
current_kwargs["isolation_level"] = isolation_level
|
||||
if not isinstance(json_deserializer, _DefaultPlaceholder):
|
||||
current_kwargs["json_deserializer"] = json_deserializer
|
||||
if not isinstance(json_serializer, _DefaultPlaceholder):
|
||||
current_kwargs["json_serializer"] = json_serializer
|
||||
if not isinstance(label_length, _DefaultPlaceholder):
|
||||
current_kwargs["label_length"] = label_length
|
||||
if not isinstance(logging_name, _DefaultPlaceholder):
|
||||
current_kwargs["logging_name"] = logging_name
|
||||
if not isinstance(max_identifier_length, _DefaultPlaceholder):
|
||||
current_kwargs["max_identifier_length"] = max_identifier_length
|
||||
if not isinstance(max_overflow, _DefaultPlaceholder):
|
||||
current_kwargs["max_overflow"] = max_overflow
|
||||
if not isinstance(module, _DefaultPlaceholder):
|
||||
current_kwargs["module"] = module
|
||||
if not isinstance(paramstyle, _DefaultPlaceholder):
|
||||
current_kwargs["paramstyle"] = paramstyle
|
||||
if not isinstance(pool, _DefaultPlaceholder):
|
||||
current_kwargs["pool"] = pool
|
||||
if not isinstance(poolclass, _DefaultPlaceholder):
|
||||
current_kwargs["poolclass"] = poolclass
|
||||
if not isinstance(pool_logging_name, _DefaultPlaceholder):
|
||||
current_kwargs["pool_logging_name"] = pool_logging_name
|
||||
if not isinstance(pool_pre_ping, _DefaultPlaceholder):
|
||||
current_kwargs["pool_pre_ping"] = pool_pre_ping
|
||||
if not isinstance(pool_size, _DefaultPlaceholder):
|
||||
current_kwargs["pool_size"] = pool_size
|
||||
if not isinstance(pool_recycle, _DefaultPlaceholder):
|
||||
current_kwargs["pool_recycle"] = pool_recycle
|
||||
if not isinstance(pool_reset_on_return, _DefaultPlaceholder):
|
||||
current_kwargs["pool_reset_on_return"] = pool_reset_on_return
|
||||
if not isinstance(pool_timeout, _DefaultPlaceholder):
|
||||
current_kwargs["pool_timeout"] = pool_timeout
|
||||
if not isinstance(pool_use_lifo, _DefaultPlaceholder):
|
||||
current_kwargs["pool_use_lifo"] = pool_use_lifo
|
||||
if not isinstance(plugins, _DefaultPlaceholder):
|
||||
current_kwargs["plugins"] = plugins
|
||||
if not isinstance(query_cache_size, _DefaultPlaceholder):
|
||||
current_kwargs["query_cache_size"] = query_cache_size
|
||||
current_kwargs.update(kwargs)
|
||||
return _create_engine(url, **current_kwargs)
|
79
sqlmodel/engine/result.py
Normal file
79
sqlmodel/engine/result.py
Normal file
@ -0,0 +1,79 @@
|
||||
from typing import Generic, Iterator, List, Optional, TypeVar
|
||||
|
||||
from sqlalchemy.engine.result import Result as _Result
|
||||
from sqlalchemy.engine.result import ScalarResult as _ScalarResult
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class ScalarResult(_ScalarResult, Generic[_T]):
|
||||
def all(self) -> List[_T]:
|
||||
return super().all()
|
||||
|
||||
def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]:
|
||||
return super().partitions(size)
|
||||
|
||||
def fetchall(self) -> List[_T]:
|
||||
return super().fetchall()
|
||||
|
||||
def fetchmany(self, size: Optional[int] = None) -> List[_T]:
|
||||
return super().fetchmany(size)
|
||||
|
||||
def __iter__(self) -> Iterator[_T]:
|
||||
return super().__iter__()
|
||||
|
||||
def __next__(self) -> _T:
|
||||
return super().__next__()
|
||||
|
||||
def first(self) -> Optional[_T]:
|
||||
return super().first()
|
||||
|
||||
def one_or_none(self) -> Optional[_T]:
|
||||
return super().one_or_none()
|
||||
|
||||
def one(self) -> _T:
|
||||
return super().one()
|
||||
|
||||
|
||||
class Result(_Result, Generic[_T]):
|
||||
def scalars(self, index: int = 0) -> ScalarResult[_T]:
|
||||
return super().scalars(index) # type: ignore
|
||||
|
||||
def __iter__(self) -> Iterator[_T]: # type: ignore
|
||||
return super().__iter__() # type: ignore
|
||||
|
||||
def __next__(self) -> _T: # type: ignore
|
||||
return super().__next__() # type: ignore
|
||||
|
||||
def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]: # type: ignore
|
||||
return super().partitions(size) # type: ignore
|
||||
|
||||
def fetchall(self) -> List[_T]: # type: ignore
|
||||
return super().fetchall() # type: ignore
|
||||
|
||||
def fetchone(self) -> Optional[_T]: # type: ignore
|
||||
return super().fetchone() # type: ignore
|
||||
|
||||
def fetchmany(self, size: Optional[int] = None) -> List[_T]: # type: ignore
|
||||
return super().fetchmany() # type: ignore
|
||||
|
||||
def all(self) -> List[_T]: # type: ignore
|
||||
return super().all() # type: ignore
|
||||
|
||||
def first(self) -> Optional[_T]: # type: ignore
|
||||
return super().first() # type: ignore
|
||||
|
||||
def one_or_none(self) -> Optional[_T]: # type: ignore
|
||||
return super().one_or_none() # type: ignore
|
||||
|
||||
def scalar_one(self) -> _T:
|
||||
return super().scalar_one() # type: ignore
|
||||
|
||||
def scalar_one_or_none(self) -> Optional[_T]:
|
||||
return super().scalar_one_or_none() # type: ignore
|
||||
|
||||
def one(self) -> _T: # type: ignore
|
||||
return super().one() # type: ignore
|
||||
|
||||
def scalar(self) -> Optional[_T]:
|
||||
return super().scalar() # type: ignore
|
0
sqlmodel/ext/__init__.py
Normal file
0
sqlmodel/ext/__init__.py
Normal file
0
sqlmodel/ext/asyncio/__init__.py
Normal file
0
sqlmodel/ext/asyncio/__init__.py
Normal file
62
sqlmodel/ext/asyncio/session.py
Normal file
62
sqlmodel/ext/asyncio/session.py
Normal file
@ -0,0 +1,62 @@
|
||||
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union
|
||||
|
||||
from sqlalchemy import util
|
||||
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
|
||||
from sqlalchemy.ext.asyncio import engine
|
||||
from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
|
||||
from sqlalchemy.util.concurrency import greenlet_spawn
|
||||
from sqlmodel.sql.base import Executable
|
||||
|
||||
from ...engine.result import ScalarResult
|
||||
from ...orm.session import Session
|
||||
from ...sql.expression import Select
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class AsyncSession(_AsyncSession):
|
||||
sync_session: Session
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
|
||||
binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
|
||||
**kw,
|
||||
):
|
||||
# All the same code of the original AsyncSession
|
||||
kw["future"] = True
|
||||
if bind:
|
||||
self.bind = bind
|
||||
bind = engine._get_sync_engine_or_connection(bind) # type: ignore
|
||||
|
||||
if binds:
|
||||
self.binds = binds
|
||||
binds = {
|
||||
key: engine._get_sync_engine_or_connection(b) # type: ignore
|
||||
for key, b in binds.items()
|
||||
}
|
||||
|
||||
self.sync_session = self._proxied = self._assign_proxied( # type: ignore
|
||||
Session(bind=bind, binds=binds, **kw) # type: ignore
|
||||
)
|
||||
|
||||
async def exec(
|
||||
self,
|
||||
statement: Union[Select[_T], Executable[_T]],
|
||||
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
|
||||
execution_options: Mapping[Any, Any] = util.EMPTY_DICT,
|
||||
bind_arguments: Optional[Mapping[str, Any]] = None,
|
||||
**kw: Any,
|
||||
) -> ScalarResult[_T]:
|
||||
# TODO: the documentation says execution_options accepts a dict, but only
|
||||
# util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
|
||||
execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore
|
||||
|
||||
return await greenlet_spawn( # type: ignore
|
||||
self.sync_session.exec,
|
||||
statement,
|
||||
params=params,
|
||||
execution_options=execution_options,
|
||||
bind_arguments=bind_arguments,
|
||||
**kw,
|
||||
)
|
631
sqlmodel/main.py
Normal file
631
sqlmodel/main.py
Normal file
@ -0,0 +1,631 @@
|
||||
import ipaddress
|
||||
import uuid
|
||||
import weakref
|
||||
from datetime import date, datetime, time, timedelta
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AbstractSet,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
ForwardRef,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.errors import ConfigError, DictError
|
||||
from pydantic.fields import FieldInfo as PydanticFieldInfo
|
||||
from pydantic.fields import ModelField, Undefined, UndefinedType
|
||||
from pydantic.main import BaseConfig, ModelMetaclass, validate_model
|
||||
from pydantic.typing import NoArgAnyCallable, resolve_annotations
|
||||
from pydantic.utils import ROOT_KEY, Representation, ValueItems
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
Date,
|
||||
DateTime,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
Interval,
|
||||
Numeric,
|
||||
inspect,
|
||||
)
|
||||
from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship
|
||||
from sqlalchemy.orm.attributes import set_attribute
|
||||
from sqlalchemy.orm.decl_api import DeclarativeMeta
|
||||
from sqlalchemy.orm.instrumentation import is_instrumented
|
||||
from sqlalchemy.sql.schema import MetaData
|
||||
from sqlalchemy.sql.sqltypes import LargeBinary, Time
|
||||
|
||||
from .sql.sqltypes import GUID, AutoString
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def __dataclass_transform__(
|
||||
*,
|
||||
eq_default: bool = True,
|
||||
order_default: bool = False,
|
||||
kw_only_default: bool = False,
|
||||
field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
|
||||
) -> Callable[[_T], _T]:
|
||||
return lambda a: a
|
||||
|
||||
|
||||
class FieldInfo(PydanticFieldInfo):
|
||||
def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
|
||||
primary_key = kwargs.pop("primary_key", False)
|
||||
nullable = kwargs.pop("nullable", Undefined)
|
||||
foreign_key = kwargs.pop("foreign_key", Undefined)
|
||||
index = kwargs.pop("index", Undefined)
|
||||
sa_column = kwargs.pop("sa_column", Undefined)
|
||||
sa_column_args = kwargs.pop("sa_column_args", Undefined)
|
||||
sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined)
|
||||
if sa_column is not Undefined:
|
||||
if sa_column_args is not Undefined:
|
||||
raise RuntimeError(
|
||||
"Passing sa_column_args is not supported when "
|
||||
"also passing a sa_column"
|
||||
)
|
||||
if sa_column_kwargs is not Undefined:
|
||||
raise RuntimeError(
|
||||
"Passing sa_column_kwargs is not supported when "
|
||||
"also passing a sa_column"
|
||||
)
|
||||
super().__init__(default=default, **kwargs)
|
||||
self.primary_key = primary_key
|
||||
self.nullable = nullable
|
||||
self.foreign_key = foreign_key
|
||||
self.index = index
|
||||
self.sa_column = sa_column
|
||||
self.sa_column_args = sa_column_args
|
||||
self.sa_column_kwargs = sa_column_kwargs
|
||||
|
||||
|
||||
class RelationshipInfo(Representation):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
back_populates: Optional[str] = None,
|
||||
link_model: Optional[Any] = None,
|
||||
sa_relationship: Optional[RelationshipProperty] = None,
|
||||
sa_relationship_args: Optional[Sequence[Any]] = None,
|
||||
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> None:
|
||||
if sa_relationship is not None:
|
||||
if sa_relationship_args is not None:
|
||||
raise RuntimeError(
|
||||
"Passing sa_relationship_args is not supported when "
|
||||
"also passing a sa_relationship"
|
||||
)
|
||||
if sa_relationship_kwargs is not None:
|
||||
raise RuntimeError(
|
||||
"Passing sa_relationship_kwargs is not supported when "
|
||||
"also passing a sa_relationship"
|
||||
)
|
||||
self.back_populates = back_populates
|
||||
self.link_model = link_model
|
||||
self.sa_relationship = sa_relationship
|
||||
self.sa_relationship_args = sa_relationship_args
|
||||
self.sa_relationship_kwargs = sa_relationship_kwargs
|
||||
|
||||
|
||||
def Field(
|
||||
default: Any = Undefined,
|
||||
*,
|
||||
default_factory: Optional[NoArgAnyCallable] = None,
|
||||
alias: str = None,
|
||||
title: str = None,
|
||||
description: str = None,
|
||||
exclude: Union[
|
||||
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
|
||||
] = None,
|
||||
include: Union[
|
||||
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
|
||||
] = None,
|
||||
const: bool = None,
|
||||
gt: float = None,
|
||||
ge: float = None,
|
||||
lt: float = None,
|
||||
le: float = None,
|
||||
multiple_of: float = None,
|
||||
min_items: int = None,
|
||||
max_items: int = None,
|
||||
min_length: int = None,
|
||||
max_length: int = None,
|
||||
allow_mutation: bool = True,
|
||||
regex: str = None,
|
||||
primary_key: bool = False,
|
||||
foreign_key: Optional[Any] = None,
|
||||
nullable: Union[bool, UndefinedType] = Undefined,
|
||||
index: Union[bool, UndefinedType] = Undefined,
|
||||
sa_column: Union[Column, UndefinedType] = Undefined,
|
||||
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
|
||||
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
|
||||
schema_extra: Optional[Dict[str, Any]] = None,
|
||||
) -> Any:
|
||||
current_schema_extra = schema_extra or {}
|
||||
field_info = FieldInfo(
|
||||
default,
|
||||
default_factory=default_factory,
|
||||
alias=alias,
|
||||
title=title,
|
||||
description=description,
|
||||
exclude=exclude,
|
||||
include=include,
|
||||
const=const,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
le=le,
|
||||
multiple_of=multiple_of,
|
||||
min_items=min_items,
|
||||
max_items=max_items,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
allow_mutation=allow_mutation,
|
||||
regex=regex,
|
||||
primary_key=primary_key,
|
||||
foreign_key=foreign_key,
|
||||
nullable=nullable,
|
||||
index=index,
|
||||
sa_column=sa_column,
|
||||
sa_column_args=sa_column_args,
|
||||
sa_column_kwargs=sa_column_kwargs,
|
||||
**current_schema_extra,
|
||||
)
|
||||
field_info._validate()
|
||||
return field_info
|
||||
|
||||
|
||||
def Relationship(
|
||||
*,
|
||||
back_populates: Optional[str] = None,
|
||||
link_model: Optional[Any] = None,
|
||||
sa_relationship: Optional[RelationshipProperty] = None,
|
||||
sa_relationship_args: Optional[Sequence[Any]] = None,
|
||||
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> Any:
|
||||
relationship_info = RelationshipInfo(
|
||||
back_populates=back_populates,
|
||||
link_model=link_model,
|
||||
sa_relationship=sa_relationship,
|
||||
sa_relationship_args=sa_relationship_args,
|
||||
sa_relationship_kwargs=sa_relationship_kwargs,
|
||||
)
|
||||
return relationship_info
|
||||
|
||||
|
||||
@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
|
||||
class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
||||
__sqlmodel_relationships__: Dict[str, RelationshipInfo]
|
||||
__config__: Type[BaseConfig]
|
||||
__fields__: Dict[str, ModelField]
|
||||
|
||||
# Replicate SQLAlchemy
|
||||
def __setattr__(cls, name: str, value: Any) -> None:
|
||||
if getattr(cls.__config__, "table", False): # type: ignore
|
||||
DeclarativeMeta.__setattr__(cls, name, value)
|
||||
else:
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def __delattr__(cls, name: str) -> None:
|
||||
if getattr(cls.__config__, "table", False): # type: ignore
|
||||
DeclarativeMeta.__delattr__(cls, name)
|
||||
else:
|
||||
super().__delattr__(name)
|
||||
|
||||
# From Pydantic
|
||||
def __new__(cls, name, bases, class_dict: dict, **kwargs) -> Any:
|
||||
relationships: Dict[str, RelationshipInfo] = {}
|
||||
dict_for_pydantic = {}
|
||||
original_annotations = resolve_annotations(
|
||||
class_dict.get("__annotations__", {}), class_dict.get("__module__", None)
|
||||
)
|
||||
pydantic_annotations = {}
|
||||
relationship_annotations = {}
|
||||
for k, v in class_dict.items():
|
||||
if isinstance(v, RelationshipInfo):
|
||||
relationships[k] = v
|
||||
else:
|
||||
dict_for_pydantic[k] = v
|
||||
for k, v in original_annotations.items():
|
||||
if k in relationships:
|
||||
relationship_annotations[k] = v
|
||||
else:
|
||||
pydantic_annotations[k] = v
|
||||
dict_used = {
|
||||
**dict_for_pydantic,
|
||||
"__weakref__": None,
|
||||
"__sqlmodel_relationships__": relationships,
|
||||
"__annotations__": pydantic_annotations,
|
||||
}
|
||||
# Duplicate logic from Pydantic to filter config kwargs because if they are
|
||||
# passed directly including the registry Pydantic will pass them over to the
|
||||
# superclass causing an error
|
||||
allowed_config_kwargs: Set[str] = {
|
||||
key
|
||||
for key in dir(BaseConfig)
|
||||
if not (
|
||||
key.startswith("__") and key.endswith("__")
|
||||
) # skip dunder methods and attributes
|
||||
}
|
||||
pydantic_kwargs = kwargs.copy()
|
||||
config_kwargs = {
|
||||
key: pydantic_kwargs.pop(key)
|
||||
for key in pydantic_kwargs.keys() & allowed_config_kwargs
|
||||
}
|
||||
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
|
||||
new_cls.__annotations__ = {
|
||||
**relationship_annotations,
|
||||
**pydantic_annotations,
|
||||
**new_cls.__annotations__,
|
||||
}
|
||||
|
||||
def get_config(name: str) -> Any:
|
||||
config_class_value = getattr(new_cls.__config__, name, Undefined)
|
||||
if config_class_value is not Undefined:
|
||||
return config_class_value
|
||||
kwarg_value = kwargs.get(name, Undefined)
|
||||
if kwarg_value is not Undefined:
|
||||
return kwarg_value
|
||||
return Undefined
|
||||
|
||||
config_table = get_config("table")
|
||||
if config_table is True:
|
||||
# If it was passed by kwargs, ensure it's also set in config
|
||||
new_cls.__config__.table = config_table
|
||||
for k, v in new_cls.__fields__.items():
|
||||
col = get_column_from_field(v)
|
||||
setattr(new_cls, k, col)
|
||||
# Set a config flag to tell FastAPI that this should be read with a field
|
||||
# in orm_mode instead of preemptively converting it to a dict.
|
||||
# This could be done by reading new_cls.__config__.table in FastAPI, but
|
||||
# that's very specific about SQLModel, so let's have another config that
|
||||
# other future tools based on Pydantic can use.
|
||||
new_cls.__config__.read_with_orm_mode = True
|
||||
|
||||
config_registry = get_config("registry")
|
||||
if config_registry is not Undefined:
|
||||
config_registry = cast(registry, config_registry)
|
||||
# If it was passed by kwargs, ensure it's also set in config
|
||||
new_cls.__config__.registry = config_table
|
||||
setattr(new_cls, "_sa_registry", config_registry)
|
||||
setattr(new_cls, "metadata", config_registry.metadata)
|
||||
setattr(new_cls, "__abstract__", True)
|
||||
return new_cls
|
||||
|
||||
# Override SQLAlchemy, allow both SQLAlchemy and plain Pydantic models
|
||||
def __init__(
|
||||
cls, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any], **kw: Any
|
||||
) -> None:
|
||||
# Only one of the base classes (or the current one) should be a table model
|
||||
# this allows FastAPI cloning a SQLModel for the response_model without
|
||||
# trying to create a new SQLAlchemy, for a new table, with the same name, that
|
||||
# triggers an error
|
||||
base_is_table = False
|
||||
for base in bases:
|
||||
config = getattr(base, "__config__")
|
||||
if config and getattr(config, "table", False):
|
||||
base_is_table = True
|
||||
break
|
||||
if getattr(cls.__config__, "table", False) and not base_is_table:
|
||||
dict_used = dict_.copy()
|
||||
for field_name, field_value in cls.__fields__.items():
|
||||
dict_used[field_name] = get_column_from_field(field_value)
|
||||
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
|
||||
if rel_info.sa_relationship:
|
||||
# There's a SQLAlchemy relationship declared, that takes precedence
|
||||
# over anything else, use that and continue with the next attribute
|
||||
dict_used[rel_name] = rel_info.sa_relationship
|
||||
continue
|
||||
ann = cls.__annotations__[rel_name]
|
||||
temp_field = ModelField.infer(
|
||||
name=rel_name,
|
||||
value=rel_info,
|
||||
annotation=ann,
|
||||
class_validators=None,
|
||||
config=BaseConfig,
|
||||
)
|
||||
relationship_to = temp_field.type_
|
||||
if isinstance(temp_field.type_, ForwardRef):
|
||||
relationship_to = temp_field.type_.__forward_arg__
|
||||
rel_kwargs: Dict[str, Any] = {}
|
||||
if rel_info.back_populates:
|
||||
rel_kwargs["back_populates"] = rel_info.back_populates
|
||||
if rel_info.link_model:
|
||||
ins = inspect(rel_info.link_model)
|
||||
local_table = getattr(ins, "local_table")
|
||||
if local_table is None:
|
||||
raise RuntimeError(
|
||||
"Couldn't find the secondary table for "
|
||||
f"model {rel_info.link_model}"
|
||||
)
|
||||
rel_kwargs["secondary"] = local_table
|
||||
rel_args: List[Any] = []
|
||||
if rel_info.sa_relationship_args:
|
||||
rel_args.extend(rel_info.sa_relationship_args)
|
||||
if rel_info.sa_relationship_kwargs:
|
||||
rel_kwargs.update(rel_info.sa_relationship_kwargs)
|
||||
rel_value: RelationshipProperty = relationship(
|
||||
relationship_to, *rel_args, **rel_kwargs
|
||||
)
|
||||
dict_used[rel_name] = rel_value
|
||||
DeclarativeMeta.__init__(cls, classname, bases, dict_used, **kw)
|
||||
else:
|
||||
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
|
||||
|
||||
|
||||
def get_sqlachemy_type(field: ModelField) -> Any:
|
||||
if issubclass(field.type_, str):
|
||||
if field.field_info.max_length:
|
||||
return AutoString(length=field.field_info.max_length)
|
||||
return AutoString
|
||||
if issubclass(field.type_, float):
|
||||
return Float
|
||||
if issubclass(field.type_, bool):
|
||||
return Boolean
|
||||
if issubclass(field.type_, int):
|
||||
return Integer
|
||||
if issubclass(field.type_, datetime):
|
||||
return DateTime
|
||||
if issubclass(field.type_, date):
|
||||
return Date
|
||||
if issubclass(field.type_, timedelta):
|
||||
return Interval
|
||||
if issubclass(field.type_, time):
|
||||
return Time
|
||||
if issubclass(field.type_, Enum):
|
||||
return Enum
|
||||
if issubclass(field.type_, bytes):
|
||||
return LargeBinary
|
||||
if issubclass(field.type_, Decimal):
|
||||
return Numeric
|
||||
if issubclass(field.type_, ipaddress.IPv4Address):
|
||||
return AutoString
|
||||
if issubclass(field.type_, ipaddress.IPv4Network):
|
||||
return AutoString
|
||||
if issubclass(field.type_, ipaddress.IPv6Address):
|
||||
return AutoString
|
||||
if issubclass(field.type_, ipaddress.IPv6Network):
|
||||
return AutoString
|
||||
if issubclass(field.type_, Path):
|
||||
return AutoString
|
||||
if issubclass(field.type_, uuid.UUID):
|
||||
return GUID
|
||||
|
||||
|
||||
def get_column_from_field(field: ModelField) -> Column:
|
||||
sa_column = getattr(field.field_info, "sa_column", Undefined)
|
||||
if isinstance(sa_column, Column):
|
||||
return sa_column
|
||||
sa_type = get_sqlachemy_type(field)
|
||||
primary_key = getattr(field.field_info, "primary_key", False)
|
||||
nullable = not field.required
|
||||
index = getattr(field.field_info, "index", Undefined)
|
||||
if index is Undefined:
|
||||
index = True
|
||||
if hasattr(field.field_info, "nullable"):
|
||||
field_nullable = getattr(field.field_info, "nullable")
|
||||
if field_nullable != Undefined:
|
||||
nullable = field_nullable
|
||||
args = []
|
||||
foreign_key = getattr(field.field_info, "foreign_key", None)
|
||||
if foreign_key:
|
||||
args.append(ForeignKey(foreign_key))
|
||||
kwargs = {
|
||||
"primary_key": primary_key,
|
||||
"nullable": nullable,
|
||||
"index": index,
|
||||
}
|
||||
sa_default = Undefined
|
||||
if field.field_info.default_factory:
|
||||
sa_default = field.field_info.default_factory
|
||||
elif field.field_info.default is not Undefined:
|
||||
sa_default = field.field_info.default
|
||||
if sa_default is not Undefined:
|
||||
kwargs["default"] = sa_default
|
||||
sa_column_args = getattr(field.field_info, "sa_column_args", Undefined)
|
||||
if sa_column_args is not Undefined:
|
||||
args.extend(list(cast(Sequence, sa_column_args)))
|
||||
sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined)
|
||||
if sa_column_kwargs is not Undefined:
|
||||
kwargs.update(cast(dict, sa_column_kwargs))
|
||||
return Column(sa_type, *args, **kwargs)
|
||||
|
||||
|
||||
class_registry = weakref.WeakValueDictionary() # type: ignore
|
||||
|
||||
default_registry = registry()
|
||||
|
||||
|
||||
class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
|
||||
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
|
||||
__slots__ = ("__weakref__",)
|
||||
__tablename__: ClassVar[Union[str, Callable[..., str]]]
|
||||
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]]
|
||||
__name__: ClassVar[str]
|
||||
metadata: ClassVar[MetaData]
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
def __new__(cls, *args, **kwargs) -> Any:
|
||||
new_object = super().__new__(cls)
|
||||
# SQLAlchemy doesn't call __init__ on the base class
|
||||
# Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html
|
||||
# Set __fields_set__ here, that would have been set when calling __init__
|
||||
# in the Pydantic model so that when SQLAlchemy sets attributes that are
|
||||
# added (e.g. when querying from DB) to the __fields_set__, this already exists
|
||||
object.__setattr__(new_object, "__fields_set__", set())
|
||||
return new_object
|
||||
|
||||
def __init__(__pydantic_self__, **data: Any) -> None:
|
||||
# Uses something other than `self` the first arg to allow "self" as a
|
||||
# settable attribute
|
||||
if TYPE_CHECKING:
|
||||
__pydantic_self__.__dict__: Dict[str, Any] = {}
|
||||
__pydantic_self__.__fields_set__: Set[str] = set()
|
||||
values, fields_set, validation_error = validate_model(
|
||||
__pydantic_self__.__class__, data
|
||||
)
|
||||
# Only raise errors if not a SQLModel model
|
||||
if (
|
||||
not getattr(__pydantic_self__.__config__, "table", False)
|
||||
and validation_error
|
||||
):
|
||||
raise validation_error
|
||||
# Do not set values as in Pydantic, pass them through setattr, so SQLAlchemy
|
||||
# can handle them
|
||||
# object.__setattr__(__pydantic_self__, '__dict__', values)
|
||||
object.__setattr__(__pydantic_self__, "__fields_set__", fields_set)
|
||||
for key, value in values.items():
|
||||
setattr(__pydantic_self__, key, value)
|
||||
non_pydantic_keys = data.keys() - values.keys()
|
||||
for key in non_pydantic_keys:
|
||||
if key in __pydantic_self__.__sqlmodel_relationships__:
|
||||
setattr(__pydantic_self__, key, data[key])
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name in {"_sa_instance_state"}:
|
||||
self.__dict__[name] = value
|
||||
return
|
||||
else:
|
||||
# Set in SQLAlchemy, before Pydantic to trigger events and updates
|
||||
if getattr(self.__config__, "table", False):
|
||||
if is_instrumented(self, name):
|
||||
set_attribute(self, name, value)
|
||||
# Set in Pydantic model to trigger possible validation changes, only for
|
||||
# non relationship values
|
||||
if name not in self.__sqlmodel_relationships__:
|
||||
super().__setattr__(name, value)
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None):
|
||||
# Duplicated from Pydantic
|
||||
if not cls.__config__.orm_mode:
|
||||
raise ConfigError(
|
||||
"You must have the config attribute orm_mode=True to use from_orm"
|
||||
)
|
||||
obj = {ROOT_KEY: obj} if cls.__custom_root_type__ else cls._decompose_class(obj)
|
||||
# SQLModel, support update dict
|
||||
if update is not None:
|
||||
obj = {**obj, **update}
|
||||
# End SQLModel support dict
|
||||
if not getattr(cls.__config__, "table", False):
|
||||
# If not table, normal Pydantic code
|
||||
m = cls.__new__(cls)
|
||||
else:
|
||||
# If table, create the new instance normally to make SQLAlchemy create
|
||||
# the _sa_instance_state attribute
|
||||
m = cls()
|
||||
values, fields_set, validation_error = validate_model(cls, obj)
|
||||
if validation_error:
|
||||
raise validation_error
|
||||
# Updated to trigger SQLAlchemy internal handling
|
||||
if not getattr(cls.__config__, "table", False):
|
||||
object.__setattr__(m, "__dict__", values)
|
||||
else:
|
||||
for key, value in values.items():
|
||||
setattr(m, key, value)
|
||||
# Continue with standard Pydantic logic
|
||||
object.__setattr__(m, "__fields_set__", fields_set)
|
||||
m._init_private_attributes()
|
||||
return m
|
||||
|
||||
@classmethod
|
||||
def parse_obj(
|
||||
cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None
|
||||
) -> "SQLModel":
|
||||
obj = cls._enforce_dict_if_root(obj)
|
||||
# SQLModel, support update dict
|
||||
if update is not None:
|
||||
obj = {**obj, **update}
|
||||
# End SQLModel support dict
|
||||
return super().parse_obj(obj)
|
||||
|
||||
def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]:
|
||||
# Don't show SQLAlchemy private attributes
|
||||
return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_sa_")]
|
||||
|
||||
# From Pydantic, override to enforce validation with dict
|
||||
@classmethod
|
||||
def validate(cls: Type["SQLModel"], value: Any) -> "SQLModel":
|
||||
if isinstance(value, cls):
|
||||
return value.copy() if cls.__config__.copy_on_model_validation else value
|
||||
|
||||
value = cls._enforce_dict_if_root(value)
|
||||
if isinstance(value, dict):
|
||||
values, fields_set, validation_error = validate_model(cls, value)
|
||||
if validation_error:
|
||||
raise validation_error
|
||||
model = cls(**values)
|
||||
# Reset fields set, this would have been done in Pydantic in __init__
|
||||
object.__setattr__(model, "__fields_set__", fields_set)
|
||||
return model
|
||||
elif cls.__config__.orm_mode:
|
||||
return cls.from_orm(value)
|
||||
elif cls.__custom_root_type__:
|
||||
return cls.parse_obj(value)
|
||||
else:
|
||||
try:
|
||||
value_as_dict = dict(value)
|
||||
except (TypeError, ValueError) as e:
|
||||
raise DictError() from e
|
||||
return cls(**value_as_dict)
|
||||
|
||||
# From Pydantic, override to only show keys from fields, omit SQLAlchemy attributes
|
||||
def _calculate_keys(
|
||||
self,
|
||||
include: Optional[Mapping[Union[int, str], Any]],
|
||||
exclude: Optional[Mapping[Union[int, str], Any]],
|
||||
exclude_unset: bool,
|
||||
update: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[AbstractSet[str]]:
|
||||
if include is None and exclude is None and exclude_unset is False:
|
||||
# Original in Pydantic:
|
||||
# return None
|
||||
# Updated to not return SQLAlchemy attributes
|
||||
# Do not include relationships as that would easily lead to infinite
|
||||
# recursion, or traversing the whole database
|
||||
return self.__fields__.keys() # | self.__sqlmodel_relationships__.keys()
|
||||
|
||||
keys: AbstractSet[str]
|
||||
if exclude_unset:
|
||||
keys = self.__fields_set__.copy()
|
||||
else:
|
||||
# Original in Pydantic:
|
||||
# keys = self.__dict__.keys()
|
||||
# Updated to not return SQLAlchemy attributes
|
||||
# Do not include relationships as that would easily lead to infinite
|
||||
# recursion, or traversing the whole database
|
||||
keys = self.__fields__.keys() # | self.__sqlmodel_relationships__.keys()
|
||||
|
||||
if include is not None:
|
||||
keys &= include.keys()
|
||||
|
||||
if update:
|
||||
keys -= update.keys()
|
||||
|
||||
if exclude:
|
||||
keys -= {k for k, v in exclude.items() if ValueItems.is_true(v)}
|
||||
|
||||
return keys
|
||||
|
||||
@declared_attr # type: ignore
|
||||
def __tablename__(cls) -> str:
|
||||
return cls.__name__.lower()
|
0
sqlmodel/orm/__init__.py
Normal file
0
sqlmodel/orm/__init__.py
Normal file
135
sqlmodel/orm/session.py
Normal file
135
sqlmodel/orm/session.py
Normal file
@ -0,0 +1,135 @@
|
||||
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union, overload
|
||||
|
||||
from sqlalchemy import util
|
||||
from sqlalchemy.orm import Query as _Query
|
||||
from sqlalchemy.orm import Session as _Session
|
||||
from sqlalchemy.sql.base import Executable as _Executable
|
||||
from sqlmodel.sql.expression import Select, SelectOfScalar
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ..engine.result import Result, ScalarResult
|
||||
from ..sql.base import Executable
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class Session(_Session):
|
||||
@overload
|
||||
def exec(
|
||||
self,
|
||||
statement: Select[_T],
|
||||
*,
|
||||
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
|
||||
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
|
||||
bind_arguments: Optional[Mapping[str, Any]] = None,
|
||||
_parent_execute_state: Optional[Any] = None,
|
||||
_add_event: Optional[Any] = None,
|
||||
**kw: Any,
|
||||
) -> Union[Result[_T]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def exec(
|
||||
self,
|
||||
statement: SelectOfScalar[_T],
|
||||
*,
|
||||
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
|
||||
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
|
||||
bind_arguments: Optional[Mapping[str, Any]] = None,
|
||||
_parent_execute_state: Optional[Any] = None,
|
||||
_add_event: Optional[Any] = None,
|
||||
**kw: Any,
|
||||
) -> Union[ScalarResult[_T]]:
|
||||
...
|
||||
|
||||
def exec(
|
||||
self,
|
||||
statement: Union[Select[_T], SelectOfScalar[_T], Executable[_T]],
|
||||
*,
|
||||
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
|
||||
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
|
||||
bind_arguments: Optional[Mapping[str, Any]] = None,
|
||||
_parent_execute_state: Optional[Any] = None,
|
||||
_add_event: Optional[Any] = None,
|
||||
**kw: Any,
|
||||
) -> Union[Result[_T], ScalarResult[_T]]:
|
||||
results = super().execute(
|
||||
statement,
|
||||
params=params,
|
||||
execution_options=execution_options, # type: ignore
|
||||
bind_arguments=bind_arguments,
|
||||
_parent_execute_state=_parent_execute_state,
|
||||
_add_event=_add_event,
|
||||
**kw,
|
||||
)
|
||||
if isinstance(statement, SelectOfScalar):
|
||||
return results.scalars() # type: ignore
|
||||
return results # type: ignore
|
||||
|
||||
def execute(
|
||||
self,
|
||||
statement: _Executable,
|
||||
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
|
||||
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
|
||||
bind_arguments: Optional[Mapping[str, Any]] = None,
|
||||
_parent_execute_state: Optional[Any] = None,
|
||||
_add_event: Optional[Any] = None,
|
||||
**kw: Any,
|
||||
) -> Result[Any]:
|
||||
"""
|
||||
🚨 You probably want to use `session.exec()` instead of `session.execute()`.
|
||||
|
||||
This is the original SQLAlchemy `session.execute()` method that returns objects
|
||||
of type `Row`, and that you have to call `scalars()` to get the model objects.
|
||||
|
||||
For example:
|
||||
|
||||
```Python
|
||||
heroes = session.execute(select(Hero)).scalars().all()
|
||||
```
|
||||
|
||||
instead you could use `exec()`:
|
||||
|
||||
```Python
|
||||
heroes = session.exec(select(Hero)).all()
|
||||
```
|
||||
"""
|
||||
return super().execute( # type: ignore
|
||||
statement,
|
||||
params=params,
|
||||
execution_options=execution_options, # type: ignore
|
||||
bind_arguments=bind_arguments,
|
||||
_parent_execute_state=_parent_execute_state,
|
||||
_add_event=_add_event,
|
||||
**kw,
|
||||
)
|
||||
|
||||
def query(self, *entities: Any, **kwargs: Any) -> "_Query[Any]":
|
||||
"""
|
||||
🚨 You probably want to use `session.exec()` instead of `session.query()`.
|
||||
|
||||
`session.exec()` is SQLModel's own short version with increased type
|
||||
annotations.
|
||||
|
||||
Or otherwise you might want to use `session.execute()` instead of
|
||||
`session.query()`.
|
||||
"""
|
||||
return super().query(*entities, **kwargs)
|
||||
|
||||
def get(
|
||||
self,
|
||||
entity: _T,
|
||||
ident: Any,
|
||||
options: Optional[Sequence[Any]] = None,
|
||||
populate_existing: bool = False,
|
||||
with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None,
|
||||
identity_token: Optional[Any] = None,
|
||||
) -> _T:
|
||||
return super().get(
|
||||
entity,
|
||||
ident,
|
||||
options=options,
|
||||
populate_existing=populate_existing,
|
||||
with_for_update=with_for_update,
|
||||
identity_token=identity_token,
|
||||
)
|
1
sqlmodel/pool/__init__.py
Normal file
1
sqlmodel/pool/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from sqlalchemy.pool import StaticPool as StaticPool # noqa: F401
|
0
sqlmodel/sql/__init__.py
Normal file
0
sqlmodel/sql/__init__.py
Normal file
11
sqlmodel/sql/base.py
Normal file
11
sqlmodel/sql/base.py
Normal file
@ -0,0 +1,11 @@
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from sqlalchemy.sql.base import Executable as _Executable
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class Executable(_Executable, Generic[_T]):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.__dict__["_exec_options"] = kwargs.pop("_exec_options", None)
|
||||
super(_Executable, self).__init__(*args, **kwargs)
|
459
sqlmodel/sql/expression.py
Normal file
459
sqlmodel/sql/expression.py
Normal file
@ -0,0 +1,459 @@
|
||||
# WARNING: do not modify this code, it is generated by expression.py.jinja2
|
||||
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
Mapping,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy.orm import InstrumentedAttribute
|
||||
from sqlalchemy.sql.elements import ColumnClause
|
||||
from sqlalchemy.sql.expression import Select as _Select
|
||||
|
||||
_TSelect = TypeVar("_TSelect")
|
||||
|
||||
# Workaround Generics incompatibility in Python 3.6
|
||||
# Ref: https://github.com/python/typing/issues/449#issuecomment-316061322
|
||||
if sys.version_info.minor >= 7:
|
||||
|
||||
class Select(_Select, Generic[_TSelect]):
|
||||
pass
|
||||
|
||||
# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different
|
||||
# purpose. This is the same as a normal SQLAlchemy Select class where there's only one
|
||||
# entity, so the result will be converted to a scalar by default. This way writing
|
||||
# for loops on the results will feel natural.
|
||||
class SelectOfScalar(_Select, Generic[_TSelect]):
|
||||
pass
|
||||
|
||||
|
||||
else:
|
||||
from typing import GenericMeta # type: ignore
|
||||
|
||||
class GenericSelectMeta(GenericMeta, _Select.__class__): # type: ignore
|
||||
pass
|
||||
|
||||
class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
|
||||
pass
|
||||
|
||||
class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
|
||||
pass
|
||||
|
||||
# Cast them for editors to work correctly, from several tricks tried, this works
|
||||
# for both VS Code and PyCharm
|
||||
Select = cast("Select", _Py36Select) # type: ignore
|
||||
SelectOfScalar = cast("SelectOfScalar", _Py36SelectOfScalar) # type: ignore
|
||||
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from ..main import SQLModel
|
||||
|
||||
# Generated TypeVars start
|
||||
|
||||
|
||||
_TScalar_0 = TypeVar(
|
||||
"_TScalar_0",
|
||||
Column,
|
||||
Sequence,
|
||||
Mapping,
|
||||
UUID,
|
||||
datetime,
|
||||
float,
|
||||
int,
|
||||
bool,
|
||||
bytes,
|
||||
str,
|
||||
None,
|
||||
)
|
||||
|
||||
_TModel_0 = TypeVar("_TModel_0", bound="SQLModel")
|
||||
|
||||
|
||||
_TScalar_1 = TypeVar(
|
||||
"_TScalar_1",
|
||||
Column,
|
||||
Sequence,
|
||||
Mapping,
|
||||
UUID,
|
||||
datetime,
|
||||
float,
|
||||
int,
|
||||
bool,
|
||||
bytes,
|
||||
str,
|
||||
None,
|
||||
)
|
||||
|
||||
_TModel_1 = TypeVar("_TModel_1", bound="SQLModel")
|
||||
|
||||
|
||||
_TScalar_2 = TypeVar(
|
||||
"_TScalar_2",
|
||||
Column,
|
||||
Sequence,
|
||||
Mapping,
|
||||
UUID,
|
||||
datetime,
|
||||
float,
|
||||
int,
|
||||
bool,
|
||||
bytes,
|
||||
str,
|
||||
None,
|
||||
)
|
||||
|
||||
_TModel_2 = TypeVar("_TModel_2", bound="SQLModel")
|
||||
|
||||
|
||||
_TScalar_3 = TypeVar(
|
||||
"_TScalar_3",
|
||||
Column,
|
||||
Sequence,
|
||||
Mapping,
|
||||
UUID,
|
||||
datetime,
|
||||
float,
|
||||
int,
|
||||
bool,
|
||||
bytes,
|
||||
str,
|
||||
None,
|
||||
)
|
||||
|
||||
_TModel_3 = TypeVar("_TModel_3", bound="SQLModel")
|
||||
|
||||
|
||||
# Generated TypeVars end
|
||||
|
||||
|
||||
@overload
|
||||
def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]: # type: ignore
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: # type: ignore
|
||||
...
|
||||
|
||||
|
||||
# Generated overloads start
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: _TScalar_0,
|
||||
entity_1: _TScalar_1,
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TScalar_0, _TScalar_1]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: _TScalar_0,
|
||||
entity_1: Type[_TModel_1],
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TScalar_0, _TModel_1]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: Type[_TModel_0],
|
||||
entity_1: _TScalar_1,
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TModel_0, _TScalar_1]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: Type[_TModel_0],
|
||||
entity_1: Type[_TModel_1],
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TModel_0, _TModel_1]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: _TScalar_0,
|
||||
entity_1: _TScalar_1,
|
||||
entity_2: _TScalar_2,
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: _TScalar_0,
|
||||
entity_1: _TScalar_1,
|
||||
entity_2: Type[_TModel_2],
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: _TScalar_0,
|
||||
entity_1: Type[_TModel_1],
|
||||
entity_2: _TScalar_2,
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: _TScalar_0,
|
||||
entity_1: Type[_TModel_1],
|
||||
entity_2: Type[_TModel_2],
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: Type[_TModel_0],
|
||||
entity_1: _TScalar_1,
|
||||
entity_2: _TScalar_2,
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: Type[_TModel_0],
|
||||
entity_1: _TScalar_1,
|
||||
entity_2: Type[_TModel_2],
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: Type[_TModel_0],
|
||||
entity_1: Type[_TModel_1],
|
||||
entity_2: _TScalar_2,
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: Type[_TModel_0],
|
||||
entity_1: Type[_TModel_1],
|
||||
entity_2: Type[_TModel_2],
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: _TScalar_0,
|
||||
entity_1: _TScalar_1,
|
||||
entity_2: _TScalar_2,
|
||||
entity_3: _TScalar_3,
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TScalar_3]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: _TScalar_0,
|
||||
entity_1: _TScalar_1,
|
||||
entity_2: _TScalar_2,
|
||||
entity_3: Type[_TModel_3],
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TModel_3]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: _TScalar_0,
|
||||
entity_1: _TScalar_1,
|
||||
entity_2: Type[_TModel_2],
|
||||
entity_3: _TScalar_3,
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TScalar_3]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: _TScalar_0,
|
||||
entity_1: _TScalar_1,
|
||||
entity_2: Type[_TModel_2],
|
||||
entity_3: Type[_TModel_3],
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TModel_3]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: _TScalar_0,
|
||||
entity_1: Type[_TModel_1],
|
||||
entity_2: _TScalar_2,
|
||||
entity_3: _TScalar_3,
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TScalar_3]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: _TScalar_0,
|
||||
entity_1: Type[_TModel_1],
|
||||
entity_2: _TScalar_2,
|
||||
entity_3: Type[_TModel_3],
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TModel_3]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: _TScalar_0,
|
||||
entity_1: Type[_TModel_1],
|
||||
entity_2: Type[_TModel_2],
|
||||
entity_3: _TScalar_3,
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TScalar_3]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: _TScalar_0,
|
||||
entity_1: Type[_TModel_1],
|
||||
entity_2: Type[_TModel_2],
|
||||
entity_3: Type[_TModel_3],
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TModel_3]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: Type[_TModel_0],
|
||||
entity_1: _TScalar_1,
|
||||
entity_2: _TScalar_2,
|
||||
entity_3: _TScalar_3,
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TScalar_3]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: Type[_TModel_0],
|
||||
entity_1: _TScalar_1,
|
||||
entity_2: _TScalar_2,
|
||||
entity_3: Type[_TModel_3],
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TModel_3]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: Type[_TModel_0],
|
||||
entity_1: _TScalar_1,
|
||||
entity_2: Type[_TModel_2],
|
||||
entity_3: _TScalar_3,
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TScalar_3]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: Type[_TModel_0],
|
||||
entity_1: _TScalar_1,
|
||||
entity_2: Type[_TModel_2],
|
||||
entity_3: Type[_TModel_3],
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TModel_3]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: Type[_TModel_0],
|
||||
entity_1: Type[_TModel_1],
|
||||
entity_2: _TScalar_2,
|
||||
entity_3: _TScalar_3,
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TScalar_3]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: Type[_TModel_0],
|
||||
entity_1: Type[_TModel_1],
|
||||
entity_2: _TScalar_2,
|
||||
entity_3: Type[_TModel_3],
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TModel_3]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: Type[_TModel_0],
|
||||
entity_1: Type[_TModel_1],
|
||||
entity_2: Type[_TModel_2],
|
||||
entity_3: _TScalar_3,
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TScalar_3]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
entity_0: Type[_TModel_0],
|
||||
entity_1: Type[_TModel_1],
|
||||
entity_2: Type[_TModel_2],
|
||||
entity_3: Type[_TModel_3],
|
||||
**kw: Any,
|
||||
) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TModel_3]]:
|
||||
...
|
||||
|
||||
|
||||
# Generated overloads end
|
||||
|
||||
|
||||
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:
|
||||
if len(entities) == 1:
|
||||
return SelectOfScalar._create(*entities, **kw) # type: ignore
|
||||
return Select._create(*entities, **kw) # type: ignore
|
||||
|
||||
|
||||
# TODO: add several @overload from Python types to SQLAlchemy equivalents
|
||||
def col(column_expression: Any) -> ColumnClause:
|
||||
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
|
||||
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
|
||||
return column_expression
|
119
sqlmodel/sql/expression.py.jinja2
Normal file
119
sqlmodel/sql/expression.py.jinja2
Normal file
@ -0,0 +1,119 @@
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
Mapping,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy.orm import InstrumentedAttribute
|
||||
from sqlalchemy.sql.elements import ColumnClause
|
||||
from sqlalchemy.sql.expression import Select as _Select
|
||||
|
||||
_TSelect = TypeVar("_TSelect")
|
||||
|
||||
# Workaround Generics incompatibility in Python 3.6
|
||||
# Ref: https://github.com/python/typing/issues/449#issuecomment-316061322
|
||||
if sys.version_info.minor >= 7:
|
||||
|
||||
class Select(_Select, Generic[_TSelect]):
|
||||
pass
|
||||
|
||||
# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different
|
||||
# purpose. This is the same as a normal SQLAlchemy Select class where there's only one
|
||||
# entity, so the result will be converted to a scalar by default. This way writing
|
||||
# for loops on the results will feel natural.
|
||||
class SelectOfScalar(_Select, Generic[_TSelect]):
|
||||
pass
|
||||
|
||||
|
||||
else:
|
||||
from typing import GenericMeta # type: ignore
|
||||
|
||||
class GenericSelectMeta(GenericMeta, _Select.__class__): # type: ignore
|
||||
pass
|
||||
|
||||
class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
|
||||
pass
|
||||
|
||||
class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
|
||||
pass
|
||||
|
||||
# Cast them for editors to work correctly, from several tricks tried, this works
|
||||
# for both VS Code and PyCharm
|
||||
Select = cast("Select", _Py36Select) # type: ignore
|
||||
SelectOfScalar = cast("SelectOfScalar", _Py36SelectOfScalar) # type: ignore
|
||||
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from ..main import SQLModel
|
||||
|
||||
# Generated TypeVars start
|
||||
|
||||
{% for i in range(number_of_types) %}
|
||||
_TScalar_{{ i }} = TypeVar(
|
||||
"_TScalar_{{ i }}",
|
||||
Column,
|
||||
Sequence,
|
||||
Mapping,
|
||||
UUID,
|
||||
datetime,
|
||||
float,
|
||||
int,
|
||||
bool,
|
||||
bytes,
|
||||
str,
|
||||
None,
|
||||
)
|
||||
|
||||
_TModel_{{ i }} = TypeVar("_TModel_{{ i }}", bound="SQLModel")
|
||||
|
||||
{% endfor %}
|
||||
|
||||
# Generated TypeVars end
|
||||
|
||||
@overload
|
||||
def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]: # type: ignore
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: # type: ignore
|
||||
...
|
||||
|
||||
|
||||
# Generated overloads start
|
||||
|
||||
{% for signature in signatures %}
|
||||
|
||||
@overload
|
||||
def select( # type: ignore
|
||||
{% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %}**kw: Any,
|
||||
) -> Select[Tuple[{%for ret in signature[1] %}{{ ret }} {% if not loop.last %}, {% endif %}{% endfor %}]]:
|
||||
...
|
||||
|
||||
{% endfor %}
|
||||
|
||||
# Generated overloads end
|
||||
|
||||
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:
|
||||
if len(entities) == 1:
|
||||
return SelectOfScalar._create(*entities, **kw) # type: ignore
|
||||
return Select._create(*entities, **kw)
|
||||
|
||||
|
||||
# TODO: add several @overload from Python types to SQLAlchemy equivalents
|
||||
def col(column_expression: Any) -> ColumnClause:
|
||||
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
|
||||
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
|
||||
return column_expression
|
60
sqlmodel/sql/sqltypes.py
Normal file
60
sqlmodel/sql/sqltypes.py
Normal file
@ -0,0 +1,60 @@
|
||||
import uuid
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.engine.interfaces import Dialect
|
||||
from sqlalchemy.types import CHAR, TypeDecorator
|
||||
|
||||
|
||||
class AutoString(types.TypeDecorator):
|
||||
|
||||
impl = types.String
|
||||
cache_ok = True
|
||||
mysql_default_length = 255
|
||||
|
||||
def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]":
|
||||
impl = cast(types.String, self.impl)
|
||||
if impl.length is None and dialect.name == "mysql":
|
||||
return dialect.type_descriptor(types.String(self.mysql_default_length)) # type: ignore
|
||||
return super().load_dialect_impl(dialect)
|
||||
|
||||
|
||||
# Reference form SQLAlchemy docs: https://docs.sqlalchemy.org/en/14/core/custom_types.html#backend-agnostic-guid-type
|
||||
# with small modifications
|
||||
class GUID(TypeDecorator):
|
||||
"""Platform-independent GUID type.
|
||||
|
||||
Uses PostgreSQL's UUID type, otherwise uses
|
||||
CHAR(32), storing as stringified hex values.
|
||||
|
||||
"""
|
||||
|
||||
impl = CHAR
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
if dialect.name == "postgresql":
|
||||
return dialect.type_descriptor(UUID())
|
||||
else:
|
||||
return dialect.type_descriptor(CHAR(32))
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
elif dialect.name == "postgresql":
|
||||
return str(value)
|
||||
else:
|
||||
if not isinstance(value, uuid.UUID):
|
||||
return f"{uuid.UUID(value).int:x}"
|
||||
else:
|
||||
# hexstring
|
||||
return f"{value.int:x}"
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
else:
|
||||
if not isinstance(value, uuid.UUID):
|
||||
value = uuid.UUID(value)
|
||||
return value
|
Loading…
Reference in New Issue
Block a user