mirror of
https://github.com/TeamPGM/pyrogram.git
synced 2024-11-18 21:44:22 +00:00
Turn the Dispatcher async
This commit is contained in:
parent
57f917e6df
commit
0a6583a43c
@ -16,11 +16,9 @@
|
|||||||
# You should have received a copy of the GNU Lesser General Public License
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import threading
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from queue import Queue
|
|
||||||
from threading import Thread
|
|
||||||
|
|
||||||
import pyrogram
|
import pyrogram
|
||||||
from pyrogram.api import types
|
from pyrogram.api import types
|
||||||
@ -46,29 +44,17 @@ class Dispatcher:
|
|||||||
def __init__(self, client, workers):
|
def __init__(self, client, workers):
|
||||||
self.client = client
|
self.client = client
|
||||||
self.workers = workers
|
self.workers = workers
|
||||||
self.workers_list = []
|
|
||||||
self.updates = Queue()
|
self.update_worker_task = None
|
||||||
|
self.updates = asyncio.Queue()
|
||||||
self.groups = OrderedDict()
|
self.groups = OrderedDict()
|
||||||
|
|
||||||
def start(self):
|
async def start(self):
|
||||||
for i in range(self.workers):
|
self.update_worker_task = asyncio.ensure_future(self.update_worker())
|
||||||
self.workers_list.append(
|
|
||||||
Thread(
|
|
||||||
target=self.update_worker,
|
|
||||||
name="UpdateWorker#{}".format(i + 1)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.workers_list[-1].start()
|
async def stop(self):
|
||||||
|
self.updates.put_nowait(None)
|
||||||
def stop(self):
|
await self.update_worker_task
|
||||||
for _ in range(self.workers):
|
|
||||||
self.updates.put(None)
|
|
||||||
|
|
||||||
for i in self.workers_list:
|
|
||||||
i.join()
|
|
||||||
|
|
||||||
self.workers_list.clear()
|
|
||||||
|
|
||||||
def add_handler(self, handler, group: int):
|
def add_handler(self, handler, group: int):
|
||||||
if group not in self.groups:
|
if group not in self.groups:
|
||||||
@ -83,7 +69,9 @@ class Dispatcher:
|
|||||||
"Handler was not removed.".format(group))
|
"Handler was not removed.".format(group))
|
||||||
self.groups[group].remove(handler)
|
self.groups[group].remove(handler)
|
||||||
|
|
||||||
def dispatch(self, update, users: dict = None, chats: dict = None, is_raw: bool = False):
|
async def dispatch(self, update, users: dict = None, chats: dict = None, is_raw: bool = False):
|
||||||
|
tasks = []
|
||||||
|
|
||||||
for group in self.groups.values():
|
for group in self.groups.values():
|
||||||
for handler in group:
|
for handler in group:
|
||||||
if is_raw:
|
if is_raw:
|
||||||
@ -112,15 +100,17 @@ class Dispatcher:
|
|||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
handler.callback(*args)
|
tasks.append(handler.callback(*args))
|
||||||
break
|
break
|
||||||
|
|
||||||
def update_worker(self):
|
await asyncio.gather(*tasks)
|
||||||
name = threading.current_thread().name
|
|
||||||
log.debug("{} started".format(name))
|
async def update_worker(self):
|
||||||
|
log.info("UpdateWorkerTask started")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
update = self.updates.get()
|
tasks = []
|
||||||
|
update = await self.updates.get()
|
||||||
|
|
||||||
if update is None:
|
if update is None:
|
||||||
break
|
break
|
||||||
@ -130,7 +120,7 @@ class Dispatcher:
|
|||||||
chats = {i.id: i for i in update[2]}
|
chats = {i.id: i for i in update[2]}
|
||||||
update = update[0]
|
update = update[0]
|
||||||
|
|
||||||
self.dispatch(update, users=users, chats=chats, is_raw=True)
|
tasks.append(self.dispatch(update, users=users, chats=chats, is_raw=True))
|
||||||
|
|
||||||
if isinstance(update, Dispatcher.MESSAGE_UPDATES):
|
if isinstance(update, Dispatcher.MESSAGE_UPDATES):
|
||||||
if isinstance(update.message, types.MessageEmpty):
|
if isinstance(update.message, types.MessageEmpty):
|
||||||
@ -145,7 +135,7 @@ class Dispatcher:
|
|||||||
|
|
||||||
is_edited_message = isinstance(update, Dispatcher.EDIT_MESSAGE_UPDATES)
|
is_edited_message = isinstance(update, Dispatcher.EDIT_MESSAGE_UPDATES)
|
||||||
|
|
||||||
self.dispatch(
|
tasks.append(self.dispatch(
|
||||||
pyrogram.Update(
|
pyrogram.Update(
|
||||||
message=((message if message.chat.type != "channel"
|
message=((message if message.chat.type != "channel"
|
||||||
else None) if not is_edited_message
|
else None) if not is_edited_message
|
||||||
@ -160,26 +150,28 @@ class Dispatcher:
|
|||||||
else None) if is_edited_message
|
else None) if is_edited_message
|
||||||
else None)
|
else None)
|
||||||
)
|
)
|
||||||
)
|
))
|
||||||
elif isinstance(update, types.UpdateBotCallbackQuery):
|
elif isinstance(update, types.UpdateBotCallbackQuery):
|
||||||
self.dispatch(
|
tasks.append(self.dispatch(
|
||||||
pyrogram.Update(
|
pyrogram.Update(
|
||||||
callback_query=utils.parse_callback_query(
|
callback_query=utils.parse_callback_query(
|
||||||
self.client, update, users
|
self.client, update, users
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
))
|
||||||
elif isinstance(update, types.UpdateInlineBotCallbackQuery):
|
elif isinstance(update, types.UpdateInlineBotCallbackQuery):
|
||||||
self.dispatch(
|
tasks.append(self.dispatch(
|
||||||
pyrogram.Update(
|
pyrogram.Update(
|
||||||
callback_query=utils.parse_inline_callback_query(
|
callback_query=utils.parse_inline_callback_query(
|
||||||
update, users
|
update, users
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
))
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(e, exc_info=True)
|
log.error(e, exc_info=True)
|
||||||
|
|
||||||
log.debug("{} stopped".format(name))
|
log.info("UpdateWorkerTask stopped")
|
||||||
|
Loading…
Reference in New Issue
Block a user