Merge branch 'develop' into asyncio

# Conflicts:
#	pyrogram/client/dispatcher/dispatcher.py
#	pyrogram/client/ext/utils.py
#	pyrogram/client/types/messages_and_media/message.py
This commit is contained in:
Dan 2018-11-09 09:33:00 +01:00
commit 14feffce84
4 changed files with 107 additions and 170 deletions

View File

@ -20,10 +20,9 @@ import asyncio
import logging
from collections import OrderedDict
import pyrogram
from pyrogram.api import types
from ..ext import utils
from ..handlers import RawUpdateHandler, CallbackQueryHandler, MessageHandler, DeletedMessagesHandler, UserStatusHandler
from ..handlers import CallbackQueryHandler, MessageHandler, DeletedMessagesHandler, UserStatusHandler, RawUpdateHandler
log = logging.getLogger(__name__)
@ -44,9 +43,16 @@ class Dispatcher:
types.UpdateDeleteChannelMessages
)
CALLBACK_QUERY_UPDATES = (
types.UpdateBotCallbackQuery,
types.UpdateInlineBotCallbackQuery
)
MESSAGE_UPDATES = NEW_MESSAGE_UPDATES + EDIT_MESSAGE_UPDATES
def __init__(self, client, workers):
UPDATES = None
def __init__(self, client, workers: int):
self.client = client
self.workers = workers
@ -54,6 +60,22 @@ class Dispatcher:
self.updates = asyncio.Queue()
self.groups = OrderedDict()
Dispatcher.UPDATES = {
Dispatcher.MESSAGE_UPDATES:
lambda upd, usr, cht: (utils.parse_messages(self.client, upd.message, usr, cht), MessageHandler),
Dispatcher.DELETE_MESSAGE_UPDATES:
lambda upd, usr, cht: (utils.parse_deleted_messages(upd), DeletedMessagesHandler),
Dispatcher.CALLBACK_QUERY_UPDATES:
lambda upd, usr, cht: (utils.parse_callback_query(self.client, upd, usr), CallbackQueryHandler),
(types.UpdateUserStatus,):
lambda upd, usr, cht: (utils.parse_user_status(upd.status, upd.user_id), UserStatusHandler)
}
Dispatcher.UPDATES = {key: value for key_tuple, value in Dispatcher.UPDATES.items() for key in key_tuple}
async def start(self):
for i in range(self.workers):
self.update_worker_tasks.append(
@ -82,67 +104,13 @@ class Dispatcher:
def remove_handler(self, handler, group: int):
if group not in self.groups:
raise ValueError("Group {} does not exist. "
"Handler was not removed.".format(group))
raise ValueError("Group {} does not exist. Handler was not removed.".format(group))
self.groups[group].remove(handler)
async def dispatch(self, update, users: dict = None, chats: dict = None, is_raw: bool = False):
tasks = []
for group in self.groups.values():
try:
for handler in group:
if is_raw:
if not isinstance(handler, RawUpdateHandler):
continue
args = (self.client, update, users, chats)
else:
message = (update.message
or update.channel_post
or update.edited_message
or update.edited_channel_post)
deleted_messages = (update.deleted_channel_posts
or update.deleted_messages)
callback_query = update.callback_query
user_status = update.user_status
if message and isinstance(handler, MessageHandler):
if not handler.check(message):
continue
args = (self.client, message)
elif deleted_messages and isinstance(handler, DeletedMessagesHandler):
if not handler.check(deleted_messages):
continue
args = (self.client, deleted_messages)
elif callback_query and isinstance(handler, CallbackQueryHandler):
if not handler.check(callback_query):
continue
args = (self.client, callback_query)
elif user_status and isinstance(handler, UserStatusHandler):
if not handler.check(user_status):
continue
args = (self.client, user_status)
else:
continue
tasks.append(handler.callback(*args))
break
except Exception as e:
log.error(e, exc_info=True)
await asyncio.gather(*tasks)
async def update_worker(self):
def update_worker(self):
while True:
update = await self.updates.get()
update = self.updates.get()
if update is None:
break
@ -152,77 +120,34 @@ class Dispatcher:
chats = {i.id: i for i in update[2]}
update = update[0]
await self.dispatch(update, users=users, chats=chats, is_raw=True)
parser = Dispatcher.UPDATES.get(type(update), None)
if isinstance(update, Dispatcher.MESSAGE_UPDATES):
if isinstance(update.message, types.MessageEmpty):
continue
message = await utils.parse_messages(
self.client,
update.message,
users,
chats
)
is_edited_message = isinstance(update, Dispatcher.EDIT_MESSAGE_UPDATES)
await self.dispatch(
pyrogram.Update(
message=((message if message.chat.type != "channel"
else None) if not is_edited_message
else None),
edited_message=((message if message.chat.type != "channel"
else None) if is_edited_message
else None),
channel_post=((message if message.chat.type == "channel"
else None) if not is_edited_message
else None),
edited_channel_post=((message if message.chat.type == "channel"
else None) if is_edited_message
else None)
)
)
elif isinstance(update, Dispatcher.DELETE_MESSAGE_UPDATES):
is_channel = hasattr(update, 'channel_id')
messages = utils.parse_deleted_messages(
update.messages,
(update.channel_id if is_channel else None)
)
await self.dispatch(
pyrogram.Update(
deleted_messages=(messages if not is_channel else None),
deleted_channel_posts=(messages if is_channel else None)
)
)
elif isinstance(update, types.UpdateBotCallbackQuery):
await self.dispatch(
pyrogram.Update(
callback_query=await utils.parse_callback_query(
self.client, update, users
)
)
)
elif isinstance(update, types.UpdateInlineBotCallbackQuery):
await self.dispatch(
pyrogram.Update(
callback_query=await utils.parse_inline_callback_query(
self.client, update, users
)
)
)
elif isinstance(update, types.UpdateUserStatus):
await self.dispatch(
pyrogram.Update(
user_status=utils.parse_user_status(
update.status, update.user_id
)
)
)
else:
if parser is None:
continue
update, handler_type = parser(update, users, chats)
tasks = []
for group in self.groups.values():
for handler in group:
args = None
if isinstance(handler, RawUpdateHandler):
args = (update, users, chats)
elif isinstance(handler, handler_type):
if handler.check(update):
args = (update,)
if args is None:
continue
try:
tasks.append(handler.callback(self.client, *args))
except Exception as e:
log.error(e, exc_info=True)
finally:
break
await asyncio.gather(*tasks)
except Exception as e:
log.error(e, exc_info=True)

View File

@ -772,10 +772,10 @@ async def parse_messages(
return parsed_messages if is_list else parsed_messages[0]
def parse_deleted_messages(
messages: list,
channel_id: int
) -> pyrogram_types.Messages:
def parse_deleted_messages(update) -> pyrogram_types.Messages:
messages = update.messages
channel_id = getattr(update, "channel_id", None)
parsed_messages = []
for message in messages:
@ -882,42 +882,40 @@ def parse_profile_photos(photos):
)
async def parse_callback_query(client, callback_query, users):
peer = callback_query.peer
async def parse_callback_query(client, update, users):
message = None
inline_message_id = None
if isinstance(peer, types.PeerUser):
peer_id = peer.user_id
elif isinstance(peer, types.PeerChat):
peer_id = -peer.chat_id
else:
peer_id = int("-100" + str(peer.channel_id))
if isinstance(update, types.UpdateBotCallbackQuery):
peer = update.peer
return pyrogram_types.CallbackQuery(
id=str(callback_query.query_id),
from_user=parse_user(users[callback_query.user_id]),
message=await client.get_messages(peer_id, callback_query.msg_id),
chat_instance=str(callback_query.chat_instance),
data=callback_query.data.decode(),
game_short_name=callback_query.game_short_name,
client=client
)
if isinstance(peer, types.PeerUser):
peer_id = peer.user_id
elif isinstance(peer, types.PeerChat):
peer_id = -peer.chat_id
else:
peer_id = int("-100" + str(peer.channel_id))
async def parse_inline_callback_query(client, callback_query, users):
return pyrogram_types.CallbackQuery(
id=str(callback_query.query_id),
from_user=parse_user(users[callback_query.user_id]),
chat_instance=str(callback_query.chat_instance),
inline_message_id=b64encode(
message = client.get_messages(peer_id, update.msg_id)
elif isinstance(update, types.UpdateInlineBotCallbackQuery):
inline_message_id = b64encode(
pack(
"<iqq",
callback_query.msg_id.dc_id,
callback_query.msg_id.id,
callback_query.msg_id.access_hash
update.msg_id.dc_id,
update.msg_id.id,
update.msg_id.access_hash
),
b"-_"
).decode().rstrip("="),
game_short_name=callback_query.game_short_name,
).decode().rstrip("=")
return pyrogram_types.CallbackQuery(
id=str(update.query_id),
from_user=parse_user(users[update.user_id]),
message=message,
inline_message_id=inline_message_id,
chat_instance=str(update.chat_instance),
data=update.data.decode(),
game_short_name=update.game_short_name,
client=client
)

View File

@ -282,15 +282,16 @@ class Filters:
Args:
users (``int`` | ``str`` | ``list``):
Pass one or more user ids/usernames to filter users.
For you yourself, "me" or "self" can be used as well.
Defaults to None (no users).
"""
def __init__(self, users: int or str or list = None):
users = [] if users is None else users if type(users) is list else [users]
super().__init__(
{i.lower().strip("@") if type(i) is str else i for i in users}
{"me" if i in ["me", "self"] else i.lower().strip("@") if type(i) is str else i for i in users}
if type(users) is list else
{users.lower().strip("@") if type(users) is str else users}
{"me" if users in ["me", "self"] else users.lower().strip("@") if type(users) is str else users}
)
def __call__(self, message):
@ -298,7 +299,9 @@ class Filters:
message.from_user
and (message.from_user.id in self
or (message.from_user.username
and message.from_user.username.lower() in self))
and message.from_user.username.lower() in self)
or ("me" in self
and message.from_user.is_self))
)
# noinspection PyPep8Naming
@ -311,7 +314,7 @@ class Filters:
Args:
chats (``int`` | ``str`` | ``list``):
Pass one or more chat ids/usernames to filter chats.
For your personal cloud (Saved Messages) you can simply use me or self.
For your personal cloud (Saved Messages) you can simply use "me" or "self".
Defaults to None (no chats).
"""

View File

@ -634,7 +634,7 @@ class Message(Object):
else:
raise ValueError("The message doesn't contain any keyboard")
async def download(self, file_name: str = "", block: bool = True):
async def download(self, file_name: str = "", block: bool = True, progress: callable = None, progress_args: tuple = None):
"""Bound method *download* of :obj:`Message <pyrogram.Message>`.
Use as a shortcut for:
@ -659,6 +659,15 @@ class Message(Object):
Blocks the code execution until the file has been downloaded.
Defaults to True.
progress (``callable``):
Pass a callback function to view the download progress.
The function must take *(client, current, total, \*args)* as positional arguments (look at the section
below for a detailed description).
progress_args (``tuple``):
Extra custom arguments for the progress callback function. Useful, for example, if you want to pass
a chat_id and a message_id in order to edit a message with the updated progress.
Returns:
On success, the absolute path of the downloaded file as string is returned, None otherwise.
@ -669,5 +678,7 @@ class Message(Object):
return await self._client.download_media(
message=self,
file_name=file_name,
block=block
block=block,
progress=progress,
progress_args=progress_args,
)