Add Wish Paginator

This commit is contained in:
洛水居室 2023-05-04 19:35:01 +08:00 committed by GitHub
parent 4291309aba
commit ff07ac974f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 67 additions and 13 deletions

View File

@ -76,6 +76,7 @@ class BaseWishClient(BaseClient):
end_id: int,
banner_type: int,
game: Game,
size: int = 20,
lang: Optional[str] = None,
authkey: Optional[str] = None,
) -> Dict[str, Any]:
@ -86,6 +87,7 @@ class BaseWishClient(BaseClient):
end_id (int): The ending ID of the last wish to retrieve.
banner_type (int): The type of banner to retrieve wishes from.
game (Game): The game to make the request for.
size (int, optional): : The number of wishes to retrieve per page, with a default value of 20.
lang (Optional[str], optional): The language code to use for the request.
If not provided, the class default will be used.
authkey (Optional[str], optional): The authorization key for making the request.
@ -98,7 +100,7 @@ class BaseWishClient(BaseClient):
game=game,
lang=lang,
authkey=authkey,
params=dict(gacha_type=banner_type, size=20, end_id=end_id),
params=dict(gacha_type=banner_type, size=size, end_id=end_id),
)
async def get_banner_names(

View File

@ -1,7 +1,9 @@
from functools import partial
from typing import Optional, List
from simnet.client.wish.base import BaseWishClient
from simnet.models.starrail.wish import StarRailWish
from simnet.utils.enum_ import Game
from simnet.utils.paginator import WishPaginator
class WishClient(BaseWishClient):
@ -9,7 +11,7 @@ class WishClient(BaseWishClient):
async def wish_history(
self,
banner_types: List[int],
banner_type: int,
limit: Optional[int] = None,
lang: Optional[str] = None,
authkey: Optional[str] = None,
@ -19,7 +21,7 @@ class WishClient(BaseWishClient):
Get the wish history for a list of banner types.
Args:
banner_types (List[int], optional): The list of banner types to get the wish history for.
banner_type (int, optional): The banner types to get the wish history for.
limit (Optional[int] , optional): The maximum number of wishes to retrieve.
If not provided, all available wishes will be returned.
lang (Optional[str], optional): The language code to use for the request.
@ -30,12 +32,15 @@ class WishClient(BaseWishClient):
Returns:
List[StarRailWish]: A list of StarRailWish objects representing the retrieved wishes.
"""
wish: List[StarRailWish] = []
banner_names = await self.get_banner_names(
game=Game.STARRAIL, lang=lang, authkey=authkey
paginator = WishPaginator(
end_id,
partial(
self.get_wish_page,
banner_type=banner_type,
game=Game.STARRAIL,
authkey=authkey,
),
)
for banner_type in banner_types:
data = await self.get_wish_page(end_id, banner_type, Game.STARRAIL)
banner_name = banner_names[banner_type]
wish = [StarRailWish(**i, banner_name=banner_name) for i in data["list"]]
items = await paginator.get(limit)
wish = [StarRailWish(**i) for i in items]
return wish

View File

@ -47,9 +47,6 @@ class StarRailWish(APIModel):
banner_type: StarRailBannerType = Field(alias="gacha_type")
"""Type of the banner the wish was made on."""
banner_name: str
"""Name of the banner the wish was made on."""
@validator("banner_type", pre=True)
def cast_banner_type(cls, v: Any) -> int:
"""Converts the banner type from any type to int."""

50
simnet/utils/paginator.py Normal file
View File

@ -0,0 +1,50 @@
from typing import List, Dict, Callable, Any, Awaitable
class WishPaginator:
"""
A paginator for fetching and processing wish data.
Attributes:
end_id (int): The ID of the item to stop fetching at.
fetch_data (Callable[..., Awaitable[Dict[str, Any]]]): An asynchronous function to fetch the raw data.
"""
def __init__(
self,
end_id: int,
fetch_data: Callable[..., Awaitable[Dict[str, Any]]],
):
self.end_id = end_id
self.fetch_data = fetch_data
async def get(self, limit: int) -> List[Dict]:
"""
Fetches and returns the items up to the specified limit.
Args:
limit (int): The maximum number of items to return.
Returns:
List[Dict]: The list of fetched items.
"""
all_items = []
current_end_id = 0
while True:
raw_data = await self.fetch_data(end_id=current_end_id)
items = raw_data["list"]
if not items:
break
current_end_id = items[-1]["id"]
filtered_items = [item for item in items if item["id"] != self.end_id]
if len(filtered_items) < len(items):
all_items.extend(filtered_items)
break
all_items.extend(filtered_items)
# Return up to the specified limit.
return all_items[: min(len(all_items), limit)]