mirror of
https://github.com/TeamPGM/pyrogram.git
synced 2024-11-18 21:44:22 +00:00
Turn the Dispatcher async
This commit is contained in:
parent
57f917e6df
commit
0a6583a43c
@ -16,11 +16,9 @@
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
|
||||
import pyrogram
|
||||
from pyrogram.api import types
|
||||
@ -46,29 +44,17 @@ class Dispatcher:
|
||||
def __init__(self, client, workers):
|
||||
self.client = client
|
||||
self.workers = workers
|
||||
self.workers_list = []
|
||||
self.updates = Queue()
|
||||
|
||||
self.update_worker_task = None
|
||||
self.updates = asyncio.Queue()
|
||||
self.groups = OrderedDict()
|
||||
|
||||
def start(self):
|
||||
for i in range(self.workers):
|
||||
self.workers_list.append(
|
||||
Thread(
|
||||
target=self.update_worker,
|
||||
name="UpdateWorker#{}".format(i + 1)
|
||||
)
|
||||
)
|
||||
async def start(self):
|
||||
self.update_worker_task = asyncio.ensure_future(self.update_worker())
|
||||
|
||||
self.workers_list[-1].start()
|
||||
|
||||
def stop(self):
|
||||
for _ in range(self.workers):
|
||||
self.updates.put(None)
|
||||
|
||||
for i in self.workers_list:
|
||||
i.join()
|
||||
|
||||
self.workers_list.clear()
|
||||
async def stop(self):
|
||||
self.updates.put_nowait(None)
|
||||
await self.update_worker_task
|
||||
|
||||
def add_handler(self, handler, group: int):
|
||||
if group not in self.groups:
|
||||
@ -83,7 +69,9 @@ class Dispatcher:
|
||||
"Handler was not removed.".format(group))
|
||||
self.groups[group].remove(handler)
|
||||
|
||||
def dispatch(self, update, users: dict = None, chats: dict = None, is_raw: bool = False):
|
||||
async def dispatch(self, update, users: dict = None, chats: dict = None, is_raw: bool = False):
|
||||
tasks = []
|
||||
|
||||
for group in self.groups.values():
|
||||
for handler in group:
|
||||
if is_raw:
|
||||
@ -112,15 +100,17 @@ class Dispatcher:
|
||||
else:
|
||||
continue
|
||||
|
||||
handler.callback(*args)
|
||||
tasks.append(handler.callback(*args))
|
||||
break
|
||||
|
||||
def update_worker(self):
|
||||
name = threading.current_thread().name
|
||||
log.debug("{} started".format(name))
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def update_worker(self):
|
||||
log.info("UpdateWorkerTask started")
|
||||
|
||||
while True:
|
||||
update = self.updates.get()
|
||||
tasks = []
|
||||
update = await self.updates.get()
|
||||
|
||||
if update is None:
|
||||
break
|
||||
@ -130,7 +120,7 @@ class Dispatcher:
|
||||
chats = {i.id: i for i in update[2]}
|
||||
update = update[0]
|
||||
|
||||
self.dispatch(update, users=users, chats=chats, is_raw=True)
|
||||
tasks.append(self.dispatch(update, users=users, chats=chats, is_raw=True))
|
||||
|
||||
if isinstance(update, Dispatcher.MESSAGE_UPDATES):
|
||||
if isinstance(update.message, types.MessageEmpty):
|
||||
@ -145,7 +135,7 @@ class Dispatcher:
|
||||
|
||||
is_edited_message = isinstance(update, Dispatcher.EDIT_MESSAGE_UPDATES)
|
||||
|
||||
self.dispatch(
|
||||
tasks.append(self.dispatch(
|
||||
pyrogram.Update(
|
||||
message=((message if message.chat.type != "channel"
|
||||
else None) if not is_edited_message
|
||||
@ -160,26 +150,28 @@ class Dispatcher:
|
||||
else None) if is_edited_message
|
||||
else None)
|
||||
)
|
||||
)
|
||||
))
|
||||
elif isinstance(update, types.UpdateBotCallbackQuery):
|
||||
self.dispatch(
|
||||
tasks.append(self.dispatch(
|
||||
pyrogram.Update(
|
||||
callback_query=utils.parse_callback_query(
|
||||
self.client, update, users
|
||||
)
|
||||
)
|
||||
)
|
||||
))
|
||||
elif isinstance(update, types.UpdateInlineBotCallbackQuery):
|
||||
self.dispatch(
|
||||
tasks.append(self.dispatch(
|
||||
pyrogram.Update(
|
||||
callback_query=utils.parse_inline_callback_query(
|
||||
update, users
|
||||
)
|
||||
)
|
||||
)
|
||||
))
|
||||
else:
|
||||
continue
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
except Exception as e:
|
||||
log.error(e, exc_info=True)
|
||||
|
||||
log.debug("{} stopped".format(name))
|
||||
log.info("UpdateWorkerTask stopped")
|
||||
|
Loading…
Reference in New Issue
Block a user