GramCore/sqlmodel/session.py
2023-08-05 10:41:12 +08:00

119 lines
4.0 KiB
Python

from typing import Any, Mapping, Optional, Sequence, Type, TypeVar, Union, overload
from sqlalchemy import util
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
from sqlalchemy.sql.base import Executable as _Executable
from sqlmodel.engine.result import Result, ScalarResult
from sqlmodel.orm.session import Session
from sqlmodel.sql.base import Executable
from sqlmodel.sql.expression import Select, SelectOfScalar
from typing_extensions import Literal
_TSelectParam = TypeVar("_TSelectParam")
__all__ = ("AsyncSession",)
class AsyncSession(_AsyncSession): # pylint: disable=W0223
sync_session_class = Session
sync_session: Session
def __init__(
self,
bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
sync_session_class: Type[Session] = Session,
**kw: Any,
):
super().__init__(
bind=bind,
binds=binds,
sync_session_class=sync_session_class,
**kw,
)
@overload
async def exec(
self,
statement: Select[_TSelectParam],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
**kw: Any,
) -> Result[_TSelectParam]:
...
@overload
async def exec(
self,
statement: SelectOfScalar[_TSelectParam],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
**kw: Any,
) -> ScalarResult[_TSelectParam]:
...
async def exec(
self,
statement: Union[
Select[_TSelectParam],
SelectOfScalar[_TSelectParam],
Executable[_TSelectParam],
],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
**kw: Any,
) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]:
results = super().execute(
statement,
params=params,
execution_options=execution_options,
bind_arguments=bind_arguments,
**kw,
)
if isinstance(statement, SelectOfScalar):
return (await results).scalars() # type: ignore
return await results # type: ignore
async def execute( # pylint: disable=W0221
self,
statement: _Executable,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
**kw: Any,
) -> Result[Any]:
return await super().execute( # type: ignore
statement=statement,
params=params,
execution_options=execution_options,
bind_arguments=bind_arguments,
**kw,
)
async def get( # pylint: disable=W0221
self,
entity: Type[_TSelectParam],
ident: Any,
options: Optional[Sequence[Any]] = None,
populate_existing: bool = False,
with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None,
identity_token: Optional[Any] = None,
execution_options: Optional[Mapping[Any, Any]] = util.EMPTY_DICT,
) -> Optional[_TSelectParam]:
return await super().get(
entity=entity,
ident=ident,
options=options,
populate_existing=populate_existing,
with_for_update=with_for_update,
identity_token=identity_token,
execution_options=execution_options,
)