diff --git a/sqlmodel/ext/asyncio/session.py b/sqlmodel/ext/asyncio/session.py index 80267b2..e6b176a 100644 --- a/sqlmodel/ext/asyncio/session.py +++ b/sqlmodel/ext/asyncio/session.py @@ -1,9 +1,11 @@ -from typing import Any, Mapping, Optional, Sequence, TypeVar, Union +from typing import Any, Dict, Mapping, Optional, Sequence, Type, TypeVar, Union from sqlalchemy import util from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession from sqlalchemy.ext.asyncio import engine from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine +from sqlalchemy.orm import Mapper +from sqlalchemy.sql.expression import TableClause from sqlalchemy.util.concurrency import greenlet_spawn from sqlmodel.sql.base import Executable @@ -14,13 +16,18 @@ from ...sql.expression import Select _T = TypeVar("_T") +BindsType = Dict[ + Union[Type[Any], Mapper[Any], TableClause, str], Union[AsyncEngine, AsyncConnection] +] + + class AsyncSession(_AsyncSession): sync_session: Session def __init__( self, bind: Optional[Union[AsyncConnection, AsyncEngine]] = None, - binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None, + binds: Optional[BindsType] = None, **kw: Any, ): # All the same code of the original AsyncSession diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 141005f..56ba414 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -61,6 +61,9 @@ def __dataclass_transform__( class FieldInfo(PydanticFieldInfo): + + nullable: Union[bool, PydanticUndefinedType] + def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None: primary_key = kwargs.pop("primary_key", False) nullable = kwargs.pop("nullable", PydanticUndefined) @@ -587,7 +590,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry def _is_field_noneable(field: FieldInfo) -> bool: - if getattr(field, "nullable", PydanticUndefined) is not PydanticUndefined: + if not isinstance(field.nullable, PydanticUndefinedType): return field.nullable if not field.is_required(): default = getattr(field, "original_default", field.default) diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index 9a07956..dab26c3 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -12,9 +12,7 @@ from typing import ( from sqlalchemy import util from sqlalchemy.orm import Mapper as _Mapper -from sqlalchemy.orm import Query as _Query from sqlalchemy.orm import Session as _Session -from sqlalchemy.sql.base import Executable as _Executable from sqlalchemy.sql.selectable import ForUpdateArg as _ForUpdateArg from sqlmodel.sql.expression import Select, SelectOfScalar @@ -81,56 +79,6 @@ class Session(_Session): return results.scalars() # type: ignore return results # type: ignore - def execute( - self, - statement: _Executable, - params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, - execution_options: Mapping[str, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Dict[str, Any]] = None, - _parent_execute_state: Optional[Any] = None, - _add_event: Optional[Any] = None, - **kw: Any, - ) -> Result[Any]: - """ - 🚨 You probably want to use `session.exec()` instead of `session.execute()`. - - This is the original SQLAlchemy `session.execute()` method that returns objects - of type `Row`, and that you have to call `scalars()` to get the model objects. - - For example: - - ```Python - heroes = session.execute(select(Hero)).scalars().all() - ``` - - instead you could use `exec()`: - - ```Python - heroes = session.exec(select(Hero)).all() - ``` - """ - return super().execute( # type: ignore - statement, - params=params, - execution_options=execution_options, - bind_arguments=bind_arguments, - _parent_execute_state=_parent_execute_state, - _add_event=_add_event, - **kw, - ) - - def query(self, *entities: Any, **kwargs: Any) -> "_Query[Any]": - """ - 🚨 You probably want to use `session.exec()` instead of `session.query()`. - - `session.exec()` is SQLModel's own short version with increased type - annotations. - - Or otherwise you might want to use `session.execute()` instead of - `session.query()`. - """ - return super().query(*entities, **kwargs) # type: ignore - def get( self, entity: Union[Type[_TSelectParam], "_Mapper[_TSelectParam]"], @@ -152,3 +100,34 @@ class Session(_Session): execution_options=execution_options, bind_arguments=bind_arguments, ) + + +Session.query.__doc__ = """ +🚨 You probably want to use `session.exec()` instead of `session.query()`. + +`session.exec()` is SQLModel's own short version with increased type +annotations. + +Or otherwise you might want to use `session.execute()` instead of +`session.query()`. +""" + + +Session.execute.__doc__ = """ +🚨 You probably want to use `session.exec()` instead of `session.execute()`. + +This is the original SQLAlchemy `session.execute()` method that returns objects +of type `Row`, and that you have to call `scalars()` to get the model objects. + +For example: + +```Python +heroes = session.execute(select(Hero)).scalars().all() +``` + +instead you could use `exec()`: + +```Python +heroes = session.exec(select(Hero)).all() +``` +""" diff --git a/sqlmodel/sql/expression.py b/sqlmodel/sql/expression.py index 8cb2309..a0ac1bd 100644 --- a/sqlmodel/sql/expression.py +++ b/sqlmodel/sql/expression.py @@ -35,14 +35,6 @@ class SelectOfScalar(_Select[Tuple[_TSelect]], Generic[_TSelect]): inherit_cache = True -# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different -# purpose. This is the same as a normal SQLAlchemy Select class where there's only one -# entity, so the result will be converted to a scalar by default. This way writing -# for loops on the results will feel natural. -class SelectOfScalar(_Select, Generic[_TSelect]): - inherit_cache = True - - if TYPE_CHECKING: # pragma: no cover from ..main import SQLModel diff --git a/sqlmodel/sql/expression.py.jinja2 b/sqlmodel/sql/expression.py.jinja2 index 55f4a1a..4284543 100644 --- a/sqlmodel/sql/expression.py.jinja2 +++ b/sqlmodel/sql/expression.py.jinja2 @@ -30,12 +30,6 @@ class Select(_Select[Tuple[_TSelect]], Generic[_TSelect]): class SelectOfScalar(_Select[Tuple[_TSelect]], Generic[_TSelect]): inherit_cache = True -# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different -# purpose. This is the same as a normal SQLAlchemy Select class where there's only one -# entity, so the result will be converted to a scalar by default. This way writing -# for loops on the results will feel natural. -class SelectOfScalar(_Select, Generic[_TSelect]): - inherit_cache = True if TYPE_CHECKING: # pragma: no cover from ..main import SQLModel