Accommodate the new Dispatcher

This commit is contained in:
Dan 2018-04-06 18:48:41 +02:00
parent 7bd52c3718
commit e98b209526

View File

@ -36,7 +36,6 @@ from queue import Queue
from signal import signal, SIGINT, SIGTERM, SIGABRT from signal import signal, SIGINT, SIGTERM, SIGABRT
from threading import Event, Thread from threading import Event, Thread
import pyrogram
from pyrogram.api import functions, types from pyrogram.api import functions, types
from pyrogram.api.core import Object from pyrogram.api.core import Object
from pyrogram.api.errors import ( from pyrogram.api.errors import (
@ -46,10 +45,11 @@ from pyrogram.api.errors import (
PasswordHashInvalid, FloodWait, PeerIdInvalid, FilePartMissing, PasswordHashInvalid, FloodWait, PeerIdInvalid, FilePartMissing,
ChatAdminRequired, FirstnameInvalid, PhoneNumberBanned, ChatAdminRequired, FirstnameInvalid, PhoneNumberBanned,
VolumeLocNotFound, UserMigrate) VolumeLocNotFound, UserMigrate)
from pyrogram.client import message_parser
from pyrogram.crypto import AES from pyrogram.crypto import AES
from pyrogram.session import Auth, Session from pyrogram.session import Auth, Session
from pyrogram.session.internals import MsgId from pyrogram.session.internals import MsgId
from .dispatcher import Dispatcher
from .handler import Handler
from .input_media import InputMedia from .input_media import InputMedia
from .style import Markdown, HTML from .style import Markdown, HTML
@ -183,10 +183,13 @@ class Client:
self.is_idle = None self.is_idle = None
self.updates_queue = Queue() self.updates_queue = Queue()
self.update_queue = Queue() self.download_queue = Queue()
self.dispatcher = Dispatcher(self, workers)
self.update_handler = None self.update_handler = None
self.download_queue = Queue() def add_handler(self, handler: Handler, group: int = 0):
self.dispatcher.add_handler(handler, group)
def start(self): def start(self):
"""Use this method to start the Client after creating it. """Use this method to start the Client after creating it.
@ -234,12 +237,11 @@ class Client:
for i in range(self.UPDATES_WORKERS): for i in range(self.UPDATES_WORKERS):
Thread(target=self.updates_worker, name="UpdatesWorker#{}".format(i + 1)).start() Thread(target=self.updates_worker, name="UpdatesWorker#{}".format(i + 1)).start()
for i in range(self.workers):
Thread(target=self.update_worker, name="UpdateWorker#{}".format(i + 1)).start()
for i in range(self.DOWNLOAD_WORKERS): for i in range(self.DOWNLOAD_WORKERS):
Thread(target=self.download_worker, name="DownloadWorker#{}".format(i + 1)).start() Thread(target=self.download_worker, name="DownloadWorker#{}".format(i + 1)).start()
self.dispatcher.start()
mimetypes.init() mimetypes.init()
def stop(self): def stop(self):
@ -255,12 +257,11 @@ class Client:
for _ in range(self.UPDATES_WORKERS): for _ in range(self.UPDATES_WORKERS):
self.updates_queue.put(None) self.updates_queue.put(None)
for _ in range(self.workers):
self.update_queue.put(None)
for _ in range(self.DOWNLOAD_WORKERS): for _ in range(self.DOWNLOAD_WORKERS):
self.download_queue.put(None) self.download_queue.put(None)
self.dispatcher.stop()
def authorize_bot(self): def authorize_bot(self):
try: try:
r = self.send( r = self.send(
@ -684,7 +685,7 @@ class Client:
if len(self.channels_pts[channel_id]) > 50: if len(self.channels_pts[channel_id]) > 50:
self.channels_pts[channel_id] = self.channels_pts[channel_id][25:] self.channels_pts[channel_id] = self.channels_pts[channel_id][25:]
self.update_queue.put((update, updates.users, updates.chats)) self.dispatcher.updates.put((update, updates.users, updates.chats))
elif isinstance(updates, (types.UpdateShortMessage, types.UpdateShortChatMessage)): elif isinstance(updates, (types.UpdateShortMessage, types.UpdateShortChatMessage)):
diff = self.send( diff = self.send(
functions.updates.GetDifference( functions.updates.GetDifference(
@ -694,7 +695,7 @@ class Client:
) )
) )
self.update_queue.put(( self.dispatcher.updates.put((
types.UpdateNewMessage( types.UpdateNewMessage(
message=diff.new_messages[0], message=diff.new_messages[0],
pts=updates.pts, pts=updates.pts,
@ -704,54 +705,7 @@ class Client:
diff.chats diff.chats
)) ))
elif isinstance(updates, types.UpdateShort): elif isinstance(updates, types.UpdateShort):
self.update_queue.put((updates.update, [], [])) self.dispatcher.updates.put((updates.update, [], []))
except Exception as e:
log.error(e, exc_info=True)
log.debug("{} stopped".format(name))
def update_worker(self):
name = threading.current_thread().name
log.debug("{} started".format(name))
while True:
update = self.update_queue.get()
if update is None:
break
try:
users = {i.id: i for i in update[1]}
chats = {i.id: i for i in update[2]}
update = update[0]
valid_updates = (types.UpdateNewMessage, types.UpdateNewChannelMessage,
types.UpdateEditMessage, types.UpdateEditChannelMessage)
if isinstance(update, valid_updates):
message = update.message
if isinstance(message, types.Message):
m = message_parser.parse_message(self, message, users, chats)
elif isinstance(message, types.MessageService):
m = message_parser.parse_message_service(self, message, users, chats)
else:
continue
else:
continue
edit = isinstance(update, (types.UpdateEditMessage, types.UpdateEditChannelMessage))
u = pyrogram.Update(
update_id=0,
message=(m if m.chat.type is not "channel" else None) if not edit else None,
edited_message=(m if m.chat.type is not "channel" else None) if edit else None,
channel_post=(m if m.chat.type is "channel" else None) if not edit else None,
edited_channel_post=(m if m.chat.type is "channel" else None) if edit else None
)
if self.update_handler:
self.update_handler(self, u)
except Exception as e: except Exception as e:
log.error(e, exc_info=True) log.error(e, exc_info=True)
@ -778,47 +732,6 @@ class Client:
while self.is_idle: while self.is_idle:
time.sleep(1) time.sleep(1)
def set_update_handler(self, callback: callable):
"""Use this method to set the update handler.
You must call this method *before* you *start()* the Client.
Args:
callback (``callable``):
A function that will be called when a new update is received from the server. It takes
*(client, update, users, chats)* as positional arguments (Look at the section below for
a detailed description).
Other Parameters:
client (:class:`Client <pyrogram.Client>`):
The Client itself, useful when you want to call other API methods inside the update handler.
update (``Update``):
The received update, which can be one of the many single Updates listed in the *updates*
field you see in the :obj:`Update <pyrogram.api.types.Update>` type.
users (``dict``):
Dictionary of all :obj:`User <pyrogram.api.types.User>` mentioned in the update.
You can access extra info about the user (such as *first_name*, *last_name*, etc...) by using
the IDs you find in the *update* argument (e.g.: *users[1768841572]*).
chats (``dict``):
Dictionary of all :obj:`Chat <pyrogram.api.types.Chat>` and
:obj:`Channel <pyrogram.api.types.Channel>` mentioned in the update.
You can access extra info about the chat (such as *title*, *participants_count*, etc...)
by using the IDs you find in the *update* argument (e.g.: *chats[1701277281]*).
Note:
The following Empty or Forbidden types may exist inside the *users* and *chats* dictionaries.
They mean you have been blocked by the user or banned from the group/channel.
- :obj:`UserEmpty <pyrogram.api.types.UserEmpty>`
- :obj:`ChatEmpty <pyrogram.api.types.ChatEmpty>`
- :obj:`ChatForbidden <pyrogram.api.types.ChatForbidden>`
- :obj:`ChannelForbidden <pyrogram.api.types.ChannelForbidden>`
"""
self.update_handler = callback
def send(self, data: Object): def send(self, data: Object):
"""Use this method to send Raw Function queries. """Use this method to send Raw Function queries.