Fix linting issues

This commit is contained in:
Santiago Martinez 2023-08-09 01:43:50 +01:00
parent 972ee56fde
commit 179183c018
5 changed files with 44 additions and 69 deletions

View File

@ -1,9 +1,11 @@
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union from typing import Any, Dict, Mapping, Optional, Sequence, Type, TypeVar, Union
from sqlalchemy import util from sqlalchemy import util
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
from sqlalchemy.ext.asyncio import engine from sqlalchemy.ext.asyncio import engine
from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
from sqlalchemy.orm import Mapper
from sqlalchemy.sql.expression import TableClause
from sqlalchemy.util.concurrency import greenlet_spawn from sqlalchemy.util.concurrency import greenlet_spawn
from sqlmodel.sql.base import Executable from sqlmodel.sql.base import Executable
@ -14,13 +16,18 @@ from ...sql.expression import Select
_T = TypeVar("_T") _T = TypeVar("_T")
BindsType = Dict[
Union[Type[Any], Mapper[Any], TableClause, str], Union[AsyncEngine, AsyncConnection]
]
class AsyncSession(_AsyncSession): class AsyncSession(_AsyncSession):
sync_session: Session sync_session: Session
def __init__( def __init__(
self, self,
bind: Optional[Union[AsyncConnection, AsyncEngine]] = None, bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None, binds: Optional[BindsType] = None,
**kw: Any, **kw: Any,
): ):
# All the same code of the original AsyncSession # All the same code of the original AsyncSession

View File

@ -61,6 +61,9 @@ def __dataclass_transform__(
class FieldInfo(PydanticFieldInfo): class FieldInfo(PydanticFieldInfo):
nullable: Union[bool, PydanticUndefinedType]
def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None: def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None:
primary_key = kwargs.pop("primary_key", False) primary_key = kwargs.pop("primary_key", False)
nullable = kwargs.pop("nullable", PydanticUndefined) nullable = kwargs.pop("nullable", PydanticUndefined)
@ -587,7 +590,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
def _is_field_noneable(field: FieldInfo) -> bool: def _is_field_noneable(field: FieldInfo) -> bool:
if getattr(field, "nullable", PydanticUndefined) is not PydanticUndefined: if not isinstance(field.nullable, PydanticUndefinedType):
return field.nullable return field.nullable
if not field.is_required(): if not field.is_required():
default = getattr(field, "original_default", field.default) default = getattr(field, "original_default", field.default)

View File

@ -12,9 +12,7 @@ from typing import (
from sqlalchemy import util from sqlalchemy import util
from sqlalchemy.orm import Mapper as _Mapper from sqlalchemy.orm import Mapper as _Mapper
from sqlalchemy.orm import Query as _Query
from sqlalchemy.orm import Session as _Session from sqlalchemy.orm import Session as _Session
from sqlalchemy.sql.base import Executable as _Executable
from sqlalchemy.sql.selectable import ForUpdateArg as _ForUpdateArg from sqlalchemy.sql.selectable import ForUpdateArg as _ForUpdateArg
from sqlmodel.sql.expression import Select, SelectOfScalar from sqlmodel.sql.expression import Select, SelectOfScalar
@ -81,56 +79,6 @@ class Session(_Session):
return results.scalars() # type: ignore return results.scalars() # type: ignore
return results # type: ignore return results # type: ignore
def execute(
self,
statement: _Executable,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Dict[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> Result[Any]:
"""
🚨 You probably want to use `session.exec()` instead of `session.execute()`.
This is the original SQLAlchemy `session.execute()` method that returns objects
of type `Row`, and that you have to call `scalars()` to get the model objects.
For example:
```Python
heroes = session.execute(select(Hero)).scalars().all()
```
instead you could use `exec()`:
```Python
heroes = session.exec(select(Hero)).all()
```
"""
return super().execute( # type: ignore
statement,
params=params,
execution_options=execution_options,
bind_arguments=bind_arguments,
_parent_execute_state=_parent_execute_state,
_add_event=_add_event,
**kw,
)
def query(self, *entities: Any, **kwargs: Any) -> "_Query[Any]":
"""
🚨 You probably want to use `session.exec()` instead of `session.query()`.
`session.exec()` is SQLModel's own short version with increased type
annotations.
Or otherwise you might want to use `session.execute()` instead of
`session.query()`.
"""
return super().query(*entities, **kwargs) # type: ignore
def get( def get(
self, self,
entity: Union[Type[_TSelectParam], "_Mapper[_TSelectParam]"], entity: Union[Type[_TSelectParam], "_Mapper[_TSelectParam]"],
@ -152,3 +100,34 @@ class Session(_Session):
execution_options=execution_options, execution_options=execution_options,
bind_arguments=bind_arguments, bind_arguments=bind_arguments,
) )
Session.query.__doc__ = """
🚨 You probably want to use `session.exec()` instead of `session.query()`.
`session.exec()` is SQLModel's own short version with increased type
annotations.
Or otherwise you might want to use `session.execute()` instead of
`session.query()`.
"""
Session.execute.__doc__ = """
🚨 You probably want to use `session.exec()` instead of `session.execute()`.
This is the original SQLAlchemy `session.execute()` method that returns objects
of type `Row`, and that you have to call `scalars()` to get the model objects.
For example:
```Python
heroes = session.execute(select(Hero)).scalars().all()
```
instead you could use `exec()`:
```Python
heroes = session.exec(select(Hero)).all()
```
"""

View File

@ -35,14 +35,6 @@ class SelectOfScalar(_Select[Tuple[_TSelect]], Generic[_TSelect]):
inherit_cache = True inherit_cache = True
# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different
# purpose. This is the same as a normal SQLAlchemy Select class where there's only one
# entity, so the result will be converted to a scalar by default. This way writing
# for loops on the results will feel natural.
class SelectOfScalar(_Select, Generic[_TSelect]):
inherit_cache = True
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
from ..main import SQLModel from ..main import SQLModel

View File

@ -30,12 +30,6 @@ class Select(_Select[Tuple[_TSelect]], Generic[_TSelect]):
class SelectOfScalar(_Select[Tuple[_TSelect]], Generic[_TSelect]): class SelectOfScalar(_Select[Tuple[_TSelect]], Generic[_TSelect]):
inherit_cache = True inherit_cache = True
# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different
# purpose. This is the same as a normal SQLAlchemy Select class where there's only one
# entity, so the result will be converted to a scalar by default. This way writing
# for loops on the results will feel natural.
class SelectOfScalar(_Select, Generic[_TSelect]):
inherit_cache = True
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
from ..main import SQLModel from ..main import SQLModel