fix some linting errors

This commit is contained in:
Mohamed Farahat 2023-03-24 14:56:26 +02:00 committed by Anton De Meester
parent c634631ab1
commit b9cf129188
10 changed files with 66 additions and 26 deletions

View File

@ -136,4 +136,4 @@ def create_engine(
if not isinstance(query_cache_size, _DefaultPlaceholder): if not isinstance(query_cache_size, _DefaultPlaceholder):
current_kwargs["query_cache_size"] = query_cache_size current_kwargs["query_cache_size"] = query_cache_size
current_kwargs.update(kwargs) current_kwargs.update(kwargs)
return _create_engine(url, **current_kwargs) # type: ignore return _create_engine(url, **current_kwargs)

View File

@ -1,4 +1,4 @@
from typing import Generic, Iterator, List, Optional, TypeVar from typing import Generic, Iterator, List, Optional, Sequence, TypeVar
from sqlalchemy.engine.result import Result as _Result from sqlalchemy.engine.result import Result as _Result
from sqlalchemy.engine.result import ScalarResult as _ScalarResult from sqlalchemy.engine.result import ScalarResult as _ScalarResult
@ -6,24 +6,24 @@ from sqlalchemy.engine.result import ScalarResult as _ScalarResult
_T = TypeVar("_T") _T = TypeVar("_T")
class ScalarResult(_ScalarResult, Generic[_T]): class ScalarResult(_ScalarResult[_T], Generic[_T]):
def all(self) -> List[_T]: def all(self) -> Sequence[_T]:
return super().all() return super().all()
def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]: def partitions(self, size: Optional[int] = None) -> Iterator[Sequence[_T]]:
return super().partitions(size) return super().partitions(size)
def fetchall(self) -> List[_T]: def fetchall(self) -> Sequence[_T]:
return super().fetchall() return super().fetchall()
def fetchmany(self, size: Optional[int] = None) -> List[_T]: def fetchmany(self, size: Optional[int] = None) -> Sequence[_T]:
return super().fetchmany(size) return super().fetchmany(size)
def __iter__(self) -> Iterator[_T]: def __iter__(self) -> Iterator[_T]:
return super().__iter__() return super().__iter__()
def __next__(self) -> _T: def __next__(self) -> _T:
return super().__next__() # type: ignore return super().__next__()
def first(self) -> Optional[_T]: def first(self) -> Optional[_T]:
return super().first() return super().first()
@ -32,10 +32,10 @@ class ScalarResult(_ScalarResult, Generic[_T]):
return super().one_or_none() return super().one_or_none()
def one(self) -> _T: def one(self) -> _T:
return super().one() # type: ignore return super().one()
class Result(_Result, Generic[_T]): class Result(_Result[_T], Generic[_T]):
def scalars(self, index: int = 0) -> ScalarResult[_T]: def scalars(self, index: int = 0) -> ScalarResult[_T]:
return super().scalars(index) # type: ignore return super().scalars(index) # type: ignore
@ -76,4 +76,4 @@ class Result(_Result, Generic[_T]):
return super().one() # type: ignore return super().one() # type: ignore
def scalar(self) -> Optional[_T]: def scalar(self) -> Optional[_T]:
return super().scalar() return super().scalar() # type: ignore

View File

@ -479,7 +479,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore
__name__: ClassVar[str] __name__: ClassVar[str]
metadata: ClassVar[MetaData] metadata: ClassVar[MetaData]
__allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six
class Config: class Config:
orm_mode = True orm_mode = True
@ -523,7 +523,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
return return
else: else:
# Set in SQLAlchemy, before Pydantic to trigger events and updates # Set in SQLAlchemy, before Pydantic to trigger events and updates
if getattr(self.__config__, "table", False) and is_instrumented(self, name): if getattr(self.__config__, "table", False) and is_instrumented(self, name): # type: ignore
set_attribute(self, name, value) set_attribute(self, name, value)
# Set in Pydantic model to trigger possible validation changes, only for # Set in Pydantic model to trigger possible validation changes, only for
# non relationship values # non relationship values

View File

@ -118,7 +118,7 @@ class Session(_Session):
Or otherwise you might want to use `session.execute()` instead of Or otherwise you might want to use `session.execute()` instead of
`session.query()`. `session.query()`.
""" """
return super().query(*entities, **kwargs) return super().query(*entities, **kwargs) # type: ignore
def get( def get(
self, self,

View File

@ -406,4 +406,4 @@ def select(*entities: Any) -> Union[Select, SelectOfScalar]: # type: ignore
def col(column_expression: Any) -> ColumnClause: # type: ignore def col(column_expression: Any) -> ColumnClause: # type: ignore
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)): if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}") raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
return column_expression return column_expression # type: ignore

View File

@ -16,7 +16,7 @@ class AutoString(types.TypeDecorator): # type: ignore
def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]": def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]":
impl = cast(types.String, self.impl) impl = cast(types.String, self.impl)
if impl.length is None and dialect.name == "mysql": if impl.length is None and dialect.name == "mysql":
return dialect.type_descriptor(types.String(self.mysql_default_length)) # type: ignore return dialect.type_descriptor(types.String(self.mysql_default_length))
return super().load_dialect_impl(dialect) return super().load_dialect_impl(dialect)
@ -35,9 +35,9 @@ class GUID(types.TypeDecorator): # type: ignore
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine: # type: ignore def load_dialect_impl(self, dialect: Dialect) -> TypeEngine: # type: ignore
if dialect.name == "postgresql": if dialect.name == "postgresql":
return dialect.type_descriptor(UUID()) # type: ignore return dialect.type_descriptor(UUID())
else: else:
return dialect.type_descriptor(CHAR(32)) # type: ignore return dialect.type_descriptor(CHAR(32))
def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]: def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]:
if value is None: if value is None:

View File

@ -173,8 +173,18 @@ def test_tutorial(clear_sqlmodel):
insp: Inspector = inspect(mod.engine) insp: Inspector = inspect(mod.engine)
indexes = insp.get_indexes(str(mod.Hero.__tablename__)) indexes = insp.get_indexes(str(mod.Hero.__tablename__))
expected_indexes = [ expected_indexes = [
{"name": "ix_hero_name", "column_names": ["name"], "unique": 0, 'dialect_options': {}}, {
{"name": "ix_hero_age", "column_names": ["age"], "unique": 0, 'dialect_options': {}}, "name": "ix_hero_name",
"column_names": ["name"],
"unique": 0,
"dialect_options": {},
},
{
"name": "ix_hero_age",
"column_names": ["age"],
"unique": 0,
"dialect_options": {},
},
] ]
for index in expected_indexes: for index in expected_indexes:
assert index in indexes, "This expected index should be in the indexes in DB" assert index in indexes, "This expected index should be in the indexes in DB"

View File

@ -173,8 +173,18 @@ def test_tutorial(clear_sqlmodel):
insp: Inspector = inspect(mod.engine) insp: Inspector = inspect(mod.engine)
indexes = insp.get_indexes(str(mod.Hero.__tablename__)) indexes = insp.get_indexes(str(mod.Hero.__tablename__))
expected_indexes = [ expected_indexes = [
{"name": "ix_hero_age", "column_names": ["age"], "unique": 0, 'dialect_options': {}}, {
{"name": "ix_hero_name", "column_names": ["name"], "unique": 0, 'dialect_options': {}}, "name": "ix_hero_age",
"column_names": ["age"],
"unique": 0,
"dialect_options": {},
},
{
"name": "ix_hero_name",
"column_names": ["name"],
"unique": 0,
"dialect_options": {},
},
] ]
for index in expected_indexes: for index in expected_indexes:
assert index in indexes, "This expected index should be in the indexes in DB" assert index in indexes, "This expected index should be in the indexes in DB"

View File

@ -25,8 +25,18 @@ def test_tutorial(clear_sqlmodel):
insp: Inspector = inspect(mod.engine) insp: Inspector = inspect(mod.engine)
indexes = insp.get_indexes(str(mod.Hero.__tablename__)) indexes = insp.get_indexes(str(mod.Hero.__tablename__))
expected_indexes = [ expected_indexes = [
{"name": "ix_hero_name", "column_names": ["name"], "unique": 0, 'dialect_options': {}}, {
{"name": "ix_hero_age", "column_names": ["age"], "unique": 0, 'dialect_options': {}}, "name": "ix_hero_name",
"column_names": ["name"],
"unique": 0,
"dialect_options": {},
},
{
"name": "ix_hero_age",
"column_names": ["age"],
"unique": 0,
"dialect_options": {},
},
] ]
for index in expected_indexes: for index in expected_indexes:
assert index in indexes, "This expected index should be in the indexes in DB" assert index in indexes, "This expected index should be in the indexes in DB"

View File

@ -26,8 +26,18 @@ def test_tutorial(clear_sqlmodel):
insp: Inspector = inspect(mod.engine) insp: Inspector = inspect(mod.engine)
indexes = insp.get_indexes(str(mod.Hero.__tablename__)) indexes = insp.get_indexes(str(mod.Hero.__tablename__))
expected_indexes = [ expected_indexes = [
{"name": "ix_hero_name", "column_names": ["name"], "unique": 0, 'dialect_options': {}}, {
{"name": "ix_hero_age", "column_names": ["age"], "unique": 0, 'dialect_options': {}}, "name": "ix_hero_name",
"column_names": ["name"],
"unique": 0,
"dialect_options": {},
},
{
"name": "ix_hero_age",
"column_names": ["age"],
"unique": 0,
"dialect_options": {},
},
] ]
for index in expected_indexes: for index in expected_indexes:
assert index in indexes, "This expected index should be in the indexes in DB" assert index in indexes, "This expected index should be in the indexes in DB"