Turn the Dispatcher async

This commit is contained in:
Dan 2018-06-17 18:41:07 +02:00
parent 57f917e6df
commit 0a6583a43c

View File

@ -16,11 +16,9 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging import logging
import threading
from collections import OrderedDict from collections import OrderedDict
from queue import Queue
from threading import Thread
import pyrogram import pyrogram
from pyrogram.api import types from pyrogram.api import types
@ -46,29 +44,17 @@ class Dispatcher:
def __init__(self, client, workers): def __init__(self, client, workers):
self.client = client self.client = client
self.workers = workers self.workers = workers
self.workers_list = []
self.updates = Queue() self.update_worker_task = None
self.updates = asyncio.Queue()
self.groups = OrderedDict() self.groups = OrderedDict()
def start(self): async def start(self):
for i in range(self.workers): self.update_worker_task = asyncio.ensure_future(self.update_worker())
self.workers_list.append(
Thread(
target=self.update_worker,
name="UpdateWorker#{}".format(i + 1)
)
)
self.workers_list[-1].start() async def stop(self):
self.updates.put_nowait(None)
def stop(self): await self.update_worker_task
for _ in range(self.workers):
self.updates.put(None)
for i in self.workers_list:
i.join()
self.workers_list.clear()
def add_handler(self, handler, group: int): def add_handler(self, handler, group: int):
if group not in self.groups: if group not in self.groups:
@ -83,7 +69,9 @@ class Dispatcher:
"Handler was not removed.".format(group)) "Handler was not removed.".format(group))
self.groups[group].remove(handler) self.groups[group].remove(handler)
def dispatch(self, update, users: dict = None, chats: dict = None, is_raw: bool = False): async def dispatch(self, update, users: dict = None, chats: dict = None, is_raw: bool = False):
tasks = []
for group in self.groups.values(): for group in self.groups.values():
for handler in group: for handler in group:
if is_raw: if is_raw:
@ -112,15 +100,17 @@ class Dispatcher:
else: else:
continue continue
handler.callback(*args) tasks.append(handler.callback(*args))
break break
def update_worker(self): await asyncio.gather(*tasks)
name = threading.current_thread().name
log.debug("{} started".format(name)) async def update_worker(self):
log.info("UpdateWorkerTask started")
while True: while True:
update = self.updates.get() tasks = []
update = await self.updates.get()
if update is None: if update is None:
break break
@ -130,7 +120,7 @@ class Dispatcher:
chats = {i.id: i for i in update[2]} chats = {i.id: i for i in update[2]}
update = update[0] update = update[0]
self.dispatch(update, users=users, chats=chats, is_raw=True) tasks.append(self.dispatch(update, users=users, chats=chats, is_raw=True))
if isinstance(update, Dispatcher.MESSAGE_UPDATES): if isinstance(update, Dispatcher.MESSAGE_UPDATES):
if isinstance(update.message, types.MessageEmpty): if isinstance(update.message, types.MessageEmpty):
@ -145,7 +135,7 @@ class Dispatcher:
is_edited_message = isinstance(update, Dispatcher.EDIT_MESSAGE_UPDATES) is_edited_message = isinstance(update, Dispatcher.EDIT_MESSAGE_UPDATES)
self.dispatch( tasks.append(self.dispatch(
pyrogram.Update( pyrogram.Update(
message=((message if message.chat.type != "channel" message=((message if message.chat.type != "channel"
else None) if not is_edited_message else None) if not is_edited_message
@ -160,26 +150,28 @@ class Dispatcher:
else None) if is_edited_message else None) if is_edited_message
else None) else None)
) )
) ))
elif isinstance(update, types.UpdateBotCallbackQuery): elif isinstance(update, types.UpdateBotCallbackQuery):
self.dispatch( tasks.append(self.dispatch(
pyrogram.Update( pyrogram.Update(
callback_query=utils.parse_callback_query( callback_query=utils.parse_callback_query(
self.client, update, users self.client, update, users
) )
) )
) ))
elif isinstance(update, types.UpdateInlineBotCallbackQuery): elif isinstance(update, types.UpdateInlineBotCallbackQuery):
self.dispatch( tasks.append(self.dispatch(
pyrogram.Update( pyrogram.Update(
callback_query=utils.parse_inline_callback_query( callback_query=utils.parse_inline_callback_query(
update, users update, users
) )
) )
) ))
else: else:
continue continue
await asyncio.gather(*tasks)
except Exception as e: except Exception as e:
log.error(e, exc_info=True) log.error(e, exc_info=True)
log.debug("{} stopped".format(name)) log.info("UpdateWorkerTask stopped")