diff --git a/pyrogram/client/dispatcher/dispatcher.py b/pyrogram/client/dispatcher/dispatcher.py index 61cd8c7c..26a0011b 100644 --- a/pyrogram/client/dispatcher/dispatcher.py +++ b/pyrogram/client/dispatcher/dispatcher.py @@ -20,10 +20,9 @@ import asyncio import logging from collections import OrderedDict -import pyrogram from pyrogram.api import types from ..ext import utils -from ..handlers import RawUpdateHandler, CallbackQueryHandler, MessageHandler, DeletedMessagesHandler, UserStatusHandler +from ..handlers import CallbackQueryHandler, MessageHandler, DeletedMessagesHandler, UserStatusHandler, RawUpdateHandler log = logging.getLogger(__name__) @@ -44,9 +43,16 @@ class Dispatcher: types.UpdateDeleteChannelMessages ) + CALLBACK_QUERY_UPDATES = ( + types.UpdateBotCallbackQuery, + types.UpdateInlineBotCallbackQuery + ) + MESSAGE_UPDATES = NEW_MESSAGE_UPDATES + EDIT_MESSAGE_UPDATES - def __init__(self, client, workers): + UPDATES = None + + def __init__(self, client, workers: int): self.client = client self.workers = workers @@ -54,6 +60,22 @@ class Dispatcher: self.updates = asyncio.Queue() self.groups = OrderedDict() + Dispatcher.UPDATES = { + Dispatcher.MESSAGE_UPDATES: + lambda upd, usr, cht: (utils.parse_messages(self.client, upd.message, usr, cht), MessageHandler), + + Dispatcher.DELETE_MESSAGE_UPDATES: + lambda upd, usr, cht: (utils.parse_deleted_messages(upd), DeletedMessagesHandler), + + Dispatcher.CALLBACK_QUERY_UPDATES: + lambda upd, usr, cht: (utils.parse_callback_query(self.client, upd, usr), CallbackQueryHandler), + + (types.UpdateUserStatus,): + lambda upd, usr, cht: (utils.parse_user_status(upd.status, upd.user_id), UserStatusHandler) + } + + Dispatcher.UPDATES = {key: value for key_tuple, value in Dispatcher.UPDATES.items() for key in key_tuple} + async def start(self): for i in range(self.workers): self.update_worker_tasks.append( @@ -82,67 +104,13 @@ class Dispatcher: def remove_handler(self, handler, group: int): if group not in self.groups: - raise ValueError("Group {} does not exist. " - "Handler was not removed.".format(group)) + raise ValueError("Group {} does not exist. Handler was not removed.".format(group)) + self.groups[group].remove(handler) - async def dispatch(self, update, users: dict = None, chats: dict = None, is_raw: bool = False): - tasks = [] - - for group in self.groups.values(): - try: - for handler in group: - if is_raw: - if not isinstance(handler, RawUpdateHandler): - continue - - args = (self.client, update, users, chats) - else: - message = (update.message - or update.channel_post - or update.edited_message - or update.edited_channel_post) - - deleted_messages = (update.deleted_channel_posts - or update.deleted_messages) - - callback_query = update.callback_query - - user_status = update.user_status - - if message and isinstance(handler, MessageHandler): - if not handler.check(message): - continue - - args = (self.client, message) - elif deleted_messages and isinstance(handler, DeletedMessagesHandler): - if not handler.check(deleted_messages): - continue - - args = (self.client, deleted_messages) - elif callback_query and isinstance(handler, CallbackQueryHandler): - if not handler.check(callback_query): - continue - - args = (self.client, callback_query) - elif user_status and isinstance(handler, UserStatusHandler): - if not handler.check(user_status): - continue - - args = (self.client, user_status) - else: - continue - - tasks.append(handler.callback(*args)) - break - except Exception as e: - log.error(e, exc_info=True) - - await asyncio.gather(*tasks) - - async def update_worker(self): + def update_worker(self): while True: - update = await self.updates.get() + update = self.updates.get() if update is None: break @@ -152,77 +120,34 @@ class Dispatcher: chats = {i.id: i for i in update[2]} update = update[0] - await self.dispatch(update, users=users, chats=chats, is_raw=True) + parser = Dispatcher.UPDATES.get(type(update), None) - if isinstance(update, Dispatcher.MESSAGE_UPDATES): - if isinstance(update.message, types.MessageEmpty): - continue - - message = await utils.parse_messages( - self.client, - update.message, - users, - chats - ) - - is_edited_message = isinstance(update, Dispatcher.EDIT_MESSAGE_UPDATES) - - await self.dispatch( - pyrogram.Update( - message=((message if message.chat.type != "channel" - else None) if not is_edited_message - else None), - edited_message=((message if message.chat.type != "channel" - else None) if is_edited_message - else None), - channel_post=((message if message.chat.type == "channel" - else None) if not is_edited_message - else None), - edited_channel_post=((message if message.chat.type == "channel" - else None) if is_edited_message - else None) - ) - ) - - elif isinstance(update, Dispatcher.DELETE_MESSAGE_UPDATES): - is_channel = hasattr(update, 'channel_id') - - messages = utils.parse_deleted_messages( - update.messages, - (update.channel_id if is_channel else None) - ) - - await self.dispatch( - pyrogram.Update( - deleted_messages=(messages if not is_channel else None), - deleted_channel_posts=(messages if is_channel else None) - ) - ) - elif isinstance(update, types.UpdateBotCallbackQuery): - await self.dispatch( - pyrogram.Update( - callback_query=await utils.parse_callback_query( - self.client, update, users - ) - ) - ) - elif isinstance(update, types.UpdateInlineBotCallbackQuery): - await self.dispatch( - pyrogram.Update( - callback_query=await utils.parse_inline_callback_query( - self.client, update, users - ) - ) - ) - elif isinstance(update, types.UpdateUserStatus): - await self.dispatch( - pyrogram.Update( - user_status=utils.parse_user_status( - update.status, update.user_id - ) - ) - ) - else: + if parser is None: continue + + update, handler_type = parser(update, users, chats) + tasks = [] + + for group in self.groups.values(): + for handler in group: + args = None + + if isinstance(handler, RawUpdateHandler): + args = (update, users, chats) + elif isinstance(handler, handler_type): + if handler.check(update): + args = (update,) + + if args is None: + continue + + try: + tasks.append(handler.callback(self.client, *args)) + except Exception as e: + log.error(e, exc_info=True) + finally: + break + + await asyncio.gather(*tasks) except Exception as e: log.error(e, exc_info=True) diff --git a/pyrogram/client/ext/utils.py b/pyrogram/client/ext/utils.py index c58be74c..3dc78c60 100644 --- a/pyrogram/client/ext/utils.py +++ b/pyrogram/client/ext/utils.py @@ -772,10 +772,10 @@ async def parse_messages( return parsed_messages if is_list else parsed_messages[0] -def parse_deleted_messages( - messages: list, - channel_id: int -) -> pyrogram_types.Messages: +def parse_deleted_messages(update) -> pyrogram_types.Messages: + messages = update.messages + channel_id = getattr(update, "channel_id", None) + parsed_messages = [] for message in messages: @@ -882,42 +882,40 @@ def parse_profile_photos(photos): ) -async def parse_callback_query(client, callback_query, users): - peer = callback_query.peer +async def parse_callback_query(client, update, users): + message = None + inline_message_id = None - if isinstance(peer, types.PeerUser): - peer_id = peer.user_id - elif isinstance(peer, types.PeerChat): - peer_id = -peer.chat_id - else: - peer_id = int("-100" + str(peer.channel_id)) + if isinstance(update, types.UpdateBotCallbackQuery): + peer = update.peer - return pyrogram_types.CallbackQuery( - id=str(callback_query.query_id), - from_user=parse_user(users[callback_query.user_id]), - message=await client.get_messages(peer_id, callback_query.msg_id), - chat_instance=str(callback_query.chat_instance), - data=callback_query.data.decode(), - game_short_name=callback_query.game_short_name, - client=client - ) + if isinstance(peer, types.PeerUser): + peer_id = peer.user_id + elif isinstance(peer, types.PeerChat): + peer_id = -peer.chat_id + else: + peer_id = int("-100" + str(peer.channel_id)) - -async def parse_inline_callback_query(client, callback_query, users): - return pyrogram_types.CallbackQuery( - id=str(callback_query.query_id), - from_user=parse_user(users[callback_query.user_id]), - chat_instance=str(callback_query.chat_instance), - inline_message_id=b64encode( + message = client.get_messages(peer_id, update.msg_id) + elif isinstance(update, types.UpdateInlineBotCallbackQuery): + inline_message_id = b64encode( pack( "`. Use as a shortcut for: @@ -659,6 +659,15 @@ class Message(Object): Blocks the code execution until the file has been downloaded. Defaults to True. + progress (``callable``): + Pass a callback function to view the download progress. + The function must take *(client, current, total, \*args)* as positional arguments (look at the section + below for a detailed description). + + progress_args (``tuple``): + Extra custom arguments for the progress callback function. Useful, for example, if you want to pass + a chat_id and a message_id in order to edit a message with the updated progress. + Returns: On success, the absolute path of the downloaded file as string is returned, None otherwise. @@ -669,5 +678,7 @@ class Message(Object): return await self._client.download_media( message=self, file_name=file_name, - block=block + block=block, + progress=progress, + progress_args=progress_args, )