diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py index c85ff8d7..74c2d01c 100644 --- a/pyrogram/client/client.py +++ b/pyrogram/client/client.py @@ -36,7 +36,6 @@ from queue import Queue from signal import signal, SIGINT, SIGTERM, SIGABRT from threading import Event, Thread -import pyrogram from pyrogram.api import functions, types from pyrogram.api.core import Object from pyrogram.api.errors import ( @@ -46,10 +45,11 @@ from pyrogram.api.errors import ( PasswordHashInvalid, FloodWait, PeerIdInvalid, FilePartMissing, ChatAdminRequired, FirstnameInvalid, PhoneNumberBanned, VolumeLocNotFound, UserMigrate) -from pyrogram.client import message_parser from pyrogram.crypto import AES from pyrogram.session import Auth, Session from pyrogram.session.internals import MsgId +from .dispatcher import Dispatcher +from .handler import Handler from .input_media import InputMedia from .style import Markdown, HTML @@ -183,10 +183,13 @@ class Client: self.is_idle = None self.updates_queue = Queue() - self.update_queue = Queue() + self.download_queue = Queue() + + self.dispatcher = Dispatcher(self, workers) self.update_handler = None - self.download_queue = Queue() + def add_handler(self, handler: Handler, group: int = 0): + self.dispatcher.add_handler(handler, group) def start(self): """Use this method to start the Client after creating it. @@ -234,12 +237,11 @@ class Client: for i in range(self.UPDATES_WORKERS): Thread(target=self.updates_worker, name="UpdatesWorker#{}".format(i + 1)).start() - for i in range(self.workers): - Thread(target=self.update_worker, name="UpdateWorker#{}".format(i + 1)).start() - for i in range(self.DOWNLOAD_WORKERS): Thread(target=self.download_worker, name="DownloadWorker#{}".format(i + 1)).start() + self.dispatcher.start() + mimetypes.init() def stop(self): @@ -255,12 +257,11 @@ class Client: for _ in range(self.UPDATES_WORKERS): self.updates_queue.put(None) - for _ in range(self.workers): - self.update_queue.put(None) - for _ in range(self.DOWNLOAD_WORKERS): self.download_queue.put(None) + self.dispatcher.stop() + def authorize_bot(self): try: r = self.send( @@ -684,7 +685,7 @@ class Client: if len(self.channels_pts[channel_id]) > 50: self.channels_pts[channel_id] = self.channels_pts[channel_id][25:] - self.update_queue.put((update, updates.users, updates.chats)) + self.dispatcher.updates.put((update, updates.users, updates.chats)) elif isinstance(updates, (types.UpdateShortMessage, types.UpdateShortChatMessage)): diff = self.send( functions.updates.GetDifference( @@ -694,7 +695,7 @@ class Client: ) ) - self.update_queue.put(( + self.dispatcher.updates.put(( types.UpdateNewMessage( message=diff.new_messages[0], pts=updates.pts, @@ -704,54 +705,7 @@ class Client: diff.chats )) elif isinstance(updates, types.UpdateShort): - self.update_queue.put((updates.update, [], [])) - except Exception as e: - log.error(e, exc_info=True) - - log.debug("{} stopped".format(name)) - - def update_worker(self): - name = threading.current_thread().name - log.debug("{} started".format(name)) - - while True: - update = self.update_queue.get() - - if update is None: - break - - try: - users = {i.id: i for i in update[1]} - chats = {i.id: i for i in update[2]} - update = update[0] - - valid_updates = (types.UpdateNewMessage, types.UpdateNewChannelMessage, - types.UpdateEditMessage, types.UpdateEditChannelMessage) - - if isinstance(update, valid_updates): - message = update.message - - if isinstance(message, types.Message): - m = message_parser.parse_message(self, message, users, chats) - elif isinstance(message, types.MessageService): - m = message_parser.parse_message_service(self, message, users, chats) - else: - continue - else: - continue - - edit = isinstance(update, (types.UpdateEditMessage, types.UpdateEditChannelMessage)) - - u = pyrogram.Update( - update_id=0, - message=(m if m.chat.type is not "channel" else None) if not edit else None, - edited_message=(m if m.chat.type is not "channel" else None) if edit else None, - channel_post=(m if m.chat.type is "channel" else None) if not edit else None, - edited_channel_post=(m if m.chat.type is "channel" else None) if edit else None - ) - - if self.update_handler: - self.update_handler(self, u) + self.dispatcher.updates.put((updates.update, [], [])) except Exception as e: log.error(e, exc_info=True) @@ -778,47 +732,6 @@ class Client: while self.is_idle: time.sleep(1) - def set_update_handler(self, callback: callable): - """Use this method to set the update handler. - - You must call this method *before* you *start()* the Client. - - Args: - callback (``callable``): - A function that will be called when a new update is received from the server. It takes - *(client, update, users, chats)* as positional arguments (Look at the section below for - a detailed description). - - Other Parameters: - client (:class:`Client `): - The Client itself, useful when you want to call other API methods inside the update handler. - - update (``Update``): - The received update, which can be one of the many single Updates listed in the *updates* - field you see in the :obj:`Update ` type. - - users (``dict``): - Dictionary of all :obj:`User ` mentioned in the update. - You can access extra info about the user (such as *first_name*, *last_name*, etc...) by using - the IDs you find in the *update* argument (e.g.: *users[1768841572]*). - - chats (``dict``): - Dictionary of all :obj:`Chat ` and - :obj:`Channel ` mentioned in the update. - You can access extra info about the chat (such as *title*, *participants_count*, etc...) - by using the IDs you find in the *update* argument (e.g.: *chats[1701277281]*). - - Note: - The following Empty or Forbidden types may exist inside the *users* and *chats* dictionaries. - They mean you have been blocked by the user or banned from the group/channel. - - - :obj:`UserEmpty ` - - :obj:`ChatEmpty ` - - :obj:`ChatForbidden ` - - :obj:`ChannelForbidden ` - """ - self.update_handler = callback - def send(self, data: Object): """Use this method to send Raw Function queries.