From 0a6583a43cae82e7a7584b1352fd4a61b1d6a225 Mon Sep 17 00:00:00 2001 From: Dan <14043624+delivrance@users.noreply.github.com> Date: Sun, 17 Jun 2018 18:41:07 +0200 Subject: [PATCH] Turn the Dispatcher async --- pyrogram/client/dispatcher/dispatcher.py | 66 +++++++++++------------- 1 file changed, 29 insertions(+), 37 deletions(-) diff --git a/pyrogram/client/dispatcher/dispatcher.py b/pyrogram/client/dispatcher/dispatcher.py index 51be2ebb..8efb6584 100644 --- a/pyrogram/client/dispatcher/dispatcher.py +++ b/pyrogram/client/dispatcher/dispatcher.py @@ -16,11 +16,9 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +import asyncio import logging -import threading from collections import OrderedDict -from queue import Queue -from threading import Thread import pyrogram from pyrogram.api import types @@ -46,29 +44,17 @@ class Dispatcher: def __init__(self, client, workers): self.client = client self.workers = workers - self.workers_list = [] - self.updates = Queue() + + self.update_worker_task = None + self.updates = asyncio.Queue() self.groups = OrderedDict() - def start(self): - for i in range(self.workers): - self.workers_list.append( - Thread( - target=self.update_worker, - name="UpdateWorker#{}".format(i + 1) - ) - ) + async def start(self): + self.update_worker_task = asyncio.ensure_future(self.update_worker()) - self.workers_list[-1].start() - - def stop(self): - for _ in range(self.workers): - self.updates.put(None) - - for i in self.workers_list: - i.join() - - self.workers_list.clear() + async def stop(self): + self.updates.put_nowait(None) + await self.update_worker_task def add_handler(self, handler, group: int): if group not in self.groups: @@ -83,7 +69,9 @@ class Dispatcher: "Handler was not removed.".format(group)) self.groups[group].remove(handler) - def dispatch(self, update, users: dict = None, chats: dict = None, is_raw: bool = False): + async def dispatch(self, update, users: dict = None, chats: dict = None, is_raw: bool = False): + tasks = [] + for group in self.groups.values(): for handler in group: if is_raw: @@ -112,15 +100,17 @@ class Dispatcher: else: continue - handler.callback(*args) + tasks.append(handler.callback(*args)) break - def update_worker(self): - name = threading.current_thread().name - log.debug("{} started".format(name)) + await asyncio.gather(*tasks) + + async def update_worker(self): + log.info("UpdateWorkerTask started") while True: - update = self.updates.get() + tasks = [] + update = await self.updates.get() if update is None: break @@ -130,7 +120,7 @@ class Dispatcher: chats = {i.id: i for i in update[2]} update = update[0] - self.dispatch(update, users=users, chats=chats, is_raw=True) + tasks.append(self.dispatch(update, users=users, chats=chats, is_raw=True)) if isinstance(update, Dispatcher.MESSAGE_UPDATES): if isinstance(update.message, types.MessageEmpty): @@ -145,7 +135,7 @@ class Dispatcher: is_edited_message = isinstance(update, Dispatcher.EDIT_MESSAGE_UPDATES) - self.dispatch( + tasks.append(self.dispatch( pyrogram.Update( message=((message if message.chat.type != "channel" else None) if not is_edited_message @@ -160,26 +150,28 @@ class Dispatcher: else None) if is_edited_message else None) ) - ) + )) elif isinstance(update, types.UpdateBotCallbackQuery): - self.dispatch( + tasks.append(self.dispatch( pyrogram.Update( callback_query=utils.parse_callback_query( self.client, update, users ) ) - ) + )) elif isinstance(update, types.UpdateInlineBotCallbackQuery): - self.dispatch( + tasks.append(self.dispatch( pyrogram.Update( callback_query=utils.parse_inline_callback_query( update, users ) ) - ) + )) else: continue + + await asyncio.gather(*tasks) except Exception as e: log.error(e, exc_info=True) - log.debug("{} stopped".format(name)) + log.info("UpdateWorkerTask stopped")