diff --git a/simnet/client/wish/base.py b/simnet/client/wish/base.py index 010db7d..d165ea0 100644 --- a/simnet/client/wish/base.py +++ b/simnet/client/wish/base.py @@ -76,6 +76,7 @@ class BaseWishClient(BaseClient): end_id: int, banner_type: int, game: Game, + size: int = 20, lang: Optional[str] = None, authkey: Optional[str] = None, ) -> Dict[str, Any]: @@ -86,6 +87,7 @@ class BaseWishClient(BaseClient): end_id (int): The ending ID of the last wish to retrieve. banner_type (int): The type of banner to retrieve wishes from. game (Game): The game to make the request for. + size (int, optional): : The number of wishes to retrieve per page, with a default value of 20. lang (Optional[str], optional): The language code to use for the request. If not provided, the class default will be used. authkey (Optional[str], optional): The authorization key for making the request. @@ -98,7 +100,7 @@ class BaseWishClient(BaseClient): game=game, lang=lang, authkey=authkey, - params=dict(gacha_type=banner_type, size=20, end_id=end_id), + params=dict(gacha_type=banner_type, size=size, end_id=end_id), ) async def get_banner_names( diff --git a/simnet/client/wish/starrail.py b/simnet/client/wish/starrail.py index d75c133..f477f0d 100644 --- a/simnet/client/wish/starrail.py +++ b/simnet/client/wish/starrail.py @@ -1,7 +1,9 @@ +from functools import partial from typing import Optional, List from simnet.client.wish.base import BaseWishClient from simnet.models.starrail.wish import StarRailWish from simnet.utils.enum_ import Game +from simnet.utils.paginator import WishPaginator class WishClient(BaseWishClient): @@ -9,7 +11,7 @@ class WishClient(BaseWishClient): async def wish_history( self, - banner_types: List[int], + banner_type: int, limit: Optional[int] = None, lang: Optional[str] = None, authkey: Optional[str] = None, @@ -19,7 +21,7 @@ class WishClient(BaseWishClient): Get the wish history for a list of banner types. Args: - banner_types (List[int], optional): The list of banner types to get the wish history for. + banner_type (int, optional): The banner types to get the wish history for. limit (Optional[int] , optional): The maximum number of wishes to retrieve. If not provided, all available wishes will be returned. lang (Optional[str], optional): The language code to use for the request. @@ -30,12 +32,15 @@ class WishClient(BaseWishClient): Returns: List[StarRailWish]: A list of StarRailWish objects representing the retrieved wishes. """ - wish: List[StarRailWish] = [] - banner_names = await self.get_banner_names( - game=Game.STARRAIL, lang=lang, authkey=authkey + paginator = WishPaginator( + end_id, + partial( + self.get_wish_page, + banner_type=banner_type, + game=Game.STARRAIL, + authkey=authkey, + ), ) - for banner_type in banner_types: - data = await self.get_wish_page(end_id, banner_type, Game.STARRAIL) - banner_name = banner_names[banner_type] - wish = [StarRailWish(**i, banner_name=banner_name) for i in data["list"]] + items = await paginator.get(limit) + wish = [StarRailWish(**i) for i in items] return wish diff --git a/simnet/models/starrail/wish.py b/simnet/models/starrail/wish.py index 622a1c6..6cb16dc 100644 --- a/simnet/models/starrail/wish.py +++ b/simnet/models/starrail/wish.py @@ -47,9 +47,6 @@ class StarRailWish(APIModel): banner_type: StarRailBannerType = Field(alias="gacha_type") """Type of the banner the wish was made on.""" - banner_name: str - """Name of the banner the wish was made on.""" - @validator("banner_type", pre=True) def cast_banner_type(cls, v: Any) -> int: """Converts the banner type from any type to int.""" diff --git a/simnet/utils/paginator.py b/simnet/utils/paginator.py new file mode 100644 index 0000000..8e84b88 --- /dev/null +++ b/simnet/utils/paginator.py @@ -0,0 +1,50 @@ +from typing import List, Dict, Callable, Any, Awaitable + + +class WishPaginator: + """ + A paginator for fetching and processing wish data. + + Attributes: + end_id (int): The ID of the item to stop fetching at. + fetch_data (Callable[..., Awaitable[Dict[str, Any]]]): An asynchronous function to fetch the raw data. + """ + + def __init__( + self, + end_id: int, + fetch_data: Callable[..., Awaitable[Dict[str, Any]]], + ): + self.end_id = end_id + self.fetch_data = fetch_data + + async def get(self, limit: int) -> List[Dict]: + """ + Fetches and returns the items up to the specified limit. + + Args: + limit (int): The maximum number of items to return. + + Returns: + List[Dict]: The list of fetched items. + """ + all_items = [] + current_end_id = 0 + + while True: + raw_data = await self.fetch_data(end_id=current_end_id) + items = raw_data["list"] + if not items: + break + + current_end_id = items[-1]["id"] + + filtered_items = [item for item in items if item["id"] != self.end_id] + if len(filtered_items) < len(items): + all_items.extend(filtered_items) + break + + all_items.extend(filtered_items) + + # Return up to the specified limit. + return all_items[: min(len(all_items), limit)]