fix some linting errors

This commit is contained in:
Mohamed Farahat 2023-03-24 14:56:26 +02:00 committed by Anton De Meester
parent c634631ab1
commit b9cf129188
10 changed files with 66 additions and 26 deletions

View File

@ -136,4 +136,4 @@ def create_engine(
if not isinstance(query_cache_size, _DefaultPlaceholder):
current_kwargs["query_cache_size"] = query_cache_size
current_kwargs.update(kwargs)
return _create_engine(url, **current_kwargs) # type: ignore
return _create_engine(url, **current_kwargs)

View File

@ -1,4 +1,4 @@
from typing import Generic, Iterator, List, Optional, TypeVar
from typing import Generic, Iterator, List, Optional, Sequence, TypeVar
from sqlalchemy.engine.result import Result as _Result
from sqlalchemy.engine.result import ScalarResult as _ScalarResult
@ -6,24 +6,24 @@ from sqlalchemy.engine.result import ScalarResult as _ScalarResult
_T = TypeVar("_T")
class ScalarResult(_ScalarResult, Generic[_T]):
def all(self) -> List[_T]:
class ScalarResult(_ScalarResult[_T], Generic[_T]):
def all(self) -> Sequence[_T]:
return super().all()
def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]:
def partitions(self, size: Optional[int] = None) -> Iterator[Sequence[_T]]:
return super().partitions(size)
def fetchall(self) -> List[_T]:
def fetchall(self) -> Sequence[_T]:
return super().fetchall()
def fetchmany(self, size: Optional[int] = None) -> List[_T]:
def fetchmany(self, size: Optional[int] = None) -> Sequence[_T]:
return super().fetchmany(size)
def __iter__(self) -> Iterator[_T]:
return super().__iter__()
def __next__(self) -> _T:
return super().__next__() # type: ignore
return super().__next__()
def first(self) -> Optional[_T]:
return super().first()
@ -32,10 +32,10 @@ class ScalarResult(_ScalarResult, Generic[_T]):
return super().one_or_none()
def one(self) -> _T:
return super().one() # type: ignore
return super().one()
class Result(_Result, Generic[_T]):
class Result(_Result[_T], Generic[_T]):
def scalars(self, index: int = 0) -> ScalarResult[_T]:
return super().scalars(index) # type: ignore
@ -76,4 +76,4 @@ class Result(_Result, Generic[_T]):
return super().one() # type: ignore
def scalar(self) -> Optional[_T]:
return super().scalar()
return super().scalar() # type: ignore

View File

@ -523,7 +523,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
return
else:
# Set in SQLAlchemy, before Pydantic to trigger events and updates
if getattr(self.__config__, "table", False) and is_instrumented(self, name):
if getattr(self.__config__, "table", False) and is_instrumented(self, name): # type: ignore
set_attribute(self, name, value)
# Set in Pydantic model to trigger possible validation changes, only for
# non relationship values

View File

@ -118,7 +118,7 @@ class Session(_Session):
Or otherwise you might want to use `session.execute()` instead of
`session.query()`.
"""
return super().query(*entities, **kwargs)
return super().query(*entities, **kwargs) # type: ignore
def get(
self,

View File

@ -406,4 +406,4 @@ def select(*entities: Any) -> Union[Select, SelectOfScalar]: # type: ignore
def col(column_expression: Any) -> ColumnClause: # type: ignore
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
return column_expression
return column_expression # type: ignore

View File

@ -16,7 +16,7 @@ class AutoString(types.TypeDecorator): # type: ignore
def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]":
impl = cast(types.String, self.impl)
if impl.length is None and dialect.name == "mysql":
return dialect.type_descriptor(types.String(self.mysql_default_length)) # type: ignore
return dialect.type_descriptor(types.String(self.mysql_default_length))
return super().load_dialect_impl(dialect)
@ -35,9 +35,9 @@ class GUID(types.TypeDecorator): # type: ignore
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine: # type: ignore
if dialect.name == "postgresql":
return dialect.type_descriptor(UUID()) # type: ignore
return dialect.type_descriptor(UUID())
else:
return dialect.type_descriptor(CHAR(32)) # type: ignore
return dialect.type_descriptor(CHAR(32))
def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]:
if value is None:

View File

@ -173,8 +173,18 @@ def test_tutorial(clear_sqlmodel):
insp: Inspector = inspect(mod.engine)
indexes = insp.get_indexes(str(mod.Hero.__tablename__))
expected_indexes = [
{"name": "ix_hero_name", "column_names": ["name"], "unique": 0, 'dialect_options': {}},
{"name": "ix_hero_age", "column_names": ["age"], "unique": 0, 'dialect_options': {}},
{
"name": "ix_hero_name",
"column_names": ["name"],
"unique": 0,
"dialect_options": {},
},
{
"name": "ix_hero_age",
"column_names": ["age"],
"unique": 0,
"dialect_options": {},
},
]
for index in expected_indexes:
assert index in indexes, "This expected index should be in the indexes in DB"

View File

@ -173,8 +173,18 @@ def test_tutorial(clear_sqlmodel):
insp: Inspector = inspect(mod.engine)
indexes = insp.get_indexes(str(mod.Hero.__tablename__))
expected_indexes = [
{"name": "ix_hero_age", "column_names": ["age"], "unique": 0, 'dialect_options': {}},
{"name": "ix_hero_name", "column_names": ["name"], "unique": 0, 'dialect_options': {}},
{
"name": "ix_hero_age",
"column_names": ["age"],
"unique": 0,
"dialect_options": {},
},
{
"name": "ix_hero_name",
"column_names": ["name"],
"unique": 0,
"dialect_options": {},
},
]
for index in expected_indexes:
assert index in indexes, "This expected index should be in the indexes in DB"

View File

@ -25,8 +25,18 @@ def test_tutorial(clear_sqlmodel):
insp: Inspector = inspect(mod.engine)
indexes = insp.get_indexes(str(mod.Hero.__tablename__))
expected_indexes = [
{"name": "ix_hero_name", "column_names": ["name"], "unique": 0, 'dialect_options': {}},
{"name": "ix_hero_age", "column_names": ["age"], "unique": 0, 'dialect_options': {}},
{
"name": "ix_hero_name",
"column_names": ["name"],
"unique": 0,
"dialect_options": {},
},
{
"name": "ix_hero_age",
"column_names": ["age"],
"unique": 0,
"dialect_options": {},
},
]
for index in expected_indexes:
assert index in indexes, "This expected index should be in the indexes in DB"

View File

@ -26,8 +26,18 @@ def test_tutorial(clear_sqlmodel):
insp: Inspector = inspect(mod.engine)
indexes = insp.get_indexes(str(mod.Hero.__tablename__))
expected_indexes = [
{"name": "ix_hero_name", "column_names": ["name"], "unique": 0, 'dialect_options': {}},
{"name": "ix_hero_age", "column_names": ["age"], "unique": 0, 'dialect_options': {}},
{
"name": "ix_hero_name",
"column_names": ["name"],
"unique": 0,
"dialect_options": {},
},
{
"name": "ix_hero_age",
"column_names": ["age"],
"unique": 0,
"dialect_options": {},
},
]
for index in expected_indexes:
assert index in indexes, "This expected index should be in the indexes in DB"