Merge branch 'develop' into asyncio

# Conflicts:
#	pyrogram/client/dispatcher/dispatcher.py
#	pyrogram/client/ext/utils.py
#	pyrogram/client/types/messages_and_media/message.py
This commit is contained in:
Dan 2018-11-09 09:33:00 +01:00
commit 14feffce84
4 changed files with 107 additions and 170 deletions

View File

@ -20,10 +20,9 @@ import asyncio
import logging import logging
from collections import OrderedDict from collections import OrderedDict
import pyrogram
from pyrogram.api import types from pyrogram.api import types
from ..ext import utils from ..ext import utils
from ..handlers import RawUpdateHandler, CallbackQueryHandler, MessageHandler, DeletedMessagesHandler, UserStatusHandler from ..handlers import CallbackQueryHandler, MessageHandler, DeletedMessagesHandler, UserStatusHandler, RawUpdateHandler
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -44,9 +43,16 @@ class Dispatcher:
types.UpdateDeleteChannelMessages types.UpdateDeleteChannelMessages
) )
CALLBACK_QUERY_UPDATES = (
types.UpdateBotCallbackQuery,
types.UpdateInlineBotCallbackQuery
)
MESSAGE_UPDATES = NEW_MESSAGE_UPDATES + EDIT_MESSAGE_UPDATES MESSAGE_UPDATES = NEW_MESSAGE_UPDATES + EDIT_MESSAGE_UPDATES
def __init__(self, client, workers): UPDATES = None
def __init__(self, client, workers: int):
self.client = client self.client = client
self.workers = workers self.workers = workers
@ -54,6 +60,22 @@ class Dispatcher:
self.updates = asyncio.Queue() self.updates = asyncio.Queue()
self.groups = OrderedDict() self.groups = OrderedDict()
Dispatcher.UPDATES = {
Dispatcher.MESSAGE_UPDATES:
lambda upd, usr, cht: (utils.parse_messages(self.client, upd.message, usr, cht), MessageHandler),
Dispatcher.DELETE_MESSAGE_UPDATES:
lambda upd, usr, cht: (utils.parse_deleted_messages(upd), DeletedMessagesHandler),
Dispatcher.CALLBACK_QUERY_UPDATES:
lambda upd, usr, cht: (utils.parse_callback_query(self.client, upd, usr), CallbackQueryHandler),
(types.UpdateUserStatus,):
lambda upd, usr, cht: (utils.parse_user_status(upd.status, upd.user_id), UserStatusHandler)
}
Dispatcher.UPDATES = {key: value for key_tuple, value in Dispatcher.UPDATES.items() for key in key_tuple}
async def start(self): async def start(self):
for i in range(self.workers): for i in range(self.workers):
self.update_worker_tasks.append( self.update_worker_tasks.append(
@ -82,67 +104,13 @@ class Dispatcher:
def remove_handler(self, handler, group: int): def remove_handler(self, handler, group: int):
if group not in self.groups: if group not in self.groups:
raise ValueError("Group {} does not exist. " raise ValueError("Group {} does not exist. Handler was not removed.".format(group))
"Handler was not removed.".format(group))
self.groups[group].remove(handler) self.groups[group].remove(handler)
async def dispatch(self, update, users: dict = None, chats: dict = None, is_raw: bool = False): def update_worker(self):
tasks = []
for group in self.groups.values():
try:
for handler in group:
if is_raw:
if not isinstance(handler, RawUpdateHandler):
continue
args = (self.client, update, users, chats)
else:
message = (update.message
or update.channel_post
or update.edited_message
or update.edited_channel_post)
deleted_messages = (update.deleted_channel_posts
or update.deleted_messages)
callback_query = update.callback_query
user_status = update.user_status
if message and isinstance(handler, MessageHandler):
if not handler.check(message):
continue
args = (self.client, message)
elif deleted_messages and isinstance(handler, DeletedMessagesHandler):
if not handler.check(deleted_messages):
continue
args = (self.client, deleted_messages)
elif callback_query and isinstance(handler, CallbackQueryHandler):
if not handler.check(callback_query):
continue
args = (self.client, callback_query)
elif user_status and isinstance(handler, UserStatusHandler):
if not handler.check(user_status):
continue
args = (self.client, user_status)
else:
continue
tasks.append(handler.callback(*args))
break
except Exception as e:
log.error(e, exc_info=True)
await asyncio.gather(*tasks)
async def update_worker(self):
while True: while True:
update = await self.updates.get() update = self.updates.get()
if update is None: if update is None:
break break
@ -152,77 +120,34 @@ class Dispatcher:
chats = {i.id: i for i in update[2]} chats = {i.id: i for i in update[2]}
update = update[0] update = update[0]
await self.dispatch(update, users=users, chats=chats, is_raw=True) parser = Dispatcher.UPDATES.get(type(update), None)
if isinstance(update, Dispatcher.MESSAGE_UPDATES): if parser is None:
if isinstance(update.message, types.MessageEmpty):
continue continue
message = await utils.parse_messages( update, handler_type = parser(update, users, chats)
self.client, tasks = []
update.message,
users,
chats
)
is_edited_message = isinstance(update, Dispatcher.EDIT_MESSAGE_UPDATES) for group in self.groups.values():
for handler in group:
args = None
await self.dispatch( if isinstance(handler, RawUpdateHandler):
pyrogram.Update( args = (update, users, chats)
message=((message if message.chat.type != "channel" elif isinstance(handler, handler_type):
else None) if not is_edited_message if handler.check(update):
else None), args = (update,)
edited_message=((message if message.chat.type != "channel"
else None) if is_edited_message
else None),
channel_post=((message if message.chat.type == "channel"
else None) if not is_edited_message
else None),
edited_channel_post=((message if message.chat.type == "channel"
else None) if is_edited_message
else None)
)
)
elif isinstance(update, Dispatcher.DELETE_MESSAGE_UPDATES): if args is None:
is_channel = hasattr(update, 'channel_id')
messages = utils.parse_deleted_messages(
update.messages,
(update.channel_id if is_channel else None)
)
await self.dispatch(
pyrogram.Update(
deleted_messages=(messages if not is_channel else None),
deleted_channel_posts=(messages if is_channel else None)
)
)
elif isinstance(update, types.UpdateBotCallbackQuery):
await self.dispatch(
pyrogram.Update(
callback_query=await utils.parse_callback_query(
self.client, update, users
)
)
)
elif isinstance(update, types.UpdateInlineBotCallbackQuery):
await self.dispatch(
pyrogram.Update(
callback_query=await utils.parse_inline_callback_query(
self.client, update, users
)
)
)
elif isinstance(update, types.UpdateUserStatus):
await self.dispatch(
pyrogram.Update(
user_status=utils.parse_user_status(
update.status, update.user_id
)
)
)
else:
continue continue
try:
tasks.append(handler.callback(self.client, *args))
except Exception as e:
log.error(e, exc_info=True)
finally:
break
await asyncio.gather(*tasks)
except Exception as e: except Exception as e:
log.error(e, exc_info=True) log.error(e, exc_info=True)

View File

@ -772,10 +772,10 @@ async def parse_messages(
return parsed_messages if is_list else parsed_messages[0] return parsed_messages if is_list else parsed_messages[0]
def parse_deleted_messages( def parse_deleted_messages(update) -> pyrogram_types.Messages:
messages: list, messages = update.messages
channel_id: int channel_id = getattr(update, "channel_id", None)
) -> pyrogram_types.Messages:
parsed_messages = [] parsed_messages = []
for message in messages: for message in messages:
@ -882,8 +882,12 @@ def parse_profile_photos(photos):
) )
async def parse_callback_query(client, callback_query, users): async def parse_callback_query(client, update, users):
peer = callback_query.peer message = None
inline_message_id = None
if isinstance(update, types.UpdateBotCallbackQuery):
peer = update.peer
if isinstance(peer, types.PeerUser): if isinstance(peer, types.PeerUser):
peer_id = peer.user_id peer_id = peer.user_id
@ -892,32 +896,26 @@ async def parse_callback_query(client, callback_query, users):
else: else:
peer_id = int("-100" + str(peer.channel_id)) peer_id = int("-100" + str(peer.channel_id))
return pyrogram_types.CallbackQuery( message = client.get_messages(peer_id, update.msg_id)
id=str(callback_query.query_id), elif isinstance(update, types.UpdateInlineBotCallbackQuery):
from_user=parse_user(users[callback_query.user_id]), inline_message_id = b64encode(
message=await client.get_messages(peer_id, callback_query.msg_id),
chat_instance=str(callback_query.chat_instance),
data=callback_query.data.decode(),
game_short_name=callback_query.game_short_name,
client=client
)
async def parse_inline_callback_query(client, callback_query, users):
return pyrogram_types.CallbackQuery(
id=str(callback_query.query_id),
from_user=parse_user(users[callback_query.user_id]),
chat_instance=str(callback_query.chat_instance),
inline_message_id=b64encode(
pack( pack(
"<iqq", "<iqq",
callback_query.msg_id.dc_id, update.msg_id.dc_id,
callback_query.msg_id.id, update.msg_id.id,
callback_query.msg_id.access_hash update.msg_id.access_hash
), ),
b"-_" b"-_"
).decode().rstrip("="), ).decode().rstrip("=")
game_short_name=callback_query.game_short_name,
return pyrogram_types.CallbackQuery(
id=str(update.query_id),
from_user=parse_user(users[update.user_id]),
message=message,
inline_message_id=inline_message_id,
chat_instance=str(update.chat_instance),
data=update.data.decode(),
game_short_name=update.game_short_name,
client=client client=client
) )

View File

@ -282,15 +282,16 @@ class Filters:
Args: Args:
users (``int`` | ``str`` | ``list``): users (``int`` | ``str`` | ``list``):
Pass one or more user ids/usernames to filter users. Pass one or more user ids/usernames to filter users.
For you yourself, "me" or "self" can be used as well.
Defaults to None (no users). Defaults to None (no users).
""" """
def __init__(self, users: int or str or list = None): def __init__(self, users: int or str or list = None):
users = [] if users is None else users if type(users) is list else [users] users = [] if users is None else users if type(users) is list else [users]
super().__init__( super().__init__(
{i.lower().strip("@") if type(i) is str else i for i in users} {"me" if i in ["me", "self"] else i.lower().strip("@") if type(i) is str else i for i in users}
if type(users) is list else if type(users) is list else
{users.lower().strip("@") if type(users) is str else users} {"me" if users in ["me", "self"] else users.lower().strip("@") if type(users) is str else users}
) )
def __call__(self, message): def __call__(self, message):
@ -298,7 +299,9 @@ class Filters:
message.from_user message.from_user
and (message.from_user.id in self and (message.from_user.id in self
or (message.from_user.username or (message.from_user.username
and message.from_user.username.lower() in self)) and message.from_user.username.lower() in self)
or ("me" in self
and message.from_user.is_self))
) )
# noinspection PyPep8Naming # noinspection PyPep8Naming
@ -311,7 +314,7 @@ class Filters:
Args: Args:
chats (``int`` | ``str`` | ``list``): chats (``int`` | ``str`` | ``list``):
Pass one or more chat ids/usernames to filter chats. Pass one or more chat ids/usernames to filter chats.
For your personal cloud (Saved Messages) you can simply use me or self. For your personal cloud (Saved Messages) you can simply use "me" or "self".
Defaults to None (no chats). Defaults to None (no chats).
""" """

View File

@ -634,7 +634,7 @@ class Message(Object):
else: else:
raise ValueError("The message doesn't contain any keyboard") raise ValueError("The message doesn't contain any keyboard")
async def download(self, file_name: str = "", block: bool = True): async def download(self, file_name: str = "", block: bool = True, progress: callable = None, progress_args: tuple = None):
"""Bound method *download* of :obj:`Message <pyrogram.Message>`. """Bound method *download* of :obj:`Message <pyrogram.Message>`.
Use as a shortcut for: Use as a shortcut for:
@ -659,6 +659,15 @@ class Message(Object):
Blocks the code execution until the file has been downloaded. Blocks the code execution until the file has been downloaded.
Defaults to True. Defaults to True.
progress (``callable``):
Pass a callback function to view the download progress.
The function must take *(client, current, total, \*args)* as positional arguments (look at the section
below for a detailed description).
progress_args (``tuple``):
Extra custom arguments for the progress callback function. Useful, for example, if you want to pass
a chat_id and a message_id in order to edit a message with the updated progress.
Returns: Returns:
On success, the absolute path of the downloaded file as string is returned, None otherwise. On success, the absolute path of the downloaded file as string is returned, None otherwise.
@ -669,5 +678,7 @@ class Message(Object):
return await self._client.download_media( return await self._client.download_media(
message=self, message=self,
file_name=file_name, file_name=file_name,
block=block block=block,
progress=progress,
progress_args=progress_args,
) )