PaiGram/core/sqlmodel/session.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

119 lines
4.0 KiB
Python
Raw Normal View History

from typing import Any, Mapping, Optional, Sequence, Type, TypeVar, Union, overload
from sqlalchemy import util
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
from sqlalchemy.sql.base import Executable as _Executable
from sqlmodel.engine.result import Result, ScalarResult
from sqlmodel.orm.session import Session
from sqlmodel.sql.base import Executable
from sqlmodel.sql.expression import Select, SelectOfScalar
from typing_extensions import Literal
_TSelectParam = TypeVar("_TSelectParam")
__all__ = ("AsyncSession",)
class AsyncSession(_AsyncSession): # pylint: disable=W0223
sync_session_class = Session
sync_session: Session
def __init__(
self,
bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
sync_session_class: Type[Session] = Session,
**kw: Any,
):
super().__init__(
bind=bind,
binds=binds,
sync_session_class=sync_session_class,
**kw,
)
@overload
async def exec(
self,
statement: Select[_TSelectParam],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
**kw: Any,
) -> Result[_TSelectParam]:
...
@overload
async def exec(
self,
statement: SelectOfScalar[_TSelectParam],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
**kw: Any,
) -> ScalarResult[_TSelectParam]:
...
async def exec(
self,
statement: Union[
Select[_TSelectParam],
SelectOfScalar[_TSelectParam],
Executable[_TSelectParam],
],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
**kw: Any,
) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]:
results = super().execute(
statement,
params=params,
execution_options=execution_options,
bind_arguments=bind_arguments,
**kw,
)
if isinstance(statement, SelectOfScalar):
return (await results).scalars() # type: ignore
return await results # type: ignore
async def execute( # pylint: disable=W0221
self,
statement: _Executable,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
**kw: Any,
) -> Result[Any]:
return await super().execute( # type: ignore
statement=statement,
params=params,
execution_options=execution_options,
bind_arguments=bind_arguments,
**kw,
)
async def get( # pylint: disable=W0221
self,
entity: Type[_TSelectParam],
ident: Any,
options: Optional[Sequence[Any]] = None,
populate_existing: bool = False,
with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None,
identity_token: Optional[Any] = None,
execution_options: Optional[Mapping[Any, Any]] = util.EMPTY_DICT,
) -> Optional[_TSelectParam]:
return await super().get(
entity=entity,
ident=ident,
options=options,
populate_existing=populate_existing,
with_for_update=with_for_update,
identity_token=identity_token,
execution_options=execution_options,
)