From c8c6faa96e4bdde24e9c99035b7b035d972c21ae Mon Sep 17 00:00:00 2001 From: CyanBook Date: Fri, 21 Aug 2020 07:28:27 +0200 Subject: [PATCH] Change logging hierarchy for loading plugins (#451) Loading plugins shouldn't be considered a warning --- pyrogram/__init__.py | 2 +- pyrogram/client/client.py | 416 ++++++++++-------- pyrogram/client/ext/base_client.py | 77 ++-- pyrogram/client/ext/dispatcher.py | 159 ++++--- pyrogram/client/ext/syncer.py | 40 +- pyrogram/client/ext/utils.py | 24 +- .../methods/bots/answer_callback_query.py | 4 +- .../methods/bots/answer_inline_query.py | 11 +- .../methods/bots/get_game_high_scores.py | 8 +- .../methods/bots/get_inline_bot_results.py | 6 +- .../methods/bots/request_callback_answer.py | 6 +- pyrogram/client/methods/bots/send_game.py | 8 +- .../methods/bots/send_inline_bot_result.py | 6 +- .../client/methods/bots/set_game_score.py | 10 +- .../client/methods/chats/add_chat_members.py | 12 +- .../client/methods/chats/archive_chats.py | 21 +- .../client/methods/chats/create_channel.py | 4 +- pyrogram/client/methods/chats/create_group.py | 6 +- .../client/methods/chats/create_supergroup.py | 4 +- .../client/methods/chats/delete_channel.py | 6 +- .../client/methods/chats/delete_chat_photo.py | 8 +- .../client/methods/chats/delete_supergroup.py | 6 +- .../methods/chats/export_chat_invite_link.py | 8 +- pyrogram/client/methods/chats/get_chat.py | 14 +- .../client/methods/chats/get_chat_member.py | 11 +- .../client/methods/chats/get_chat_members.py | 8 +- .../methods/chats/get_chat_members_count.py | 16 +- pyrogram/client/methods/chats/get_dialogs.py | 8 +- .../client/methods/chats/get_dialogs_count.py | 6 +- .../client/methods/chats/get_nearby_chats.py | 4 +- .../client/methods/chats/iter_chat_members.py | 15 +- pyrogram/client/methods/chats/iter_dialogs.py | 29 +- pyrogram/client/methods/chats/join_chat.py | 8 +- .../client/methods/chats/kick_chat_member.py | 12 +- pyrogram/client/methods/chats/leave_chat.py | 12 +- .../client/methods/chats/pin_chat_message.py | 6 +- .../methods/chats/promote_chat_member.py | 8 +- .../methods/chats/restrict_chat_member.py | 8 +- .../methods/chats/set_administrator_title.py | 12 +- .../methods/chats/set_chat_description.py | 6 +- .../methods/chats/set_chat_permissions.py | 6 +- .../client/methods/chats/set_chat_photo.py | 16 +- .../client/methods/chats/set_chat_title.py | 8 +- .../client/methods/chats/set_slow_mode.py | 6 +- .../client/methods/chats/unarchive_chats.py | 21 +- .../client/methods/chats/unban_chat_member.py | 8 +- .../methods/chats/unpin_chat_message.py | 6 +- .../methods/chats/update_chat_username.py | 6 +- .../client/methods/contacts/add_contacts.py | 4 +- .../methods/contacts/delete_contacts.py | 6 +- .../client/methods/contacts/get_contacts.py | 5 +- .../methods/contacts/get_contacts_count.py | 4 +- .../methods/messages/delete_messages.py | 8 +- .../client/methods/messages/download_media.py | 9 +- .../methods/messages/edit_inline_caption.py | 4 +- .../methods/messages/edit_inline_media.py | 6 +- .../messages/edit_inline_reply_markup.py | 4 +- .../methods/messages/edit_inline_text.py | 6 +- .../methods/messages/edit_message_caption.py | 4 +- .../methods/messages/edit_message_media.py | 46 +- .../messages/edit_message_reply_markup.py | 8 +- .../methods/messages/edit_message_text.py | 10 +- .../methods/messages/forward_messages.py | 15 +- .../client/methods/messages/get_history.py | 9 +- .../methods/messages/get_history_count.py | 6 +- .../client/methods/messages/get_messages.py | 9 +- .../client/methods/messages/iter_history.py | 13 +- .../client/methods/messages/read_history.py | 6 +- .../client/methods/messages/retract_vote.py | 6 +- .../client/methods/messages/search_global.py | 17 +- .../methods/messages/search_messages.py | 22 +- .../client/methods/messages/send_animation.py | 22 +- .../client/methods/messages/send_audio.py | 23 +- .../methods/messages/send_cached_media.py | 10 +- .../methods/messages/send_chat_action.py | 6 +- .../client/methods/messages/send_contact.py | 8 +- pyrogram/client/methods/messages/send_dice.py | 13 +- .../client/methods/messages/send_document.py | 20 +- .../client/methods/messages/send_location.py | 8 +- .../methods/messages/send_media_group.py | 30 +- .../client/methods/messages/send_message.py | 12 +- .../client/methods/messages/send_photo.py | 16 +- pyrogram/client/methods/messages/send_poll.py | 8 +- .../client/methods/messages/send_sticker.py | 14 +- .../client/methods/messages/send_venue.py | 8 +- .../client/methods/messages/send_video.py | 20 +- .../methods/messages/send_video_note.py | 18 +- .../client/methods/messages/send_voice.py | 16 +- pyrogram/client/methods/messages/stop_poll.py | 8 +- pyrogram/client/methods/messages/vote_poll.py | 8 +- .../methods/password/change_cloud_password.py | 6 +- .../methods/password/enable_cloud_password.py | 6 +- .../methods/password/remove_cloud_password.py | 6 +- pyrogram/client/methods/users/block_user.py | 6 +- .../methods/users/delete_profile_photos.py | 4 +- .../client/methods/users/get_common_chats.py | 6 +- pyrogram/client/methods/users/get_me.py | 6 +- .../methods/users/get_profile_photos.py | 10 +- .../methods/users/get_profile_photos_count.py | 8 +- pyrogram/client/methods/users/get_users.py | 7 +- .../methods/users/iter_profile_photos.py | 13 +- .../client/methods/users/set_profile_photo.py | 8 +- pyrogram/client/methods/users/unblock_user.py | 6 +- .../client/methods/users/update_profile.py | 4 +- .../client/methods/users/update_username.py | 4 +- pyrogram/client/parser/html.py | 4 +- pyrogram/client/parser/markdown.py | 4 +- pyrogram/client/parser/parser.py | 8 +- .../bots_and_keyboards/callback_query.py | 30 +- .../client/types/inline_mode/inline_query.py | 4 +- .../types/inline_mode/inline_query_result.py | 2 +- .../inline_query_result_animation.py | 4 +- .../inline_query_result_article.py | 4 +- .../inline_mode/inline_query_result_photo.py | 6 +- .../input_text_message_content.py | 4 +- .../types/messages_and_media/message.py | 146 +++--- .../types/messages_and_media/sticker.py | 17 +- pyrogram/client/types/update.py | 4 +- pyrogram/client/types/user_and_chats/chat.py | 66 +-- pyrogram/client/types/user_and_chats/user.py | 8 +- pyrogram/connection/connection.py | 30 +- pyrogram/connection/transport/tcp/__init__.py | 1 + pyrogram/connection/transport/tcp/tcp.py | 65 ++- .../connection/transport/tcp/tcp_abridged.py | 18 +- .../transport/tcp/tcp_abridged_o.py | 18 +- pyrogram/connection/transport/tcp/tcp_full.py | 23 +- .../transport/tcp/tcp_intermediate.py | 16 +- .../transport/tcp/tcp_intermediate_o.py | 16 +- pyrogram/crypto/__init__.py | 1 + pyrogram/crypto/mtproto.py | 62 +++ pyrogram/session/auth.py | 19 +- pyrogram/session/session.py | 265 +++++------ requirements.txt | 4 +- 133 files changed, 1349 insertions(+), 1207 deletions(-) create mode 100644 pyrogram/crypto/mtproto.py diff --git a/pyrogram/__init__.py b/pyrogram/__init__.py index b6cc36dc..053d3398 100644 --- a/pyrogram/__init__.py +++ b/pyrogram/__init__.py @@ -16,7 +16,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . -__version__ = "0.18.0" +__version__ = "0.18.0-async" __license__ = "GNU Lesser General Public License v3 or later (LGPLv3+)" __copyright__ = "Copyright (C) 2017-2020 Dan " diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py index d2add202..4d0312f4 100644 --- a/pyrogram/client/client.py +++ b/pyrogram/client/client.py @@ -16,6 +16,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +import asyncio import io import logging import math @@ -23,14 +24,11 @@ import os import re import shutil import tempfile -import threading -import time from configparser import ConfigParser from hashlib import sha256, md5 from importlib import import_module from pathlib import Path from signal import signal, SIGINT, SIGTERM, SIGABRT -from threading import Thread from typing import Union, List, BinaryIO from pyrogram.api import functions, types @@ -40,11 +38,13 @@ from pyrogram.client.handlers.handler import Handler from pyrogram.client.methods.password.utils import compute_check from pyrogram.crypto import AES from pyrogram.errors import ( - PhoneMigrate, NetworkMigrate, SessionPasswordNeeded, PeerIdInvalid, VolumeLocNotFound, UserMigrate, ChannelPrivate, + PhoneMigrate, NetworkMigrate, SessionPasswordNeeded, + PeerIdInvalid, VolumeLocNotFound, UserMigrate, ChannelPrivate, AuthBytesInvalid, BadRequest ) from pyrogram.session import Auth, Session from .ext import utils, Syncer, BaseClient, Dispatcher +from .ext.utils import ainput from .methods import Methods from .storage import Storage, FileStorage, MemoryStorage from .types import User, SentCode, TermsOfService @@ -127,7 +127,7 @@ class Client(Methods, BaseClient): Defaults to False. workers (``int``, *optional*): - Thread pool size for handling incoming updates. + Number of maximum concurrent workers for handling incoming updates. Defaults to 4. workdir (``str``, *optional*): @@ -243,6 +243,12 @@ class Client(Methods, BaseClient): except ConnectionError: pass + async def __aenter__(self): + return await self.start() + + async def __aexit__(self, *args): + await self.stop() + @property def proxy(self): return self._proxy @@ -259,7 +265,7 @@ class Client(Methods, BaseClient): self._proxy["enabled"] = bool(value.get("enabled", True)) self._proxy.update(value) - def connect(self) -> bool: + async def connect(self) -> bool: """ Connect the client to Telegram servers. @@ -274,16 +280,17 @@ class Client(Methods, BaseClient): raise ConnectionError("Client is already connected") self.load_config() - self.load_session() + await self.load_session() self.session = Session(self, self.storage.dc_id(), self.storage.auth_key()) - self.session.start() + + await self.session.start() self.is_connected = True return bool(self.storage.user_id()) - def disconnect(self): + async def disconnect(self): """Disconnect the client from Telegram servers. Raises: @@ -296,11 +303,11 @@ class Client(Methods, BaseClient): if self.is_initialized: raise ConnectionError("Can't disconnect an initialized client") - self.session.stop() + await self.session.stop() self.storage.close() self.is_connected = False - def initialize(self): + async def initialize(self): """Initialize the client by starting up workers. This method will start updates and download workers. @@ -319,33 +326,26 @@ class Client(Methods, BaseClient): self.load_plugins() if not self.no_updates: - for i in range(self.UPDATES_WORKERS): - self.updates_workers_list.append( - Thread( - target=self.updates_worker, - name="UpdatesWorker#{}".format(i + 1) - ) + for _ in range(Client.UPDATES_WORKERS): + self.updates_worker_tasks.append( + asyncio.ensure_future(self.updates_worker()) ) - self.updates_workers_list[-1].start() + logging.info("Started {} UpdatesWorkerTasks".format(Client.UPDATES_WORKERS)) - for i in range(self.DOWNLOAD_WORKERS): - self.download_workers_list.append( - Thread( - target=self.download_worker, - name="DownloadWorker#{}".format(i + 1) - ) + for _ in range(Client.DOWNLOAD_WORKERS): + self.download_worker_tasks.append( + asyncio.ensure_future(self.download_worker()) ) - self.download_workers_list[-1].start() + logging.info("Started {} DownloadWorkerTasks".format(Client.DOWNLOAD_WORKERS)) - self.dispatcher.start() - - Syncer.add(self) + await self.dispatcher.start() + await Syncer.add(self) self.is_initialized = True - def terminate(self): + async def terminate(self): """Terminate the client by shutting down workers. This method does the opposite of :meth:`~Client.initialize`. @@ -358,37 +358,41 @@ class Client(Methods, BaseClient): raise ConnectionError("Client is already terminated") if self.takeout_id: - self.send(functions.account.FinishTakeoutSession()) + await self.send(functions.account.FinishTakeoutSession()) log.warning("Takeout session {} finished".format(self.takeout_id)) - Syncer.remove(self) - self.dispatcher.stop() + await Syncer.remove(self) + await self.dispatcher.stop() - for _ in range(self.DOWNLOAD_WORKERS): - self.download_queue.put(None) + for _ in range(Client.DOWNLOAD_WORKERS): + self.download_queue.put_nowait(None) - for i in self.download_workers_list: - i.join() + for task in self.download_worker_tasks: + await task - self.download_workers_list.clear() + self.download_worker_tasks.clear() + + logging.info("Stopped {} DownloadWorkerTasks".format(Client.DOWNLOAD_WORKERS)) if not self.no_updates: - for _ in range(self.UPDATES_WORKERS): - self.updates_queue.put(None) + for _ in range(Client.UPDATES_WORKERS): + self.updates_queue.put_nowait(None) - for i in self.updates_workers_list: - i.join() + for task in self.updates_worker_tasks: + await task - self.updates_workers_list.clear() + self.updates_worker_tasks.clear() - for i in self.media_sessions.values(): - i.stop() + logging.info("Stopped {} UpdatesWorkerTasks".format(Client.UPDATES_WORKERS)) + + for media_session in self.media_sessions.values(): + await media_session.stop() self.media_sessions.clear() self.is_initialized = False - def send_code(self, phone_number: str) -> SentCode: + async def send_code(self, phone_number: str) -> SentCode: """Send the confirmation code to the given phone number. Parameters: @@ -405,7 +409,7 @@ class Client(Methods, BaseClient): while True: try: - r = self.send( + r = await self.send( functions.auth.SendCode( phone_number=phone_number, api_id=self.api_id, @@ -414,17 +418,17 @@ class Client(Methods, BaseClient): ) ) except (PhoneMigrate, NetworkMigrate) as e: - self.session.stop() + await self.session.stop() self.storage.dc_id(e.x) - self.storage.auth_key(Auth(self, self.storage.dc_id()).create()) + self.storage.auth_key(await Auth(self, self.storage.dc_id()).create()) self.session = Session(self, self.storage.dc_id(), self.storage.auth_key()) - self.session.start() + await self.session.start() else: return SentCode._parse(r) - def resend_code(self, phone_number: str, phone_code_hash: str) -> SentCode: + async def resend_code(self, phone_number: str, phone_code_hash: str) -> SentCode: """Re-send the confirmation code using a different type. The type of the code to be re-sent is specified in the *next_type* attribute of the :obj:`SentCode` object @@ -445,7 +449,7 @@ class Client(Methods, BaseClient): """ phone_number = phone_number.strip(" +") - r = self.send( + r = await self.send( functions.auth.ResendCode( phone_number=phone_number, phone_code_hash=phone_code_hash @@ -454,7 +458,8 @@ class Client(Methods, BaseClient): return SentCode._parse(r) - def sign_in(self, phone_number: str, phone_code_hash: str, phone_code: str) -> Union[User, TermsOfService, bool]: + async def sign_in(self, phone_number: str, phone_code_hash: str, phone_code: str) -> Union[ + User, TermsOfService, bool]: """Authorize a user in Telegram with a valid confirmation code. Parameters: @@ -479,7 +484,7 @@ class Client(Methods, BaseClient): """ phone_number = phone_number.strip(" +") - r = self.send( + r = await self.send( functions.auth.SignIn( phone_number=phone_number, phone_code_hash=phone_code_hash, @@ -498,7 +503,7 @@ class Client(Methods, BaseClient): return User._parse(self, r.user) - def sign_up(self, phone_number: str, phone_code_hash: str, first_name: str, last_name: str = "") -> User: + async def sign_up(self, phone_number: str, phone_code_hash: str, first_name: str, last_name: str = "") -> User: """Register a new user in Telegram. Parameters: @@ -522,7 +527,7 @@ class Client(Methods, BaseClient): """ phone_number = phone_number.strip(" +") - r = self.send( + r = await self.send( functions.auth.SignUp( phone_number=phone_number, first_name=first_name, @@ -536,7 +541,7 @@ class Client(Methods, BaseClient): return User._parse(self, r.user) - def sign_in_bot(self, bot_token: str) -> User: + async def sign_in_bot(self, bot_token: str) -> User: """Authorize a bot using its bot token generated by BotFather. Parameters: @@ -551,7 +556,7 @@ class Client(Methods, BaseClient): """ while True: try: - r = self.send( + r = await self.send( functions.auth.ImportBotAuthorization( flags=0, api_id=self.api_id, @@ -560,28 +565,28 @@ class Client(Methods, BaseClient): ) ) except UserMigrate as e: - self.session.stop() + await self.session.stop() self.storage.dc_id(e.x) - self.storage.auth_key(Auth(self, self.storage.dc_id()).create()) + self.storage.auth_key(await Auth(self, self.storage.dc_id()).create()) self.session = Session(self, self.storage.dc_id(), self.storage.auth_key()) - self.session.start() + await self.session.start() else: self.storage.user_id(r.user.id) self.storage.is_bot(True) return User._parse(self, r.user) - def get_password_hint(self) -> str: + async def get_password_hint(self) -> str: """Get your Two-Step Verification password hint. Returns: ``str``: On success, the password hint as string is returned. """ - return self.send(functions.account.GetPassword()).hint + return (await self.send(functions.account.GetPassword())).hint - def check_password(self, password: str) -> User: + async def check_password(self, password: str) -> User: """Check your Two-Step Verification password and log in. Parameters: @@ -594,10 +599,10 @@ class Client(Methods, BaseClient): Raises: BadRequest: In case the password is invalid. """ - r = self.send( + r = await self.send( functions.auth.CheckPassword( password=compute_check( - self.send(functions.account.GetPassword()), + await self.send(functions.account.GetPassword()), password ) ) @@ -608,7 +613,7 @@ class Client(Methods, BaseClient): return User._parse(self, r.user) - def send_recovery_code(self) -> str: + async def send_recovery_code(self) -> str: """Send a code to your email to recover your password. Returns: @@ -617,11 +622,11 @@ class Client(Methods, BaseClient): Raises: BadRequest: In case no recovery email was set up. """ - return self.send( + return (await self.send( functions.auth.RequestPasswordRecovery() - ).email_pattern + )).email_pattern - def recover_password(self, recovery_code: str) -> User: + async def recover_password(self, recovery_code: str) -> User: """Recover your password with a recovery code and log in. Parameters: @@ -634,7 +639,7 @@ class Client(Methods, BaseClient): Raises: BadRequest: In case the recovery code is invalid. """ - r = self.send( + r = await self.send( functions.auth.RecoverPassword( code=recovery_code ) @@ -645,14 +650,14 @@ class Client(Methods, BaseClient): return User._parse(self, r.user) - def accept_terms_of_service(self, terms_of_service_id: str) -> bool: + async def accept_terms_of_service(self, terms_of_service_id: str) -> bool: """Accept the given terms of service. Parameters: terms_of_service_id (``str``): The terms of service identifier. """ - r = self.send( + r = await self.send( functions.help.AcceptTermsOfService( id=types.DataJSON( data=terms_of_service_id @@ -664,15 +669,15 @@ class Client(Methods, BaseClient): return True - def authorize(self) -> User: + async def authorize(self) -> User: if self.bot_token: - return self.sign_in_bot(self.bot_token) + return await self.sign_in_bot(self.bot_token) while True: try: if not self.phone_number: while True: - value = input("Enter phone number or bot token: ") + value = await ainput("Enter phone number or bot token: ") if not value: continue @@ -684,11 +689,11 @@ class Client(Methods, BaseClient): if ":" in value: self.bot_token = value - return self.sign_in_bot(value) + return await self.sign_in_bot(value) else: self.phone_number = value - sent_code = self.send_code(self.phone_number) + sent_code = await self.send_code(self.phone_number) except BadRequest as e: print(e.MESSAGE) self.phone_number = None @@ -697,7 +702,7 @@ class Client(Methods, BaseClient): break if self.force_sms: - sent_code = self.resend_code(self.phone_number, sent_code.phone_code_hash) + sent_code = await self.resend_code(self.phone_number, sent_code.phone_code_hash) print("The confirmation code has been sent via {}".format( { @@ -710,10 +715,10 @@ class Client(Methods, BaseClient): while True: if not self.phone_code: - self.phone_code = input("Enter confirmation code: ") + self.phone_code = await ainput("Enter confirmation code: ") try: - signed_in = self.sign_in(self.phone_number, sent_code.phone_code_hash, self.phone_code) + signed_in = await self.sign_in(self.phone_number, sent_code.phone_code_hash, self.phone_code) except BadRequest as e: print(e.MESSAGE) self.phone_code = None @@ -721,24 +726,24 @@ class Client(Methods, BaseClient): print(e.MESSAGE) while True: - print("Password hint: {}".format(self.get_password_hint())) + print("Password hint: {}".format(await self.get_password_hint())) if not self.password: - self.password = input("Enter password (empty to recover): ") + self.password = await ainput("Enter password (empty to recover): ") try: if not self.password: - confirm = input("Confirm password recovery (y/n): ") + confirm = await ainput("Confirm password recovery (y/n): ") if confirm == "y": - email_pattern = self.send_recovery_code() + email_pattern = await self.send_recovery_code() print("The recovery code has been sent to {}".format(email_pattern)) while True: - recovery_code = input("Enter recovery code: ") + recovery_code = await ainput("Enter recovery code: ") try: - return self.recover_password(recovery_code) + return await self.recover_password(recovery_code) except BadRequest as e: print(e.MESSAGE) except Exception as e: @@ -747,7 +752,7 @@ class Client(Methods, BaseClient): else: self.password = None else: - return self.check_password(self.password) + return await self.check_password(self.password) except BadRequest as e: print(e.MESSAGE) self.password = None @@ -758,11 +763,11 @@ class Client(Methods, BaseClient): return signed_in while True: - first_name = input("Enter first name: ") - last_name = input("Enter last name (empty to skip): ") + first_name = await ainput("Enter first name: ") + last_name = await ainput("Enter last name (empty to skip): ") try: - signed_up = self.sign_up( + signed_up = await self.sign_up( self.phone_number, sent_code.phone_code_hash, first_name, @@ -775,11 +780,11 @@ class Client(Methods, BaseClient): if isinstance(signed_in, TermsOfService): print("\n" + signed_in.text + "\n") - self.accept_terms_of_service(signed_in.id) + await self.accept_terms_of_service(signed_in.id) return signed_up - def log_out(self): + async def log_out(self): """Log out from Telegram and delete the *\\*.session* file. When you log out, the current client is stopped and the storage session deleted. @@ -794,13 +799,13 @@ class Client(Methods, BaseClient): # Log out. app.log_out() """ - self.send(functions.auth.LogOut()) - self.stop() + await self.send(functions.auth.LogOut()) + await self.stop() self.storage.delete() return True - def start(self): + async def start(self): """Start the client. This method connects the client to Telegram and, in case of new sessions, automatically manages the full @@ -825,25 +830,25 @@ class Client(Methods, BaseClient): app.stop() """ - is_authorized = self.connect() + is_authorized = await self.connect() try: if not is_authorized: - self.authorize() + await self.authorize() if not self.storage.is_bot() and self.takeout: - self.takeout_id = self.send(functions.account.InitTakeoutSession()).id + self.takeout_id = (await self.send(functions.account.InitTakeoutSession())).id log.warning("Takeout session {} initiated".format(self.takeout_id)) - self.send(functions.updates.GetState()) + await self.send(functions.updates.GetState()) except (Exception, KeyboardInterrupt): - self.disconnect() + await self.disconnect() raise else: - self.initialize() + await self.initialize() return self - def stop(self, block: bool = True): + async def stop(self, block: bool = True): """Stop the Client. This method disconnects the client from Telegram and stops the underlying tasks. @@ -874,18 +879,18 @@ class Client(Methods, BaseClient): app.stop() """ - def do_it(): - self.terminate() - self.disconnect() + async def do_it(): + await self.terminate() + await self.disconnect() if block: - do_it() + await do_it() else: - Thread(target=do_it).start() + asyncio.ensure_future(do_it()) return self - def restart(self, block: bool = True): + async def restart(self, block: bool = True): """Restart the Client. This method will first call :meth:`~Client.stop` and then :meth:`~Client.start` in a row in order to restart @@ -921,19 +926,19 @@ class Client(Methods, BaseClient): app.stop() """ - def do_it(): - self.stop() - self.start() + async def do_it(): + await self.stop() + await self.start() if block: - do_it() + await do_it() else: - Thread(target=do_it).start() + asyncio.ensure_future(do_it()) return self @staticmethod - def idle(stop_signals: tuple = (SIGINT, SIGTERM, SIGABRT)): + async def idle(stop_signals: tuple = (SIGINT, SIGTERM, SIGABRT)): """Block the main script execution until a signal is received. This static method will run an infinite loop in order to block the main script execution and prevent it from @@ -978,6 +983,7 @@ class Client(Methods, BaseClient): """ def signal_handler(_, __): + logging.info("Stop signal received ({}). Exiting...".format(_)) Client.is_idling = False for s in stop_signals: @@ -986,9 +992,9 @@ class Client(Methods, BaseClient): Client.is_idling = True while Client.is_idling: - time.sleep(1) + await asyncio.sleep(1) - def run(self): + def run(self, coroutine=None): """Start the client, idle the main script and finally stop the client. This is a convenience method that calls :meth:`~Client.start`, :meth:`~Client.idle` and :meth:`~Client.stop` in @@ -1010,9 +1016,17 @@ class Client(Methods, BaseClient): app.run() """ - self.start() - Client.idle() - self.stop() + loop = asyncio.get_event_loop() + run = loop.run_until_complete + + if coroutine is not None: + run(coroutine) + else: + run(self.start()) + run(Client.idle()) + run(self.stop()) + + loop.close() def add_handler(self, handler: Handler, group: int = 0): """Register an update handler. @@ -1236,12 +1250,9 @@ class Client(Methods, BaseClient): return is_min - def download_worker(self): - name = threading.current_thread().name - log.debug("{} started".format(name)) - + async def download_worker(self): while True: - packet = self.download_queue.get() + packet = await self.download_queue.get() if packet is None: break @@ -1252,7 +1263,7 @@ class Client(Methods, BaseClient): try: data, directory, file_name, done, progress, progress_args, path = packet - temp_file_path = self.get_file( + temp_file_path = await self.get_file( media_type=data.media_type, dc_id=data.dc_id, document_id=data.document_id, @@ -1289,14 +1300,9 @@ class Client(Methods, BaseClient): finally: done.set() - log.debug("{} stopped".format(name)) - - def updates_worker(self): - name = threading.current_thread().name - log.debug("{} started".format(name)) - + async def updates_worker(self): while True: - updates = self.updates_queue.get() + updates = await self.updates_queue.get() if updates is None: break @@ -1328,9 +1334,9 @@ class Client(Methods, BaseClient): if not isinstance(message, types.MessageEmpty): try: - diff = self.send( + diff = await self.send( functions.updates.GetChannelDifference( - channel=self.resolve_peer(utils.get_channel_id(channel_id)), + channel=await self.resolve_peer(utils.get_channel_id(channel_id)), filter=types.ChannelMessagesFilter( ranges=[types.MessageRange( min_id=update.message.id, @@ -1348,9 +1354,9 @@ class Client(Methods, BaseClient): users.update({u.id: u for u in diff.users}) chats.update({c.id: c for c in diff.chats}) - self.dispatcher.updates_queue.put((update, users, chats)) + self.dispatcher.updates_queue.put_nowait((update, users, chats)) elif isinstance(updates, (types.UpdateShortMessage, types.UpdateShortChatMessage)): - diff = self.send( + diff = await self.send( functions.updates.GetDifference( pts=updates.pts - updates.pts_count, date=updates.date, @@ -1359,7 +1365,7 @@ class Client(Methods, BaseClient): ) if diff.new_messages: - self.dispatcher.updates_queue.put(( + self.dispatcher.updates_queue.put_nowait(( types.UpdateNewMessage( message=diff.new_messages[0], pts=updates.pts, @@ -1369,17 +1375,15 @@ class Client(Methods, BaseClient): {c.id: c for c in diff.chats} )) else: - self.dispatcher.updates_queue.put((diff.other_updates[0], {}, {})) + self.dispatcher.updates_queue.put_nowait((diff.other_updates[0], {}, {})) elif isinstance(updates, types.UpdateShort): - self.dispatcher.updates_queue.put((updates.update, {}, {})) + self.dispatcher.updates_queue.put_nowait((updates.update, {}, {})) elif isinstance(updates, types.UpdatesTooLong): log.info(updates) except Exception as e: log.error(e, exc_info=True) - log.debug("{} stopped".format(name)) - - def send(self, data: TLObject, retries: int = Session.MAX_RETRIES, timeout: float = Session.WAIT_TIMEOUT): + async def send(self, data: TLObject, retries: int = Session.MAX_RETRIES, timeout: float = Session.WAIT_TIMEOUT): """Send raw Telegram queries. This method makes it possible to manually call every single Telegram API method in a low-level manner. @@ -1417,7 +1421,7 @@ class Client(Methods, BaseClient): if self.takeout_id: data = functions.InvokeWithTakeout(takeout_id=self.takeout_id, query=data) - r = self.session.send(data, retries, timeout, self.sleep_threshold) + r = await self.session.send(data, retries, timeout, self.sleep_threshold) self.fetch_peers(getattr(r, "users", [])) self.fetch_peers(getattr(r, "chats", [])) @@ -1497,7 +1501,7 @@ class Client(Methods, BaseClient): except KeyError: self.plugins = None - def load_session(self): + async def load_session(self): self.storage.open() session_empty = any([ @@ -1512,7 +1516,7 @@ class Client(Methods, BaseClient): self.storage.date(0) self.storage.test_mode(self.test_mode) - self.storage.auth_key(Auth(self, self.storage.dc_id()).create()) + self.storage.auth_key(await Auth(self, self.storage.dc_id()).create()) self.storage.user_id(None) self.storage.is_bot(None) @@ -1632,13 +1636,13 @@ class Client(Methods, BaseClient): self.session_name, name, module_path)) if count > 0: - log.warning('[{}] Successfully loaded {} plugin{} from "{}"'.format( + log.info('[{}] Successfully loaded {} plugin{} from "{}"'.format( self.session_name, count, "s" if count > 1 else "", root)) else: log.warning('[{}] No plugin loaded from "{}"'.format( self.session_name, root)) - def resolve_peer(self, peer_id: Union[int, str]): + async def resolve_peer(self, peer_id: Union[int, str]): """Get the InputPeer of a known peer id. Useful whenever an InputPeer type is required. @@ -1677,7 +1681,7 @@ class Client(Methods, BaseClient): try: return self.storage.get_peer_by_username(peer_id) except KeyError: - self.send( + await self.send( functions.contacts.ResolveUsername( username=peer_id ) @@ -1694,7 +1698,7 @@ class Client(Methods, BaseClient): if peer_type == "user": self.fetch_peers( - self.send( + await self.send( functions.users.GetUsers( id=[ types.InputUser( @@ -1706,13 +1710,13 @@ class Client(Methods, BaseClient): ) ) elif peer_type == "chat": - self.send( + await self.send( functions.messages.GetChats( id=[-peer_id] ) ) else: - self.send( + await self.send( functions.channels.GetChannels( id=[ types.InputChannel( @@ -1728,7 +1732,7 @@ class Client(Methods, BaseClient): except KeyError: raise PeerIdInvalid - def save_file( + async def save_file( self, path: Union[str, BinaryIO], file_id: int = None, @@ -1786,6 +1790,18 @@ class Client(Methods, BaseClient): if path is None: return None + async def worker(session): + while True: + data = await queue.get() + + if data is None: + return + + try: + await asyncio.ensure_future(session.send(data)) + except Exception as e: + logging.error(e) + part_size = 512 * 1024 if isinstance(path, str): @@ -1808,15 +1824,20 @@ class Client(Methods, BaseClient): raise ValueError("Telegram doesn't support uploading files bigger than 2000 MiB") file_total_parts = int(math.ceil(file_size / part_size)) - is_big = True if file_size > 10 * 1024 * 1024 else False - is_missing_part = True if file_id is not None else False + is_big = file_size > 10 * 1024 * 1024 + pool_size = 3 if is_big else 1 + workers_count = 4 if is_big else 1 + is_missing_part = file_id is not None file_id = file_id or self.rnd_id() md5_sum = md5() if not is_big and not is_missing_part else None - - session = Session(self, self.storage.dc_id(), self.storage.auth_key(), is_media=True) - session.start() + pool = [Session(self, self.storage.dc_id(), self.storage.auth_key(), is_media=True) for _ in range(pool_size)] + workers = [asyncio.ensure_future(worker(session)) for session in pool for _ in range(workers_count)] + queue = asyncio.Queue(16) try: + for session in pool: + await session.start() + with fp: fp.seek(part_size * file_part) @@ -1828,25 +1849,21 @@ class Client(Methods, BaseClient): md5_sum = "".join([hex(i)[2:].zfill(2) for i in md5_sum.digest()]) break - for _ in range(3): - if is_big: - rpc = functions.upload.SaveBigFilePart( - file_id=file_id, - file_part=file_part, - file_total_parts=file_total_parts, - bytes=chunk - ) - else: - rpc = functions.upload.SaveFilePart( - file_id=file_id, - file_part=file_part, - bytes=chunk - ) - - if session.send(rpc): - break + if is_big: + rpc = functions.upload.SaveBigFilePart( + file_id=file_id, + file_part=file_part, + file_total_parts=file_total_parts, + bytes=chunk + ) else: - raise AssertionError("Telegram didn't accept chunk #{} of {}".format(file_part, path)) + rpc = functions.upload.SaveFilePart( + file_id=file_id, + file_part=file_part, + bytes=chunk + ) + + await queue.put(rpc) if is_missing_part: return @@ -1857,7 +1874,7 @@ class Client(Methods, BaseClient): file_part += 1 if progress: - progress(min(file_part * part_size, file_size), file_size, *progress_args) + await progress(min(file_part * part_size, file_size), file_size, *progress_args) except Client.StopTransmission: raise except Exception as e: @@ -1878,9 +1895,15 @@ class Client(Methods, BaseClient): md5_checksum=md5_sum ) finally: - session.stop() + for _ in workers: + await queue.put(None) - def get_file( + await asyncio.gather(*workers) + + for session in pool: + await session.stop() + + async def get_file( self, media_type: int, dc_id: int, @@ -1898,23 +1921,23 @@ class Client(Methods, BaseClient): progress: callable, progress_args: tuple = () ) -> str: - with self.media_sessions_lock: + async with self.media_sessions_lock: session = self.media_sessions.get(dc_id, None) if session is None: if dc_id != self.storage.dc_id(): - session = Session(self, dc_id, Auth(self, dc_id).create(), is_media=True) - session.start() + session = Session(self, dc_id, await Auth(self, dc_id).create(), is_media=True) + await session.start() for _ in range(3): - exported_auth = self.send( + exported_auth = await self.send( functions.auth.ExportAuthorization( dc_id=dc_id ) ) try: - session.send( + await session.send( functions.auth.ImportAuthorization( id=exported_auth.id, bytes=exported_auth.bytes @@ -1925,11 +1948,11 @@ class Client(Methods, BaseClient): else: break else: - session.stop() + await session.stop() raise AuthBytesInvalid else: session = Session(self, dc_id, self.storage.auth_key(), is_media=True) - session.start() + await session.start() self.media_sessions[dc_id] = session @@ -1984,7 +2007,7 @@ class Client(Methods, BaseClient): file_name = "" try: - r = session.send( + r = await session.send( functions.upload.GetFile( location=location, offset=offset, @@ -2007,7 +2030,7 @@ class Client(Methods, BaseClient): offset += limit if progress: - progress( + await progress( min(offset, file_size) if file_size != 0 else offset, @@ -2015,7 +2038,7 @@ class Client(Methods, BaseClient): *progress_args ) - r = session.send( + r = await session.send( functions.upload.GetFile( location=location, offset=offset, @@ -2024,13 +2047,16 @@ class Client(Methods, BaseClient): ) elif isinstance(r, types.upload.FileCdnRedirect): - with self.media_sessions_lock: + async with self.media_sessions_lock: cdn_session = self.media_sessions.get(r.dc_id, None) if cdn_session is None: - cdn_session = Session(self, r.dc_id, Auth(self, r.dc_id).create(), is_media=True, is_cdn=True) + cdn_session = Session( + self, + r.dc_id, + await Auth(self, r.dc_id).create(), is_media=True, is_cdn=True) - cdn_session.start() + await cdn_session.start() self.media_sessions[r.dc_id] = cdn_session @@ -2039,7 +2065,7 @@ class Client(Methods, BaseClient): file_name = f.name while True: - r2 = cdn_session.send( + r2 = await cdn_session.send( functions.upload.GetCdnFile( file_token=r.file_token, offset=offset, @@ -2049,7 +2075,7 @@ class Client(Methods, BaseClient): if isinstance(r2, types.upload.CdnFileReuploadNeeded): try: - session.send( + await session.send( functions.upload.ReuploadCdnFile( file_token=r.file_token, request_token=r2.request_token @@ -2072,7 +2098,7 @@ class Client(Methods, BaseClient): ) ) - hashes = session.send( + hashes = await session.send( functions.upload.GetCdnFileHashes( file_token=r.file_token, offset=offset @@ -2089,7 +2115,7 @@ class Client(Methods, BaseClient): offset += limit if progress: - progress( + await progress( min(offset, file_size) if file_size != 0 else offset, diff --git a/pyrogram/client/ext/base_client.py b/pyrogram/client/ext/base_client.py index 750dc3fc..ca4e8f5b 100644 --- a/pyrogram/client/ext/base_client.py +++ b/pyrogram/client/ext/base_client.py @@ -16,13 +16,12 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +import asyncio import os import platform import re import sys from pathlib import Path -from queue import Queue -from threading import Lock from pyrogram import __version__ from ..parser import Parser @@ -30,7 +29,7 @@ from ...session.internals import MsgId class BaseClient: - class StopTransmission(StopIteration): + class StopTransmission(StopAsyncIteration): pass APP_VERSION = "Pyrogram {}".format(__version__) @@ -52,7 +51,7 @@ class BaseClient: INVITE_LINK_RE = re.compile(r"^(?:https?://)?(?:www\.)?(?:t(?:elegram)?\.(?:org|me|dog)/joinchat/)([\w-]+)$") DIALOGS_AT_ONCE = 100 UPDATES_WORKERS = 4 - DOWNLOAD_WORKERS = 1 + DOWNLOAD_WORKERS = 4 OFFLINE_SLEEP = 900 WORKERS = 4 WORKDIR = PARENT_DIR @@ -100,24 +99,24 @@ class BaseClient: self.session = None self.media_sessions = {} - self.media_sessions_lock = Lock() + self.media_sessions_lock = asyncio.Lock() self.is_connected = None self.is_initialized = None self.takeout_id = None - self.updates_queue = Queue() - self.updates_workers_list = [] - self.download_queue = Queue() - self.download_workers_list = [] + self.updates_queue = asyncio.Queue() + self.updates_worker_tasks = [] + self.download_queue = asyncio.Queue() + self.download_worker_tasks = [] self.disconnect_handler = None - def send(self, *args, **kwargs): + async def send(self, *args, **kwargs): pass - def resolve_peer(self, *args, **kwargs): + async def resolve_peer(self, *args, **kwargs): pass def fetch_peers(self, *args, **kwargs): @@ -126,25 +125,46 @@ class BaseClient: def add_handler(self, *args, **kwargs): pass - def save_file(self, *args, **kwargs): + async def save_file(self, *args, **kwargs): pass - def get_messages(self, *args, **kwargs): + async def get_messages(self, *args, **kwargs): pass - def get_history(self, *args, **kwargs): + async def get_history(self, *args, **kwargs): pass - def get_dialogs(self, *args, **kwargs): + async def get_dialogs(self, *args, **kwargs): pass - def get_chat_members(self, *args, **kwargs): + async def get_chat_members(self, *args, **kwargs): pass - def get_chat_members_count(self, *args, **kwargs): + async def get_chat_members_count(self, *args, **kwargs): pass - def answer_inline_query(self, *args, **kwargs): + async def answer_inline_query(self, *args, **kwargs): + pass + + async def get_profile_photos(self, *args, **kwargs): + pass + + async def edit_message_text(self, *args, **kwargs): + pass + + async def edit_inline_text(self, *args, **kwargs): + pass + + async def edit_message_media(self, *args, **kwargs): + pass + + async def edit_inline_media(self, *args, **kwargs): + pass + + async def edit_message_reply_markup(self, *args, **kwargs): + pass + + async def edit_inline_reply_markup(self, *args, **kwargs): pass def guess_mime_type(self, *args, **kwargs): @@ -152,24 +172,3 @@ class BaseClient: def guess_extension(self, *args, **kwargs): pass - - def get_profile_photos(self, *args, **kwargs): - pass - - def edit_message_text(self, *args, **kwargs): - pass - - def edit_inline_text(self, *args, **kwargs): - pass - - def edit_message_media(self, *args, **kwargs): - pass - - def edit_inline_media(self, *args, **kwargs): - pass - - def edit_message_reply_markup(self, *args, **kwargs): - pass - - def edit_inline_reply_markup(self, *args, **kwargs): - pass diff --git a/pyrogram/client/ext/dispatcher.py b/pyrogram/client/ext/dispatcher.py index 256dd5f2..818bc60c 100644 --- a/pyrogram/client/ext/dispatcher.py +++ b/pyrogram/client/ext/dispatcher.py @@ -16,11 +16,9 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +import asyncio import logging -import threading from collections import OrderedDict -from queue import Queue -from threading import Thread, Lock import pyrogram from pyrogram.api.types import ( @@ -69,105 +67,106 @@ class Dispatcher: self.client = client self.workers = workers - self.workers_list = [] + self.update_worker_tasks = [] self.locks_list = [] - self.updates_queue = Queue() + self.updates_queue = asyncio.Queue() self.groups = OrderedDict() + async def message_parser(update, users, chats): + return await pyrogram.Message._parse( + self.client, update.message, users, chats, + isinstance(update, UpdateNewScheduledMessage) + ), MessageHandler + + async def deleted_messages_parser(update, users, chats): + return utils.parse_deleted_messages(self.client, update), DeletedMessagesHandler + + async def callback_query_parser(update, users, chats): + return await pyrogram.CallbackQuery._parse(self.client, update, users), CallbackQueryHandler + + async def user_status_parser(update, users, chats): + return pyrogram.User._parse_user_status(self.client, update), UserStatusHandler + + async def inline_query_parser(update, users, chats): + return pyrogram.InlineQuery._parse(self.client, update, users), InlineQueryHandler + + async def poll_parser(update, users, chats): + return pyrogram.Poll._parse_update(self.client, update), PollHandler + + async def chosen_inline_result_parser(update, users, chats): + return pyrogram.ChosenInlineResult._parse(self.client, update, users), ChosenInlineResultHandler + self.update_parsers = { - Dispatcher.MESSAGE_UPDATES: - lambda upd, usr, cht: ( - pyrogram.Message._parse( - self.client, - upd.message, - usr, - cht, - isinstance(upd, UpdateNewScheduledMessage) - ), - MessageHandler - ), - - Dispatcher.DELETE_MESSAGES_UPDATES: - lambda upd, usr, cht: (utils.parse_deleted_messages(self.client, upd), DeletedMessagesHandler), - - Dispatcher.CALLBACK_QUERY_UPDATES: - lambda upd, usr, cht: (pyrogram.CallbackQuery._parse(self.client, upd, usr), CallbackQueryHandler), - - (UpdateUserStatus,): - lambda upd, usr, cht: (pyrogram.User._parse_user_status(self.client, upd), UserStatusHandler), - - (UpdateBotInlineQuery,): - lambda upd, usr, cht: (pyrogram.InlineQuery._parse(self.client, upd, usr), InlineQueryHandler), - - (UpdateMessagePoll,): - lambda upd, usr, cht: (pyrogram.Poll._parse_update(self.client, upd), PollHandler), - - (UpdateBotInlineSend,): - lambda upd, usr, cht: (pyrogram.ChosenInlineResult._parse(self.client, upd, usr), - ChosenInlineResultHandler) + Dispatcher.MESSAGE_UPDATES: message_parser, + Dispatcher.DELETE_MESSAGES_UPDATES: deleted_messages_parser, + Dispatcher.CALLBACK_QUERY_UPDATES: callback_query_parser, + (UpdateUserStatus,): user_status_parser, + (UpdateBotInlineQuery,): inline_query_parser, + (UpdateMessagePoll,): poll_parser, + (UpdateBotInlineSend,): chosen_inline_result_parser } self.update_parsers = {key: value for key_tuple, value in self.update_parsers.items() for key in key_tuple} - def start(self): + async def start(self): for i in range(self.workers): - self.locks_list.append(Lock()) + self.locks_list.append(asyncio.Lock()) - self.workers_list.append( - Thread( - target=self.update_worker, - name="UpdateWorker#{}".format(i + 1), - args=(self.locks_list[-1],) - ) + self.update_worker_tasks.append( + asyncio.ensure_future(self.update_worker(self.locks_list[-1])) ) - self.workers_list[-1].start() + logging.info("Started {} UpdateWorkerTasks".format(self.workers)) - def stop(self): - for _ in range(self.workers): - self.updates_queue.put(None) + async def stop(self): + for i in range(self.workers): + self.updates_queue.put_nowait(None) - for worker in self.workers_list: - worker.join() + for i in self.update_worker_tasks: + await i - self.workers_list.clear() - self.locks_list.clear() + self.update_worker_tasks.clear() self.groups.clear() + logging.info("Stopped {} UpdateWorkerTasks".format(self.workers)) + def add_handler(self, handler, group: int): - for lock in self.locks_list: - lock.acquire() - - try: - if group not in self.groups: - self.groups[group] = [] - self.groups = OrderedDict(sorted(self.groups.items())) - - self.groups[group].append(handler) - finally: + async def fn(): for lock in self.locks_list: - lock.release() + await lock.acquire() + + try: + if group not in self.groups: + self.groups[group] = [] + self.groups = OrderedDict(sorted(self.groups.items())) + + self.groups[group].append(handler) + finally: + for lock in self.locks_list: + lock.release() + + asyncio.ensure_future(fn()) def remove_handler(self, handler, group: int): - for lock in self.locks_list: - lock.acquire() - - try: - if group not in self.groups: - raise ValueError("Group {} does not exist. Handler was not removed.".format(group)) - - self.groups[group].remove(handler) - finally: + async def fn(): for lock in self.locks_list: - lock.release() + await lock.acquire() - def update_worker(self, lock): - name = threading.current_thread().name - log.debug("{} started".format(name)) + try: + if group not in self.groups: + raise ValueError("Group {} does not exist. Handler was not removed.".format(group)) + self.groups[group].remove(handler) + finally: + for lock in self.locks_list: + lock.release() + + asyncio.ensure_future(fn()) + + async def update_worker(self, lock): while True: - packet = self.updates_queue.get() + packet = await self.updates_queue.get() if packet is None: break @@ -177,12 +176,12 @@ class Dispatcher: parser = self.update_parsers.get(type(update), None) parsed_update, handler_type = ( - parser(update, users, chats) + await parser(update, users, chats) if parser is not None else (None, type(None)) ) - with lock: + async with lock: for group in self.groups.values(): for handler in group: args = None @@ -202,7 +201,7 @@ class Dispatcher: continue try: - handler.callback(self.client, *args) + await handler.callback(self.client, *args) except pyrogram.StopPropagation: raise except pyrogram.ContinuePropagation: @@ -215,5 +214,3 @@ class Dispatcher: pass except Exception as e: log.error(e, exc_info=True) - - log.debug("{} stopped".format(name)) diff --git a/pyrogram/client/ext/syncer.py b/pyrogram/client/ext/syncer.py index bfe99c98..65fb2104 100644 --- a/pyrogram/client/ext/syncer.py +++ b/pyrogram/client/ext/syncer.py @@ -16,9 +16,9 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +import asyncio import logging import time -from threading import Thread, Event, Lock log = logging.getLogger(__name__) @@ -27,13 +27,18 @@ class Syncer: INTERVAL = 20 clients = {} - thread = None - event = Event() - lock = Lock() + event = None + lock = None @classmethod - def add(cls, client): - with cls.lock: + async def add(cls, client): + if cls.event is None: + cls.event = asyncio.Event() + + if cls.lock is None: + cls.lock = asyncio.Lock() + + async with cls.lock: cls.sync(client) cls.clients[id(client)] = client @@ -42,8 +47,8 @@ class Syncer: cls.start() @classmethod - def remove(cls, client): - with cls.lock: + async def remove(cls, client): + async with cls.lock: cls.sync(client) del cls.clients[id(client)] @@ -54,25 +59,24 @@ class Syncer: @classmethod def start(cls): cls.event.clear() - cls.thread = Thread(target=cls.worker, name=cls.__name__) - cls.thread.start() + asyncio.ensure_future(cls.worker()) @classmethod def stop(cls): cls.event.set() @classmethod - def worker(cls): + async def worker(cls): while True: - cls.event.wait(cls.INTERVAL) - - if cls.event.is_set(): + try: + await asyncio.wait_for(cls.event.wait(), cls.INTERVAL) + except asyncio.TimeoutError: + async with cls.lock: + for client in cls.clients.values(): + cls.sync(client) + else: break - with cls.lock: - for client in cls.clients.values(): - cls.sync(client) - @classmethod def sync(cls, client): try: diff --git a/pyrogram/client/ext/utils.py b/pyrogram/client/ext/utils.py index b46a8038..9359ff05 100644 --- a/pyrogram/client/ext/utils.py +++ b/pyrogram/client/ext/utils.py @@ -16,8 +16,11 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +import asyncio import base64 import struct +import sys +from concurrent.futures.thread import ThreadPoolExecutor from typing import List from typing import Union @@ -80,6 +83,15 @@ def decode_file_ref(file_ref: str) -> bytes: return base64.urlsafe_b64decode(file_ref + "=" * (-len(file_ref) % 4)) +async def ainput(prompt: str = ""): + print(prompt, end="", flush=True) + + with ThreadPoolExecutor(1) as executor: + return (await asyncio.get_event_loop().run_in_executor( + executor, sys.stdin.readline + )).rstrip() + + def get_offset_date(dialogs): for m in reversed(dialogs.messages): if isinstance(m, types.MessageEmpty): @@ -141,24 +153,24 @@ def get_input_media_from_file_id( raise ValueError("Unknown media type: {}".format(file_id_str)) -def parse_messages(client, messages: types.messages.Messages, replies: int = 1) -> List["pyrogram.Message"]: +async def parse_messages(client, messages: types.messages.Messages, replies: int = 1) -> List["pyrogram.Message"]: users = {i.id: i for i in messages.users} chats = {i.id: i for i in messages.chats} if not messages.messages: return pyrogram.List() - parsed_messages = [ - pyrogram.Message._parse(client, message, users, chats, replies=0) - for message in messages.messages - ] + parsed_messages = [] + + for message in messages.messages: + parsed_messages.append(await pyrogram.Message._parse(client, message, users, chats, replies=0)) if replies: messages_with_replies = {i.id: getattr(i, "reply_to_msg_id", None) for i in messages.messages} reply_message_ids = [i[0] for i in filter(lambda x: x[1] is not None, messages_with_replies.items())] if reply_message_ids: - reply_messages = client.get_messages( + reply_messages = await client.get_messages( parsed_messages[0].chat.id, reply_to_message_ids=reply_message_ids, replies=replies - 1 diff --git a/pyrogram/client/methods/bots/answer_callback_query.py b/pyrogram/client/methods/bots/answer_callback_query.py index 2e00c07c..ff9a5b51 100644 --- a/pyrogram/client/methods/bots/answer_callback_query.py +++ b/pyrogram/client/methods/bots/answer_callback_query.py @@ -21,7 +21,7 @@ from pyrogram.client.ext import BaseClient class AnswerCallbackQuery(BaseClient): - def answer_callback_query( + async def answer_callback_query( self, callback_query_id: str, text: str = None, @@ -68,7 +68,7 @@ class AnswerCallbackQuery(BaseClient): # Answer with alert app.answer_callback_query(query_id, text=text, show_alert=True) """ - return self.send( + return await self.send( functions.messages.SetBotCallbackAnswer( query_id=int(callback_query_id), cache_time=cache_time, diff --git a/pyrogram/client/methods/bots/answer_inline_query.py b/pyrogram/client/methods/bots/answer_inline_query.py index 2f95c9b9..69b9184d 100644 --- a/pyrogram/client/methods/bots/answer_inline_query.py +++ b/pyrogram/client/methods/bots/answer_inline_query.py @@ -24,7 +24,7 @@ from ...types.inline_mode import InlineQueryResult class AnswerInlineQuery(BaseClient): - def answer_inline_query( + async def answer_inline_query( self, inline_query_id: str, results: List[InlineQueryResult], @@ -93,10 +93,15 @@ class AnswerInlineQuery(BaseClient): "Title", InputTextMessageContent("Message content"))]) """ - return self.send( + written_results = [] # Py 3.5 doesn't support await inside comprehensions + + for r in results: + written_results.append(await r.write()) + + return await self.send( functions.messages.SetInlineBotResults( query_id=int(inline_query_id), - results=[r.write() for r in results], + results=written_results, cache_time=cache_time, gallery=is_gallery or None, private=is_personal or None, diff --git a/pyrogram/client/methods/bots/get_game_high_scores.py b/pyrogram/client/methods/bots/get_game_high_scores.py index 1cebc8a6..c40350ad 100644 --- a/pyrogram/client/methods/bots/get_game_high_scores.py +++ b/pyrogram/client/methods/bots/get_game_high_scores.py @@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient class GetGameHighScores(BaseClient): - def get_game_high_scores( + async def get_game_high_scores( self, user_id: Union[int, str], chat_id: Union[int, str], @@ -59,11 +59,11 @@ class GetGameHighScores(BaseClient): """ # TODO: inline_message_id - r = self.send( + r = await self.send( functions.messages.GetGameHighScores( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), id=message_id, - user_id=self.resolve_peer(user_id) + user_id=await self.resolve_peer(user_id) ) ) diff --git a/pyrogram/client/methods/bots/get_inline_bot_results.py b/pyrogram/client/methods/bots/get_inline_bot_results.py index aa27b7c9..366594fc 100644 --- a/pyrogram/client/methods/bots/get_inline_bot_results.py +++ b/pyrogram/client/methods/bots/get_inline_bot_results.py @@ -24,7 +24,7 @@ from pyrogram.errors import UnknownError class GetInlineBotResults(BaseClient): - def get_inline_bot_results( + async def get_inline_bot_results( self, bot: Union[int, str], query: str = "", @@ -70,9 +70,9 @@ class GetInlineBotResults(BaseClient): # TODO: Don't return the raw type try: - return self.send( + return await self.send( functions.messages.GetInlineBotResults( - bot=self.resolve_peer(bot), + bot=await self.resolve_peer(bot), peer=types.InputPeerSelf(), query=query, offset=offset, diff --git a/pyrogram/client/methods/bots/request_callback_answer.py b/pyrogram/client/methods/bots/request_callback_answer.py index 6178b940..97eacf0d 100644 --- a/pyrogram/client/methods/bots/request_callback_answer.py +++ b/pyrogram/client/methods/bots/request_callback_answer.py @@ -23,7 +23,7 @@ from pyrogram.client.ext import BaseClient class RequestCallbackAnswer(BaseClient): - def request_callback_answer( + async def request_callback_answer( self, chat_id: Union[int, str], message_id: int, @@ -64,9 +64,9 @@ class RequestCallbackAnswer(BaseClient): # Telegram only wants bytes, but we are allowed to pass strings too. data = bytes(callback_data, "utf-8") if isinstance(callback_data, str) else callback_data - return self.send( + return await self.send( functions.messages.GetBotCallbackAnswer( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), msg_id=message_id, data=data ), diff --git a/pyrogram/client/methods/bots/send_game.py b/pyrogram/client/methods/bots/send_game.py index e9513ac8..4b4d2c02 100644 --- a/pyrogram/client/methods/bots/send_game.py +++ b/pyrogram/client/methods/bots/send_game.py @@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient class SendGame(BaseClient): - def send_game( + async def send_game( self, chat_id: Union[int, str], game_short_name: str, @@ -67,9 +67,9 @@ class SendGame(BaseClient): app.send_game(chat_id, "gamename") """ - r = self.send( + r = await self.send( functions.messages.SendMedia( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), media=types.InputMediaGame( id=types.InputGameShortName( bot_id=types.InputUserSelf(), @@ -86,7 +86,7 @@ class SendGame(BaseClient): for i in r.updates: if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage)): - return pyrogram.Message._parse( + return await pyrogram.Message._parse( self, i.message, {i.id: i for i in r.users}, {i.id: i for i in r.chats} diff --git a/pyrogram/client/methods/bots/send_inline_bot_result.py b/pyrogram/client/methods/bots/send_inline_bot_result.py index 9b2cdf60..8cc5bf11 100644 --- a/pyrogram/client/methods/bots/send_inline_bot_result.py +++ b/pyrogram/client/methods/bots/send_inline_bot_result.py @@ -23,7 +23,7 @@ from pyrogram.client.ext import BaseClient class SendInlineBotResult(BaseClient): - def send_inline_bot_result( + async def send_inline_bot_result( self, chat_id: Union[int, str], query_id: int, @@ -65,9 +65,9 @@ class SendInlineBotResult(BaseClient): app.send_inline_bot_result(chat_id, query_id, result_id) """ - return self.send( + return await self.send( functions.messages.SendInlineBotResult( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), query_id=query_id, id=result_id, random_id=self.rnd_id(), diff --git a/pyrogram/client/methods/bots/set_game_score.py b/pyrogram/client/methods/bots/set_game_score.py index 25d8fc0b..3912d294 100644 --- a/pyrogram/client/methods/bots/set_game_score.py +++ b/pyrogram/client/methods/bots/set_game_score.py @@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient class SetGameScore(BaseClient): - def set_game_score( + async def set_game_score( self, user_id: Union[int, str], score: int, @@ -75,12 +75,12 @@ class SetGameScore(BaseClient): # Force set new score app.set_game_score(user_id, 25, force=True) """ - r = self.send( + r = await self.send( functions.messages.SetGameScore( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), score=score, id=message_id, - user_id=self.resolve_peer(user_id), + user_id=await self.resolve_peer(user_id), force=force or None, edit_message=not disable_edit_message or None ) @@ -88,7 +88,7 @@ class SetGameScore(BaseClient): for i in r.updates: if isinstance(i, (types.UpdateEditMessage, types.UpdateEditChannelMessage)): - return pyrogram.Message._parse( + return await pyrogram.Message._parse( self, i.message, {i.id: i for i in r.users}, {i.id: i for i in r.chats} diff --git a/pyrogram/client/methods/chats/add_chat_members.py b/pyrogram/client/methods/chats/add_chat_members.py index b04d5555..9a5f18ea 100644 --- a/pyrogram/client/methods/chats/add_chat_members.py +++ b/pyrogram/client/methods/chats/add_chat_members.py @@ -23,7 +23,7 @@ from ...ext import BaseClient class AddChatMembers(BaseClient): - def add_chat_members( + async def add_chat_members( self, chat_id: Union[int, str], user_ids: Union[Union[int, str], List[Union[int, str]]], @@ -60,26 +60,26 @@ class AddChatMembers(BaseClient): # Change forward_limit (for basic groups only) app.add_chat_members(chat_id, user_id, forward_limit=25) """ - peer = self.resolve_peer(chat_id) + peer = await self.resolve_peer(chat_id) if not isinstance(user_ids, list): user_ids = [user_ids] if isinstance(peer, types.InputPeerChat): for user_id in user_ids: - self.send( + await self.send( functions.messages.AddChatUser( chat_id=peer.chat_id, - user_id=self.resolve_peer(user_id), + user_id=await self.resolve_peer(user_id), fwd_limit=forward_limit ) ) else: - self.send( + await self.send( functions.channels.InviteToChannel( channel=peer, users=[ - self.resolve_peer(user_id) + await self.resolve_peer(user_id) for user_id in user_ids ] ) diff --git a/pyrogram/client/methods/chats/archive_chats.py b/pyrogram/client/methods/chats/archive_chats.py index 3c1cabf7..54c452a2 100644 --- a/pyrogram/client/methods/chats/archive_chats.py +++ b/pyrogram/client/methods/chats/archive_chats.py @@ -23,7 +23,7 @@ from ...ext import BaseClient class ArchiveChats(BaseClient): - def archive_chats( + async def archive_chats( self, chat_ids: Union[int, str, List[Union[int, str]]], ) -> bool: @@ -50,14 +50,19 @@ class ArchiveChats(BaseClient): if not isinstance(chat_ids, list): chat_ids = [chat_ids] - self.send( + folder_peers = [] + + for chat in chat_ids: + folder_peers.append( + types.InputFolderPeer( + peer=await self.resolve_peer(chat), + folder_id=1 + ) + ) + + await self.send( functions.folders.EditPeerFolders( - folder_peers=[ - types.InputFolderPeer( - peer=self.resolve_peer(chat), - folder_id=1 - ) for chat in chat_ids - ] + folder_peers=folder_peers ) ) diff --git a/pyrogram/client/methods/chats/create_channel.py b/pyrogram/client/methods/chats/create_channel.py index 5986f703..7885ed3e 100644 --- a/pyrogram/client/methods/chats/create_channel.py +++ b/pyrogram/client/methods/chats/create_channel.py @@ -22,7 +22,7 @@ from ...ext import BaseClient class CreateChannel(BaseClient): - def create_channel( + async def create_channel( self, title: str, description: str = "" @@ -44,7 +44,7 @@ class CreateChannel(BaseClient): app.create_channel("Channel Title", "Channel Description") """ - r = self.send( + r = await self.send( functions.channels.CreateChannel( title=title, about=description, diff --git a/pyrogram/client/methods/chats/create_group.py b/pyrogram/client/methods/chats/create_group.py index 43ec6e7f..631aa75a 100644 --- a/pyrogram/client/methods/chats/create_group.py +++ b/pyrogram/client/methods/chats/create_group.py @@ -24,7 +24,7 @@ from ...ext import BaseClient class CreateGroup(BaseClient): - def create_group( + async def create_group( self, title: str, users: Union[Union[int, str], List[Union[int, str]]] @@ -55,10 +55,10 @@ class CreateGroup(BaseClient): if not isinstance(users, list): users = [users] - r = self.send( + r = await self.send( functions.messages.CreateChat( title=title, - users=[self.resolve_peer(u) for u in users] + users=[await self.resolve_peer(u) for u in users] ) ) diff --git a/pyrogram/client/methods/chats/create_supergroup.py b/pyrogram/client/methods/chats/create_supergroup.py index 139064ec..1310d65e 100644 --- a/pyrogram/client/methods/chats/create_supergroup.py +++ b/pyrogram/client/methods/chats/create_supergroup.py @@ -22,7 +22,7 @@ from ...ext import BaseClient class CreateSupergroup(BaseClient): - def create_supergroup( + async def create_supergroup( self, title: str, description: str = "" @@ -48,7 +48,7 @@ class CreateSupergroup(BaseClient): app.create_supergroup("Supergroup Title", "Supergroup Description") """ - r = self.send( + r = await self.send( functions.channels.CreateChannel( title=title, about=description, diff --git a/pyrogram/client/methods/chats/delete_channel.py b/pyrogram/client/methods/chats/delete_channel.py index fd07b0e6..ff62e8d6 100644 --- a/pyrogram/client/methods/chats/delete_channel.py +++ b/pyrogram/client/methods/chats/delete_channel.py @@ -23,7 +23,7 @@ from ...ext import BaseClient class DeleteChannel(BaseClient): - def delete_channel(self, chat_id: Union[int, str]) -> bool: + async def delete_channel(self, chat_id: Union[int, str]) -> bool: """Delete a channel. Parameters: @@ -38,9 +38,9 @@ class DeleteChannel(BaseClient): app.delete_channel(channel_id) """ - self.send( + await self.send( functions.channels.DeleteChannel( - channel=self.resolve_peer(chat_id) + channel=await self.resolve_peer(chat_id) ) ) diff --git a/pyrogram/client/methods/chats/delete_chat_photo.py b/pyrogram/client/methods/chats/delete_chat_photo.py index 655d6fd6..cc2e06dd 100644 --- a/pyrogram/client/methods/chats/delete_chat_photo.py +++ b/pyrogram/client/methods/chats/delete_chat_photo.py @@ -23,7 +23,7 @@ from ...ext import BaseClient class DeleteChatPhoto(BaseClient): - def delete_chat_photo( + async def delete_chat_photo( self, chat_id: Union[int, str] ) -> bool: @@ -46,17 +46,17 @@ class DeleteChatPhoto(BaseClient): app.delete_chat_photo(chat_id) """ - peer = self.resolve_peer(chat_id) + peer = await self.resolve_peer(chat_id) if isinstance(peer, types.InputPeerChat): - self.send( + await self.send( functions.messages.EditChatPhoto( chat_id=peer.chat_id, photo=types.InputChatPhotoEmpty() ) ) elif isinstance(peer, types.InputPeerChannel): - self.send( + await self.send( functions.channels.EditPhoto( channel=peer, photo=types.InputChatPhotoEmpty() diff --git a/pyrogram/client/methods/chats/delete_supergroup.py b/pyrogram/client/methods/chats/delete_supergroup.py index df4649e5..f24c55a1 100644 --- a/pyrogram/client/methods/chats/delete_supergroup.py +++ b/pyrogram/client/methods/chats/delete_supergroup.py @@ -23,7 +23,7 @@ from ...ext import BaseClient class DeleteSupergroup(BaseClient): - def delete_supergroup(self, chat_id: Union[int, str]) -> bool: + async def delete_supergroup(self, chat_id: Union[int, str]) -> bool: """Delete a supergroup. Parameters: @@ -38,9 +38,9 @@ class DeleteSupergroup(BaseClient): app.delete_supergroup(supergroup_id) """ - self.send( + await self.send( functions.channels.DeleteChannel( - channel=self.resolve_peer(chat_id) + channel=await self.resolve_peer(chat_id) ) ) diff --git a/pyrogram/client/methods/chats/export_chat_invite_link.py b/pyrogram/client/methods/chats/export_chat_invite_link.py index 671c1ade..ac0d3b91 100644 --- a/pyrogram/client/methods/chats/export_chat_invite_link.py +++ b/pyrogram/client/methods/chats/export_chat_invite_link.py @@ -23,7 +23,7 @@ from ...ext import BaseClient class ExportChatInviteLink(BaseClient): - def export_chat_invite_link( + async def export_chat_invite_link( self, chat_id: Union[int, str] ) -> str: @@ -55,13 +55,13 @@ class ExportChatInviteLink(BaseClient): link = app.export_chat_invite_link(chat_id) print(link) """ - peer = self.resolve_peer(chat_id) + peer = await self.resolve_peer(chat_id) if isinstance(peer, (types.InputPeerChat, types.InputPeerChannel)): - return self.send( + return (await self.send( functions.messages.ExportChatInvite( peer=peer ) - ).link + )).link else: raise ValueError('The chat_id "{}" belongs to a user'.format(chat_id)) diff --git a/pyrogram/client/methods/chats/get_chat.py b/pyrogram/client/methods/chats/get_chat.py index 14adc1a7..3c945070 100644 --- a/pyrogram/client/methods/chats/get_chat.py +++ b/pyrogram/client/methods/chats/get_chat.py @@ -24,7 +24,7 @@ from ...ext import BaseClient, utils class GetChat(BaseClient): - def get_chat( + async def get_chat( self, chat_id: Union[int, str] ) -> Union["pyrogram.Chat", "pyrogram.ChatPreview"]: @@ -55,7 +55,7 @@ class GetChat(BaseClient): match = self.INVITE_LINK_RE.match(str(chat_id)) if match: - r = self.send( + r = await self.send( functions.messages.CheckChatInvite( hash=match.group(1) ) @@ -72,13 +72,13 @@ class GetChat(BaseClient): if isinstance(r.chat, types.Channel): chat_id = utils.get_channel_id(r.chat.id) - peer = self.resolve_peer(chat_id) + peer = await self.resolve_peer(chat_id) if isinstance(peer, types.InputPeerChannel): - r = self.send(functions.channels.GetFullChannel(channel=peer)) + r = await self.send(functions.channels.GetFullChannel(channel=peer)) elif isinstance(peer, (types.InputPeerUser, types.InputPeerSelf)): - r = self.send(functions.users.GetFullUser(id=peer)) + r = await self.send(functions.users.GetFullUser(id=peer)) else: - r = self.send(functions.messages.GetFullChat(chat_id=peer.chat_id)) + r = await self.send(functions.messages.GetFullChat(chat_id=peer.chat_id)) - return pyrogram.Chat._parse_full(self, r) + return await pyrogram.Chat._parse_full(self, r) diff --git a/pyrogram/client/methods/chats/get_chat_member.py b/pyrogram/client/methods/chats/get_chat_member.py index 9a7bdeff..b77bca85 100644 --- a/pyrogram/client/methods/chats/get_chat_member.py +++ b/pyrogram/client/methods/chats/get_chat_member.py @@ -21,11 +21,12 @@ from typing import Union import pyrogram from pyrogram.api import functions, types from pyrogram.errors import UserNotParticipant + from ...ext import BaseClient class GetChatMember(BaseClient): - def get_chat_member( + async def get_chat_member( self, chat_id: Union[int, str], user_id: Union[int, str] @@ -50,11 +51,11 @@ class GetChatMember(BaseClient): dan = app.get_chat_member("pyrogramchat", "haskell") print(dan) """ - chat = self.resolve_peer(chat_id) - user = self.resolve_peer(user_id) + chat = await self.resolve_peer(chat_id) + user = await self.resolve_peer(user_id) if isinstance(chat, types.InputPeerChat): - r = self.send( + r = await self.send( functions.messages.GetFullChat( chat_id=chat.chat_id ) @@ -75,7 +76,7 @@ class GetChatMember(BaseClient): else: raise UserNotParticipant elif isinstance(chat, types.InputPeerChannel): - r = self.send( + r = await self.send( functions.channels.GetParticipant( channel=chat, user_id=user diff --git a/pyrogram/client/methods/chats/get_chat_members.py b/pyrogram/client/methods/chats/get_chat_members.py index da1954a4..4f7613ce 100644 --- a/pyrogram/client/methods/chats/get_chat_members.py +++ b/pyrogram/client/methods/chats/get_chat_members.py @@ -36,7 +36,7 @@ class Filters: class GetChatMembers(BaseClient): - def get_chat_members( + async def get_chat_members( self, chat_id: Union[int, str], offset: int = 0, @@ -103,10 +103,10 @@ class GetChatMembers(BaseClient): # Get all bots app.get_chat_members("pyrogramchat", filter="bots") """ - peer = self.resolve_peer(chat_id) + peer = await self.resolve_peer(chat_id) if isinstance(peer, types.InputPeerChat): - r = self.send( + r = await self.send( functions.messages.GetFullChat( chat_id=peer.chat_id ) @@ -134,7 +134,7 @@ class GetChatMembers(BaseClient): else: raise ValueError("Invalid filter \"{}\"".format(filter)) - r = self.send( + r = await self.send( functions.channels.GetParticipants( channel=peer, filter=filter, diff --git a/pyrogram/client/methods/chats/get_chat_members_count.py b/pyrogram/client/methods/chats/get_chat_members_count.py index ad77acc1..88c12669 100644 --- a/pyrogram/client/methods/chats/get_chat_members_count.py +++ b/pyrogram/client/methods/chats/get_chat_members_count.py @@ -23,7 +23,7 @@ from ...ext import BaseClient class GetChatMembersCount(BaseClient): - def get_chat_members_count( + async def get_chat_members_count( self, chat_id: Union[int, str] ) -> int: @@ -45,19 +45,23 @@ class GetChatMembersCount(BaseClient): count = app.get_chat_members_count("pyrogramchat") print(count) """ - peer = self.resolve_peer(chat_id) + peer = await self.resolve_peer(chat_id) if isinstance(peer, types.InputPeerChat): - return self.send( + r = await self.send( functions.messages.GetChats( id=[peer.chat_id] ) - ).chats[0].participants_count + ) + + return r.chats[0].participants_count elif isinstance(peer, types.InputPeerChannel): - return self.send( + r = await self.send( functions.channels.GetFullChannel( channel=peer ) - ).full_chat.participants_count + ) + + return r.full_chat.participants_count else: raise ValueError("The chat_id \"{}\" belongs to a user".format(chat_id)) diff --git a/pyrogram/client/methods/chats/get_dialogs.py b/pyrogram/client/methods/chats/get_dialogs.py index de03b046..d52bd2ce 100644 --- a/pyrogram/client/methods/chats/get_dialogs.py +++ b/pyrogram/client/methods/chats/get_dialogs.py @@ -27,7 +27,7 @@ log = logging.getLogger(__name__) class GetDialogs(BaseClient): - def get_dialogs( + async def get_dialogs( self, offset_date: int = 0, limit: int = 100, @@ -65,9 +65,9 @@ class GetDialogs(BaseClient): """ if pinned_only: - r = self.send(functions.messages.GetPinnedDialogs(folder_id=0)) + r = await self.send(functions.messages.GetPinnedDialogs(folder_id=0)) else: - r = self.send( + r = await self.send( functions.messages.GetDialogs( offset_date=offset_date, offset_id=0, @@ -94,7 +94,7 @@ class GetDialogs(BaseClient): else: chat_id = utils.get_peer_id(to_id) - messages[chat_id] = pyrogram.Message._parse(self, message, users, chats) + messages[chat_id] = await pyrogram.Message._parse(self, message, users, chats) parsed_dialogs = [] diff --git a/pyrogram/client/methods/chats/get_dialogs_count.py b/pyrogram/client/methods/chats/get_dialogs_count.py index 7b81182e..da4c60ce 100644 --- a/pyrogram/client/methods/chats/get_dialogs_count.py +++ b/pyrogram/client/methods/chats/get_dialogs_count.py @@ -21,7 +21,7 @@ from ...ext import BaseClient class GetDialogsCount(BaseClient): - def get_dialogs_count(self, pinned_only: bool = False) -> int: + async def get_dialogs_count(self, pinned_only: bool = False) -> int: """Get the total count of your dialogs. pinned_only (``bool``, *optional*): @@ -39,9 +39,9 @@ class GetDialogsCount(BaseClient): """ if pinned_only: - return len(self.send(functions.messages.GetPinnedDialogs(folder_id=0)).dialogs) + return len((await self.send(functions.messages.GetPinnedDialogs(folder_id=0))).dialogs) else: - r = self.send( + r = await self.send( functions.messages.GetDialogs( offset_date=0, offset_id=0, diff --git a/pyrogram/client/methods/chats/get_nearby_chats.py b/pyrogram/client/methods/chats/get_nearby_chats.py index 1ccab729..6b4ab56d 100644 --- a/pyrogram/client/methods/chats/get_nearby_chats.py +++ b/pyrogram/client/methods/chats/get_nearby_chats.py @@ -24,7 +24,7 @@ from ...ext import BaseClient, utils class GetNearbyChats(BaseClient): - def get_nearby_chats( + async def get_nearby_chats( self, latitude: float, longitude: float @@ -48,7 +48,7 @@ class GetNearbyChats(BaseClient): print(chats) """ - r = self.send( + r = await self.send( functions.contacts.GetLocated( geo_point=types.InputGeoPoint( lat=latitude, diff --git a/pyrogram/client/methods/chats/iter_chat_members.py b/pyrogram/client/methods/chats/iter_chat_members.py index 0bc90305..b5ded4dc 100644 --- a/pyrogram/client/methods/chats/iter_chat_members.py +++ b/pyrogram/client/methods/chats/iter_chat_members.py @@ -17,10 +17,12 @@ # along with Pyrogram. If not, see . from string import ascii_lowercase -from typing import Union, Generator +from typing import Union, Generator, Optional import pyrogram +from async_generator import async_generator, yield_ from pyrogram.api import types + from ...ext import BaseClient @@ -38,13 +40,14 @@ QUERYABLE_FILTERS = (Filters.ALL, Filters.KICKED, Filters.RESTRICTED) class IterChatMembers(BaseClient): - def iter_chat_members( + @async_generator + async def iter_chat_members( self, chat_id: Union[int, str], limit: int = 0, query: str = "", filter: str = Filters.ALL - ) -> Generator["pyrogram.ChatMember", None, None]: + ) -> Optional[Generator["pyrogram.ChatMember", None, None]]: """Iterate through the members of a chat sequentially. This convenience method does the same as repeatedly calling :meth:`~Client.get_chat_members` in a loop, thus saving you @@ -97,7 +100,7 @@ class IterChatMembers(BaseClient): queries = [query] if query else QUERIES total = limit or (1 << 31) - 1 limit = min(200, total) - resolved_chat_id = self.resolve_peer(chat_id) + resolved_chat_id = await self.resolve_peer(chat_id) if filter not in QUERYABLE_FILTERS: queries = [""] @@ -106,7 +109,7 @@ class IterChatMembers(BaseClient): offset = 0 while True: - chat_members = self.get_chat_members( + chat_members = await self.get_chat_members( chat_id=chat_id, offset=offset, limit=limit, @@ -128,7 +131,7 @@ class IterChatMembers(BaseClient): if user_id in yielded: continue - yield chat_member + await yield_(chat_member) yielded.add(chat_member.user.id) diff --git a/pyrogram/client/methods/chats/iter_dialogs.py b/pyrogram/client/methods/chats/iter_dialogs.py index a2eddcb9..a1933a7e 100644 --- a/pyrogram/client/methods/chats/iter_dialogs.py +++ b/pyrogram/client/methods/chats/iter_dialogs.py @@ -16,18 +16,21 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . -from typing import Generator +from typing import Generator, Optional + +from async_generator import async_generator, yield_ import pyrogram from ...ext import BaseClient class IterDialogs(BaseClient): - def iter_dialogs( + @async_generator + async def iter_dialogs( self, - offset_date: int = 0, - limit: int = None - ) -> Generator["pyrogram.Dialog", None, None]: + limit: int = 0, + offset_date: int = 0 + ) -> Optional[Generator["pyrogram.Dialog", None, None]]: """Iterate through a user's dialogs sequentially. This convenience method does the same as repeatedly calling :meth:`~Client.get_dialogs` in a loop, thus saving @@ -35,14 +38,14 @@ class IterDialogs(BaseClient): single call. Parameters: - offset_date (``int``): - The offset date in Unix time taken from the top message of a :obj:`Dialog`. - Defaults to 0 (most recent dialog). - limit (``int``, *optional*): Limits the number of dialogs to be retrieved. By default, no limit is applied and all dialogs are returned. + offset_date (``int``): + The offset date in Unix time taken from the top message of a :obj:`Dialog`. + Defaults to 0 (most recent dialog). + Returns: ``Generator``: A generator yielding :obj:`Dialog` objects. @@ -57,12 +60,12 @@ class IterDialogs(BaseClient): total = limit or (1 << 31) - 1 limit = min(100, total) - pinned_dialogs = self.get_dialogs( + pinned_dialogs = await self.get_dialogs( pinned_only=True ) for dialog in pinned_dialogs: - yield dialog + await yield_(dialog) current += 1 @@ -70,7 +73,7 @@ class IterDialogs(BaseClient): return while True: - dialogs = self.get_dialogs( + dialogs = await self.get_dialogs( offset_date=offset_date, limit=limit ) @@ -81,7 +84,7 @@ class IterDialogs(BaseClient): offset_date = dialogs[-1].top_message.date for dialog in dialogs: - yield dialog + await yield_(dialog) current += 1 diff --git a/pyrogram/client/methods/chats/join_chat.py b/pyrogram/client/methods/chats/join_chat.py index 66687115..05bde186 100644 --- a/pyrogram/client/methods/chats/join_chat.py +++ b/pyrogram/client/methods/chats/join_chat.py @@ -24,7 +24,7 @@ from ...ext import BaseClient class JoinChat(BaseClient): - def join_chat( + async def join_chat( self, chat_id: Union[int, str] ): @@ -53,7 +53,7 @@ class JoinChat(BaseClient): match = self.INVITE_LINK_RE.match(str(chat_id)) if match: - chat = self.send( + chat = await self.send( functions.messages.ImportChatInvite( hash=match.group(1) ) @@ -63,9 +63,9 @@ class JoinChat(BaseClient): elif isinstance(chat.chats[0], types.Channel): return pyrogram.Chat._parse_channel_chat(self, chat.chats[0]) else: - chat = self.send( + chat = await self.send( functions.channels.JoinChannel( - channel=self.resolve_peer(chat_id) + channel=await self.resolve_peer(chat_id) ) ) diff --git a/pyrogram/client/methods/chats/kick_chat_member.py b/pyrogram/client/methods/chats/kick_chat_member.py index 55a177f4..d72da5ab 100644 --- a/pyrogram/client/methods/chats/kick_chat_member.py +++ b/pyrogram/client/methods/chats/kick_chat_member.py @@ -24,7 +24,7 @@ from ...ext import BaseClient class KickChatMember(BaseClient): - def kick_chat_member( + async def kick_chat_member( self, chat_id: Union[int, str], user_id: Union[int, str], @@ -68,11 +68,11 @@ class KickChatMember(BaseClient): # Kick chat member and automatically unban after 24h app.kick_chat_member(chat_id, user_id, int(time.time() + 86400)) """ - chat_peer = self.resolve_peer(chat_id) - user_peer = self.resolve_peer(user_id) + chat_peer = await self.resolve_peer(chat_id) + user_peer = await self.resolve_peer(user_id) if isinstance(chat_peer, types.InputPeerChannel): - r = self.send( + r = await self.send( functions.channels.EditBanned( channel=chat_peer, user_id=user_peer, @@ -90,7 +90,7 @@ class KickChatMember(BaseClient): ) ) else: - r = self.send( + r = await self.send( functions.messages.DeleteChatUser( chat_id=abs(chat_id), user_id=user_peer @@ -99,7 +99,7 @@ class KickChatMember(BaseClient): for i in r.updates: if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage)): - return pyrogram.Message._parse( + return await pyrogram.Message._parse( self, i.message, {i.id: i for i in r.users}, {i.id: i for i in r.chats} diff --git a/pyrogram/client/methods/chats/leave_chat.py b/pyrogram/client/methods/chats/leave_chat.py index 2cc1c057..31b7cc78 100644 --- a/pyrogram/client/methods/chats/leave_chat.py +++ b/pyrogram/client/methods/chats/leave_chat.py @@ -23,7 +23,7 @@ from ...ext import BaseClient class LeaveChat(BaseClient): - def leave_chat( + async def leave_chat( self, chat_id: Union[int, str], delete: bool = False @@ -48,16 +48,16 @@ class LeaveChat(BaseClient): # Leave basic chat and also delete the dialog app.leave_chat(chat_id, delete=True) """ - peer = self.resolve_peer(chat_id) + peer = await self.resolve_peer(chat_id) if isinstance(peer, types.InputPeerChannel): - return self.send( + return await self.send( functions.channels.LeaveChannel( - channel=self.resolve_peer(chat_id) + channel=await self.resolve_peer(chat_id) ) ) elif isinstance(peer, types.InputPeerChat): - r = self.send( + r = await self.send( functions.messages.DeleteChatUser( chat_id=peer.chat_id, user_id=types.InputPeerSelf() @@ -65,7 +65,7 @@ class LeaveChat(BaseClient): ) if delete: - self.send( + await self.send( functions.messages.DeleteHistory( peer=peer, max_id=0 diff --git a/pyrogram/client/methods/chats/pin_chat_message.py b/pyrogram/client/methods/chats/pin_chat_message.py index 44191a2d..6adaa2e8 100644 --- a/pyrogram/client/methods/chats/pin_chat_message.py +++ b/pyrogram/client/methods/chats/pin_chat_message.py @@ -23,7 +23,7 @@ from ...ext import BaseClient class PinChatMessage(BaseClient): - def pin_chat_message( + async def pin_chat_message( self, chat_id: Union[int, str], message_id: int, @@ -56,9 +56,9 @@ class PinChatMessage(BaseClient): # Pin without notification app.pin_chat_message(chat_id, message_id, disable_notification=True) """ - self.send( + await self.send( functions.messages.UpdatePinnedMessage( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), id=message_id, silent=disable_notification or None ) diff --git a/pyrogram/client/methods/chats/promote_chat_member.py b/pyrogram/client/methods/chats/promote_chat_member.py index 70b4f4e2..c6912031 100644 --- a/pyrogram/client/methods/chats/promote_chat_member.py +++ b/pyrogram/client/methods/chats/promote_chat_member.py @@ -23,7 +23,7 @@ from ...ext import BaseClient class PromoteChatMember(BaseClient): - def promote_chat_member( + async def promote_chat_member( self, chat_id: Union[int, str], user_id: Union[int, str], @@ -84,10 +84,10 @@ class PromoteChatMember(BaseClient): # Promote chat member to supergroup admin app.promote_chat_member(chat_id, user_id) """ - self.send( + await self.send( functions.channels.EditAdmin( - channel=self.resolve_peer(chat_id), - user_id=self.resolve_peer(user_id), + channel=await self.resolve_peer(chat_id), + user_id=await self.resolve_peer(user_id), admin_rights=types.ChatAdminRights( change_info=can_change_info or None, post_messages=can_post_messages or None, diff --git a/pyrogram/client/methods/chats/restrict_chat_member.py b/pyrogram/client/methods/chats/restrict_chat_member.py index a0707078..500482a1 100644 --- a/pyrogram/client/methods/chats/restrict_chat_member.py +++ b/pyrogram/client/methods/chats/restrict_chat_member.py @@ -24,7 +24,7 @@ from ...types.user_and_chats import Chat, ChatPermissions class RestrictChatMember(BaseClient): - def restrict_chat_member( + async def restrict_chat_member( self, chat_id: Union[int, str], user_id: Union[int, str], @@ -71,10 +71,10 @@ class RestrictChatMember(BaseClient): # Chat member can only send text messages app.restrict_chat_member(chat_id, user_id, ChatPermissions(can_send_messages=True)) """ - r = self.send( + r = await self.send( functions.channels.EditBanned( - channel=self.resolve_peer(chat_id), - user_id=self.resolve_peer(user_id), + channel=await self.resolve_peer(chat_id), + user_id=await self.resolve_peer(user_id), banned_rights=types.ChatBannedRights( until_date=until_date, send_messages=True if not permissions.can_send_messages else None, diff --git a/pyrogram/client/methods/chats/set_administrator_title.py b/pyrogram/client/methods/chats/set_administrator_title.py index 361a4e1c..0cbc937a 100644 --- a/pyrogram/client/methods/chats/set_administrator_title.py +++ b/pyrogram/client/methods/chats/set_administrator_title.py @@ -23,7 +23,7 @@ from ...ext import BaseClient class SetAdministratorTitle(BaseClient): - def set_administrator_title( + async def set_administrator_title( self, chat_id: Union[int, str], user_id: Union[int, str], @@ -54,15 +54,15 @@ class SetAdministratorTitle(BaseClient): app.set_administrator_title(chat_id, user_id, "ฅ^•ﻌ•^ฅ") """ - chat_id = self.resolve_peer(chat_id) - user_id = self.resolve_peer(user_id) + chat_id = await self.resolve_peer(chat_id) + user_id = await self.resolve_peer(user_id) - r = self.send( + r = (await self.send( functions.channels.GetParticipant( channel=chat_id, user_id=user_id ) - ).participant + )).participant if isinstance(r, types.ChannelParticipantCreator): admin_rights = types.ChatAdminRights( @@ -104,7 +104,7 @@ class SetAdministratorTitle(BaseClient): if not admin_rights.add_admins: admin_rights.add_admins = None - self.send( + await self.send( functions.channels.EditAdmin( channel=chat_id, user_id=user_id, diff --git a/pyrogram/client/methods/chats/set_chat_description.py b/pyrogram/client/methods/chats/set_chat_description.py index 312b63eb..e2960408 100644 --- a/pyrogram/client/methods/chats/set_chat_description.py +++ b/pyrogram/client/methods/chats/set_chat_description.py @@ -23,7 +23,7 @@ from ...ext import BaseClient class SetChatDescription(BaseClient): - def set_chat_description( + async def set_chat_description( self, chat_id: Union[int, str], description: str @@ -49,10 +49,10 @@ class SetChatDescription(BaseClient): app.set_chat_description(chat_id, "New Description") """ - peer = self.resolve_peer(chat_id) + peer = await self.resolve_peer(chat_id) if isinstance(peer, (types.InputPeerChannel, types.InputPeerChat)): - self.send( + await self.send( functions.messages.EditChatAbout( peer=peer, about=description diff --git a/pyrogram/client/methods/chats/set_chat_permissions.py b/pyrogram/client/methods/chats/set_chat_permissions.py index afb8b6f1..3509baf4 100644 --- a/pyrogram/client/methods/chats/set_chat_permissions.py +++ b/pyrogram/client/methods/chats/set_chat_permissions.py @@ -24,7 +24,7 @@ from ...types.user_and_chats import Chat, ChatPermissions class SetChatPermissions(BaseClient): - def set_chat_permissions( + async def set_chat_permissions( self, chat_id: Union[int, str], permissions: ChatPermissions, @@ -63,9 +63,9 @@ class SetChatPermissions(BaseClient): ) ) """ - r = self.send( + r = await self.send( functions.messages.EditChatDefaultBannedRights( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), banned_rights=types.ChatBannedRights( until_date=0, send_messages=True if not permissions.can_send_messages else None, diff --git a/pyrogram/client/methods/chats/set_chat_photo.py b/pyrogram/client/methods/chats/set_chat_photo.py index f5ef954f..e6d35448 100644 --- a/pyrogram/client/methods/chats/set_chat_photo.py +++ b/pyrogram/client/methods/chats/set_chat_photo.py @@ -24,7 +24,7 @@ from ...ext import BaseClient, utils class SetChatPhoto(BaseClient): - def set_chat_photo( + async def set_chat_photo( self, chat_id: Union[int, str], *, @@ -79,32 +79,32 @@ class SetChatPhoto(BaseClient): # Set chat photo using an exiting Video file_id app.set_chat_photo(chat_id, video=video.file_id, file_ref=video.file_ref) """ - peer = self.resolve_peer(chat_id) + peer = await self.resolve_peer(chat_id) if isinstance(photo, str): if os.path.isfile(photo): photo = types.InputChatUploadedPhoto( - file=self.save_file(photo), - video=self.save_file(video) + file=await self.save_file(photo), + video=await self.save_file(video) ) else: photo = utils.get_input_media_from_file_id(photo, file_ref, 2) photo = types.InputChatPhoto(id=photo.id) else: photo = types.InputChatUploadedPhoto( - file=self.save_file(photo), - video=self.save_file(video) + file=await self.save_file(photo), + video=await self.save_file(video) ) if isinstance(peer, types.InputPeerChat): - self.send( + await self.send( functions.messages.EditChatPhoto( chat_id=peer.chat_id, photo=photo ) ) elif isinstance(peer, types.InputPeerChannel): - self.send( + await self.send( functions.channels.EditPhoto( channel=peer, photo=photo diff --git a/pyrogram/client/methods/chats/set_chat_title.py b/pyrogram/client/methods/chats/set_chat_title.py index 9d6a2d24..a6b9bc71 100644 --- a/pyrogram/client/methods/chats/set_chat_title.py +++ b/pyrogram/client/methods/chats/set_chat_title.py @@ -23,7 +23,7 @@ from ...ext import BaseClient class SetChatTitle(BaseClient): - def set_chat_title( + async def set_chat_title( self, chat_id: Union[int, str], title: str @@ -54,17 +54,17 @@ class SetChatTitle(BaseClient): app.set_chat_title(chat_id, "New Title") """ - peer = self.resolve_peer(chat_id) + peer = await self.resolve_peer(chat_id) if isinstance(peer, types.InputPeerChat): - self.send( + await self.send( functions.messages.EditChatTitle( chat_id=peer.chat_id, title=title ) ) elif isinstance(peer, types.InputPeerChannel): - self.send( + await self.send( functions.channels.EditTitle( channel=peer, title=title diff --git a/pyrogram/client/methods/chats/set_slow_mode.py b/pyrogram/client/methods/chats/set_slow_mode.py index 8215c3b9..185a3824 100644 --- a/pyrogram/client/methods/chats/set_slow_mode.py +++ b/pyrogram/client/methods/chats/set_slow_mode.py @@ -23,7 +23,7 @@ from ...ext import BaseClient class SetSlowMode(BaseClient): - def set_slow_mode( + async def set_slow_mode( self, chat_id: Union[int, str], seconds: Union[int, None] @@ -51,9 +51,9 @@ class SetSlowMode(BaseClient): app.set_slow_mode("pyrogramchat", None) """ - self.send( + await self.send( functions.channels.ToggleSlowMode( - channel=self.resolve_peer(chat_id), + channel=await self.resolve_peer(chat_id), seconds=0 if seconds is None else seconds ) ) diff --git a/pyrogram/client/methods/chats/unarchive_chats.py b/pyrogram/client/methods/chats/unarchive_chats.py index b004e4bb..dfba70a7 100644 --- a/pyrogram/client/methods/chats/unarchive_chats.py +++ b/pyrogram/client/methods/chats/unarchive_chats.py @@ -23,7 +23,7 @@ from ...ext import BaseClient class UnarchiveChats(BaseClient): - def unarchive_chats( + async def unarchive_chats( self, chat_ids: Union[int, str, List[Union[int, str]]], ) -> bool: @@ -50,14 +50,19 @@ class UnarchiveChats(BaseClient): if not isinstance(chat_ids, list): chat_ids = [chat_ids] - self.send( + folder_peers = [] + + for chat in chat_ids: + folder_peers.append( + types.InputFolderPeer( + peer=await self.resolve_peer(chat), + folder_id=0 + ) + ) + + await self.send( functions.folders.EditPeerFolders( - folder_peers=[ - types.InputFolderPeer( - peer=self.resolve_peer(chat), - folder_id=0 - ) for chat in chat_ids - ] + folder_peers=folder_peers ) ) diff --git a/pyrogram/client/methods/chats/unban_chat_member.py b/pyrogram/client/methods/chats/unban_chat_member.py index fc0c9751..4a7b3940 100644 --- a/pyrogram/client/methods/chats/unban_chat_member.py +++ b/pyrogram/client/methods/chats/unban_chat_member.py @@ -23,7 +23,7 @@ from ...ext import BaseClient class UnbanChatMember(BaseClient): - def unban_chat_member( + async def unban_chat_member( self, chat_id: Union[int, str], user_id: Union[int, str] @@ -49,10 +49,10 @@ class UnbanChatMember(BaseClient): # Unban chat member right now app.unban_chat_member(chat_id, user_id) """ - self.send( + await self.send( functions.channels.EditBanned( - channel=self.resolve_peer(chat_id), - user_id=self.resolve_peer(user_id), + channel=await self.resolve_peer(chat_id), + user_id=await self.resolve_peer(user_id), banned_rights=types.ChatBannedRights( until_date=0 ) diff --git a/pyrogram/client/methods/chats/unpin_chat_message.py b/pyrogram/client/methods/chats/unpin_chat_message.py index 6defd99f..0ca8254a 100644 --- a/pyrogram/client/methods/chats/unpin_chat_message.py +++ b/pyrogram/client/methods/chats/unpin_chat_message.py @@ -23,7 +23,7 @@ from ...ext import BaseClient class UnpinChatMessage(BaseClient): - def unpin_chat_message( + async def unpin_chat_message( self, chat_id: Union[int, str] ) -> bool: @@ -43,9 +43,9 @@ class UnpinChatMessage(BaseClient): app.unpin_chat_message(chat_id) """ - self.send( + await self.send( functions.messages.UpdatePinnedMessage( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), id=0 ) ) diff --git a/pyrogram/client/methods/chats/update_chat_username.py b/pyrogram/client/methods/chats/update_chat_username.py index 251d6832..b1c57f1e 100644 --- a/pyrogram/client/methods/chats/update_chat_username.py +++ b/pyrogram/client/methods/chats/update_chat_username.py @@ -23,7 +23,7 @@ from ...ext import BaseClient class UpdateChatUsername(BaseClient): - def update_chat_username( + async def update_chat_username( self, chat_id: Union[int, str], username: Union[str, None] @@ -50,11 +50,11 @@ class UpdateChatUsername(BaseClient): app.update_chat_username(chat_id, "new_username") """ - peer = self.resolve_peer(chat_id) + peer = await self.resolve_peer(chat_id) if isinstance(peer, types.InputPeerChannel): return bool( - self.send( + await self.send( functions.channels.UpdateUsername( channel=peer, username=username or "" diff --git a/pyrogram/client/methods/contacts/add_contacts.py b/pyrogram/client/methods/contacts/add_contacts.py index a5bd4a93..7226d60b 100644 --- a/pyrogram/client/methods/contacts/add_contacts.py +++ b/pyrogram/client/methods/contacts/add_contacts.py @@ -24,7 +24,7 @@ from ...ext import BaseClient class AddContacts(BaseClient): - def add_contacts( + async def add_contacts( self, contacts: List["pyrogram.InputPhoneContact"] ): @@ -47,7 +47,7 @@ class AddContacts(BaseClient): InputPhoneContact("38987654321", "Bar"), InputPhoneContact("01234567891", "Baz")]) """ - imported_contacts = self.send( + imported_contacts = await self.send( functions.contacts.ImportContacts( contacts=contacts ) diff --git a/pyrogram/client/methods/contacts/delete_contacts.py b/pyrogram/client/methods/contacts/delete_contacts.py index 27e6cfff..777d8b39 100644 --- a/pyrogram/client/methods/contacts/delete_contacts.py +++ b/pyrogram/client/methods/contacts/delete_contacts.py @@ -24,7 +24,7 @@ from ...ext import BaseClient class DeleteContacts(BaseClient): - def delete_contacts( + async def delete_contacts( self, ids: List[int] ): @@ -47,14 +47,14 @@ class DeleteContacts(BaseClient): for i in ids: try: - input_user = self.resolve_peer(i) + input_user = await self.resolve_peer(i) except PeerIdInvalid: continue else: if isinstance(input_user, types.InputPeerUser): contacts.append(input_user) - return self.send( + return await self.send( functions.contacts.DeleteContacts( id=contacts ) diff --git a/pyrogram/client/methods/contacts/get_contacts.py b/pyrogram/client/methods/contacts/get_contacts.py index a0699e19..8f8392d6 100644 --- a/pyrogram/client/methods/contacts/get_contacts.py +++ b/pyrogram/client/methods/contacts/get_contacts.py @@ -16,6 +16,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +import asyncio import logging from typing import List @@ -27,7 +28,7 @@ log = logging.getLogger(__name__) class GetContacts(BaseClient): - def get_contacts(self) -> List["pyrogram.User"]: + async def get_contacts(self) -> List["pyrogram.User"]: """Get contacts from your Telegram address book. Returns: @@ -39,5 +40,5 @@ class GetContacts(BaseClient): contacts = app.get_contacts() print(contacts) """ - contacts = self.send(functions.contacts.GetContacts(hash=0)) + contacts = await self.send(functions.contacts.GetContacts(hash=0)) return pyrogram.List(pyrogram.User._parse(self, user) for user in contacts.users) diff --git a/pyrogram/client/methods/contacts/get_contacts_count.py b/pyrogram/client/methods/contacts/get_contacts_count.py index b7871fde..8435557a 100644 --- a/pyrogram/client/methods/contacts/get_contacts_count.py +++ b/pyrogram/client/methods/contacts/get_contacts_count.py @@ -21,7 +21,7 @@ from ...ext import BaseClient class GetContactsCount(BaseClient): - def get_contacts_count(self) -> int: + async def get_contacts_count(self) -> int: """Get the total count of contacts from your Telegram address book. Returns: @@ -34,4 +34,4 @@ class GetContactsCount(BaseClient): print(count) """ - return len(self.send(functions.contacts.GetContacts(hash=0)).contacts) + return len((await self.send(functions.contacts.GetContacts(hash=0))).contacts) diff --git a/pyrogram/client/methods/messages/delete_messages.py b/pyrogram/client/methods/messages/delete_messages.py index a6af6a69..5deb6d5a 100644 --- a/pyrogram/client/methods/messages/delete_messages.py +++ b/pyrogram/client/methods/messages/delete_messages.py @@ -23,7 +23,7 @@ from pyrogram.client.ext import BaseClient class DeleteMessages(BaseClient): - def delete_messages( + async def delete_messages( self, chat_id: Union[int, str], message_ids: Iterable[int], @@ -62,18 +62,18 @@ class DeleteMessages(BaseClient): # Delete messages only on your side (without revoking) app.delete_messages(chat_id, message_id, revoke=False) """ - peer = self.resolve_peer(chat_id) + peer = await self.resolve_peer(chat_id) message_ids = list(message_ids) if not isinstance(message_ids, int) else [message_ids] if isinstance(peer, types.InputPeerChannel): - r = self.send( + r = await self.send( functions.channels.DeleteMessages( channel=peer, id=message_ids ) ) else: - r = self.send( + r = await self.send( functions.messages.DeleteMessages( id=message_ids, revoke=revoke or None diff --git a/pyrogram/client/methods/messages/download_media.py b/pyrogram/client/methods/messages/download_media.py index 01397135..3c1a8cbe 100644 --- a/pyrogram/client/methods/messages/download_media.py +++ b/pyrogram/client/methods/messages/download_media.py @@ -16,6 +16,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +import asyncio import binascii import os import struct @@ -32,7 +33,7 @@ DEFAULT_DOWNLOAD_DIR = "downloads/" class DownloadMedia(BaseClient): - def download_media( + async def download_media( self, message: Union["pyrogram.Message", str], file_ref: str = None, @@ -202,7 +203,7 @@ class DownloadMedia(BaseClient): except (AssertionError, binascii.Error, struct.error): raise FileIdInvalid from None - done = Event() + done = asyncio.Event() path = [None] directory, file_name = os.path.split(file_name) @@ -239,9 +240,9 @@ class DownloadMedia(BaseClient): ) # Cast to string because Path objects aren't supported by Python 3.5 - self.download_queue.put((data, str(directory), str(file_name), done, progress, progress_args, path)) + self.download_queue.put_nowait((data, str(directory), str(file_name), done, progress, progress_args, path)) if block: - done.wait() + await done.wait() return path[0] diff --git a/pyrogram/client/methods/messages/edit_inline_caption.py b/pyrogram/client/methods/messages/edit_inline_caption.py index 58cd05b2..335f878e 100644 --- a/pyrogram/client/methods/messages/edit_inline_caption.py +++ b/pyrogram/client/methods/messages/edit_inline_caption.py @@ -23,7 +23,7 @@ from pyrogram.client.ext import BaseClient class EditInlineCaption(BaseClient): - def edit_inline_caption( + async def edit_inline_caption( self, inline_message_id: str, caption: str, @@ -58,7 +58,7 @@ class EditInlineCaption(BaseClient): # Bots only app.edit_inline_caption(inline_message_id, "new media caption") """ - return self.edit_inline_text( + return await self.edit_inline_text( inline_message_id=inline_message_id, text=caption, parse_mode=parse_mode, diff --git a/pyrogram/client/methods/messages/edit_inline_media.py b/pyrogram/client/methods/messages/edit_inline_media.py index f409ae06..74cb2910 100644 --- a/pyrogram/client/methods/messages/edit_inline_media.py +++ b/pyrogram/client/methods/messages/edit_inline_media.py @@ -29,7 +29,7 @@ from pyrogram.client.types.input_media import InputMedia class EditInlineMedia(BaseClient): - def edit_inline_media( + async def edit_inline_media( self, inline_message_id: str, media: InputMedia, @@ -109,11 +109,11 @@ class EditInlineMedia(BaseClient): else: media = utils.get_input_media_from_file_id(media.media, media.file_ref, 5) - return self.send( + return await self.send( functions.messages.EditInlineBotMessage( id=utils.unpack_inline_message_id(inline_message_id), media=media, reply_markup=reply_markup.write() if reply_markup else None, - **self.parser.parse(caption, parse_mode) + **await self.parser.parse(caption, parse_mode) ) ) diff --git a/pyrogram/client/methods/messages/edit_inline_reply_markup.py b/pyrogram/client/methods/messages/edit_inline_reply_markup.py index 44e6197c..f5381806 100644 --- a/pyrogram/client/methods/messages/edit_inline_reply_markup.py +++ b/pyrogram/client/methods/messages/edit_inline_reply_markup.py @@ -22,7 +22,7 @@ from pyrogram.client.ext import BaseClient, utils class EditInlineReplyMarkup(BaseClient): - def edit_inline_reply_markup( + async def edit_inline_reply_markup( self, inline_message_id: str, reply_markup: "pyrogram.InlineKeyboardMarkup" = None @@ -50,7 +50,7 @@ class EditInlineReplyMarkup(BaseClient): InlineKeyboardMarkup([[ InlineKeyboardButton("New button", callback_data="new_data")]])) """ - return self.send( + return await self.send( functions.messages.EditInlineBotMessage( id=utils.unpack_inline_message_id(inline_message_id), reply_markup=reply_markup.write() if reply_markup else None, diff --git a/pyrogram/client/methods/messages/edit_inline_text.py b/pyrogram/client/methods/messages/edit_inline_text.py index 59b5ab73..cfdd232b 100644 --- a/pyrogram/client/methods/messages/edit_inline_text.py +++ b/pyrogram/client/methods/messages/edit_inline_text.py @@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient, utils class EditInlineText(BaseClient): - def edit_inline_text( + async def edit_inline_text( self, inline_message_id: str, text: str, @@ -71,11 +71,11 @@ class EditInlineText(BaseClient): disable_web_page_preview=True) """ - return self.send( + return await self.send( functions.messages.EditInlineBotMessage( id=utils.unpack_inline_message_id(inline_message_id), no_webpage=disable_web_page_preview or None, reply_markup=reply_markup.write() if reply_markup else None, - **self.parser.parse(text, parse_mode) + **await self.parser.parse(text, parse_mode) ) ) diff --git a/pyrogram/client/methods/messages/edit_message_caption.py b/pyrogram/client/methods/messages/edit_message_caption.py index a7c5b94c..01bd8147 100644 --- a/pyrogram/client/methods/messages/edit_message_caption.py +++ b/pyrogram/client/methods/messages/edit_message_caption.py @@ -23,7 +23,7 @@ from pyrogram.client.ext import BaseClient class EditMessageCaption(BaseClient): - def edit_message_caption( + async def edit_message_caption( self, chat_id: Union[int, str], message_id: int, @@ -63,7 +63,7 @@ class EditMessageCaption(BaseClient): app.edit_message_caption(chat_id, message_id, "new media caption") """ - return self.edit_message_text( + return await self.edit_message_text( chat_id=chat_id, message_id=message_id, text=caption, diff --git a/pyrogram/client/methods/messages/edit_message_media.py b/pyrogram/client/methods/messages/edit_message_media.py index 9e91e945..765c1598 100644 --- a/pyrogram/client/methods/messages/edit_message_media.py +++ b/pyrogram/client/methods/messages/edit_message_media.py @@ -31,7 +31,7 @@ from pyrogram.client.types.input_media import InputMedia class EditMessageMedia(BaseClient): - def edit_message_media( + async def edit_message_media( self, chat_id: Union[int, str], message_id: int, @@ -85,11 +85,11 @@ class EditMessageMedia(BaseClient): if isinstance(media, InputMediaPhoto): if os.path.isfile(media.media): - media = self.send( + media = await self.send( functions.messages.UploadMedia( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), media=types.InputMediaUploadedPhoto( - file=self.save_file(media.media) + file=await self.save_file(media.media) ) ) ) @@ -109,13 +109,13 @@ class EditMessageMedia(BaseClient): media = utils.get_input_media_from_file_id(media.media, media.file_ref, 2) elif isinstance(media, InputMediaVideo): if os.path.isfile(media.media): - media = self.send( + media = await self.send( functions.messages.UploadMedia( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), media=types.InputMediaUploadedDocument( mime_type=self.guess_mime_type(media.media) or "video/mp4", - thumb=self.save_file(media.thumb), - file=self.save_file(media.media), + thumb=await self.save_file(media.thumb), + file=await self.save_file(media.media), attributes=[ types.DocumentAttributeVideo( supports_streaming=media.supports_streaming or None, @@ -146,13 +146,13 @@ class EditMessageMedia(BaseClient): media = utils.get_input_media_from_file_id(media.media, media.file_ref, 4) elif isinstance(media, InputMediaAudio): if os.path.isfile(media.media): - media = self.send( + media = await self.send( functions.messages.UploadMedia( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), media=types.InputMediaUploadedDocument( mime_type=self.guess_mime_type(media.media) or "audio/mpeg", - thumb=self.save_file(media.thumb), - file=self.save_file(media.media), + thumb=await self.save_file(media.thumb), + file=await self.save_file(media.media), attributes=[ types.DocumentAttributeAudio( duration=media.duration, @@ -182,13 +182,13 @@ class EditMessageMedia(BaseClient): media = utils.get_input_media_from_file_id(media.media, media.file_ref, 9) elif isinstance(media, InputMediaAnimation): if os.path.isfile(media.media): - media = self.send( + media = await self.send( functions.messages.UploadMedia( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), media=types.InputMediaUploadedDocument( mime_type=self.guess_mime_type(media.media) or "video/mp4", thumb=self.save_file(media.thumb), - file=self.save_file(media.media), + file=await self.save_file(media.media), attributes=[ types.DocumentAttributeVideo( supports_streaming=True, @@ -220,13 +220,13 @@ class EditMessageMedia(BaseClient): media = utils.get_input_media_from_file_id(media.media, media.file_ref, 10) elif isinstance(media, InputMediaDocument): if os.path.isfile(media.media): - media = self.send( + media = await self.send( functions.messages.UploadMedia( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), media=types.InputMediaUploadedDocument( mime_type=self.guess_mime_type(media.media) or "application/zip", - thumb=self.save_file(media.thumb), - file=self.save_file(media.media), + thumb=await self.save_file(media.thumb), + file=await self.save_file(media.media), attributes=[ types.DocumentAttributeFilename( file_name=file_name or os.path.basename(media.media) @@ -250,19 +250,19 @@ class EditMessageMedia(BaseClient): else: media = utils.get_input_media_from_file_id(media.media, media.file_ref, 5) - r = self.send( + r = await self.send( functions.messages.EditMessage( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), id=message_id, media=media, reply_markup=reply_markup.write() if reply_markup else None, - **self.parser.parse(caption, parse_mode) + **await self.parser.parse(caption, parse_mode) ) ) for i in r.updates: if isinstance(i, (types.UpdateEditMessage, types.UpdateEditChannelMessage)): - return pyrogram.Message._parse( + return await pyrogram.Message._parse( self, i.message, {i.id: i for i in r.users}, {i.id: i for i in r.chats} diff --git a/pyrogram/client/methods/messages/edit_message_reply_markup.py b/pyrogram/client/methods/messages/edit_message_reply_markup.py index 35c6cb3e..65fa26e2 100644 --- a/pyrogram/client/methods/messages/edit_message_reply_markup.py +++ b/pyrogram/client/methods/messages/edit_message_reply_markup.py @@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient class EditMessageReplyMarkup(BaseClient): - def edit_message_reply_markup( + async def edit_message_reply_markup( self, chat_id: Union[int, str], message_id: int, @@ -58,9 +58,9 @@ class EditMessageReplyMarkup(BaseClient): InlineKeyboardMarkup([[ InlineKeyboardButton("New button", callback_data="new_data")]])) """ - r = self.send( + r = await self.send( functions.messages.EditMessage( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), id=message_id, reply_markup=reply_markup.write() if reply_markup else None, ) @@ -68,7 +68,7 @@ class EditMessageReplyMarkup(BaseClient): for i in r.updates: if isinstance(i, (types.UpdateEditMessage, types.UpdateEditChannelMessage)): - return pyrogram.Message._parse( + return await pyrogram.Message._parse( self, i.message, {i.id: i for i in r.users}, {i.id: i for i in r.chats} diff --git a/pyrogram/client/methods/messages/edit_message_text.py b/pyrogram/client/methods/messages/edit_message_text.py index df5ebace..a2b43f16 100644 --- a/pyrogram/client/methods/messages/edit_message_text.py +++ b/pyrogram/client/methods/messages/edit_message_text.py @@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient class EditMessageText(BaseClient): - def edit_message_text( + async def edit_message_text( self, chat_id: Union[int, str], message_id: int, @@ -75,19 +75,19 @@ class EditMessageText(BaseClient): disable_web_page_preview=True) """ - r = self.send( + r = await self.send( functions.messages.EditMessage( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), id=message_id, no_webpage=disable_web_page_preview or None, reply_markup=reply_markup.write() if reply_markup else None, - **self.parser.parse(text, parse_mode) + **await self.parser.parse(text, parse_mode) ) ) for i in r.updates: if isinstance(i, (types.UpdateEditMessage, types.UpdateEditChannelMessage)): - return pyrogram.Message._parse( + return await pyrogram.Message._parse( self, i.message, {i.id: i for i in r.users}, {i.id: i for i in r.chats} diff --git a/pyrogram/client/methods/messages/forward_messages.py b/pyrogram/client/methods/messages/forward_messages.py index 1e2b1702..c7f47a52 100644 --- a/pyrogram/client/methods/messages/forward_messages.py +++ b/pyrogram/client/methods/messages/forward_messages.py @@ -20,11 +20,12 @@ from typing import Union, Iterable, List import pyrogram from pyrogram.api import functions, types + from ...ext import BaseClient class ForwardMessages(BaseClient): - def forward_messages( + async def forward_messages( self, chat_id: Union[int, str], from_chat_id: Union[int, str], @@ -94,11 +95,11 @@ class ForwardMessages(BaseClient): forwarded_messages = [] for chunk in [message_ids[i:i + 200] for i in range(0, len(message_ids), 200)]: - messages = self.get_messages(chat_id=from_chat_id, message_ids=chunk) + messages = await self.get_messages(chat_id=from_chat_id, message_ids=chunk) for message in messages: forwarded_messages.append( - message.forward( + await message.forward( chat_id, disable_notification=disable_notification, as_copy=True, @@ -109,10 +110,10 @@ class ForwardMessages(BaseClient): return pyrogram.List(forwarded_messages) if is_iterable else forwarded_messages[0] else: - r = self.send( + r = await self.send( functions.messages.ForwardMessages( - to_peer=self.resolve_peer(chat_id), - from_peer=self.resolve_peer(from_chat_id), + to_peer=await self.resolve_peer(chat_id), + from_peer=await self.resolve_peer(from_chat_id), id=message_ids, silent=disable_notification or None, random_id=[self.rnd_id() for _ in message_ids], @@ -128,7 +129,7 @@ class ForwardMessages(BaseClient): for i in r.updates: if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)): forwarded_messages.append( - pyrogram.Message._parse( + await pyrogram.Message._parse( self, i.message, users, chats ) diff --git a/pyrogram/client/methods/messages/get_history.py b/pyrogram/client/methods/messages/get_history.py index ea630aee..92a84d9b 100644 --- a/pyrogram/client/methods/messages/get_history.py +++ b/pyrogram/client/methods/messages/get_history.py @@ -16,6 +16,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +import asyncio import logging from typing import Union, List @@ -28,7 +29,7 @@ log = logging.getLogger(__name__) class GetHistory(BaseClient): - def get_history( + async def get_history( self, chat_id: Union[int, str], limit: int = 100, @@ -83,11 +84,11 @@ class GetHistory(BaseClient): offset_id = offset_id or (1 if reverse else 0) - messages = utils.parse_messages( + messages = await utils.parse_messages( self, - self.send( + await self.send( functions.messages.GetHistory( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), offset_id=offset_id, offset_date=offset_date, add_offset=offset * (-1 if reverse else 1) - (limit if reverse else 0), diff --git a/pyrogram/client/methods/messages/get_history_count.py b/pyrogram/client/methods/messages/get_history_count.py index cbdb1365..d7476f95 100644 --- a/pyrogram/client/methods/messages/get_history_count.py +++ b/pyrogram/client/methods/messages/get_history_count.py @@ -26,7 +26,7 @@ log = logging.getLogger(__name__) class GetHistoryCount(BaseClient): - def get_history_count( + async def get_history_count( self, chat_id: Union[int, str] ) -> int: @@ -51,9 +51,9 @@ class GetHistoryCount(BaseClient): app.get_history_count("pyrogramchat") """ - r = self.send( + r = await self.send( functions.messages.GetHistory( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), offset_id=0, offset_date=0, add_offset=0, diff --git a/pyrogram/client/methods/messages/get_messages.py b/pyrogram/client/methods/messages/get_messages.py index caf5bea0..b4199b30 100644 --- a/pyrogram/client/methods/messages/get_messages.py +++ b/pyrogram/client/methods/messages/get_messages.py @@ -16,6 +16,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +import asyncio import logging from typing import Union, Iterable, List @@ -30,7 +31,7 @@ log = logging.getLogger(__name__) class GetMessages(BaseClient): - def get_messages( + async def get_messages( self, chat_id: Union[int, str], message_ids: Union[int, Iterable[int]] = None, @@ -96,7 +97,7 @@ class GetMessages(BaseClient): if ids is None: raise ValueError("No argument supplied. Either pass message_ids or reply_to_message_ids") - peer = self.resolve_peer(chat_id) + peer = await self.resolve_peer(chat_id) is_iterable = not isinstance(ids, int) ids = list(ids) if is_iterable else [ids] @@ -110,8 +111,8 @@ class GetMessages(BaseClient): else: rpc = functions.messages.GetMessages(id=ids) - r = self.send(rpc) + r = await self.send(rpc) - messages = utils.parse_messages(self, r, replies=replies) + messages = await utils.parse_messages(self, r, replies=replies) return messages if is_iterable else messages[0] diff --git a/pyrogram/client/methods/messages/iter_history.py b/pyrogram/client/methods/messages/iter_history.py index 641f5000..04c7dc1a 100644 --- a/pyrogram/client/methods/messages/iter_history.py +++ b/pyrogram/client/methods/messages/iter_history.py @@ -16,14 +16,17 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . -from typing import Union, Generator +from typing import Union, Optional, Generator import pyrogram +from async_generator import async_generator, yield_ + from ...ext import BaseClient class IterHistory(BaseClient): - def iter_history( + @async_generator + async def iter_history( self, chat_id: Union[int, str], limit: int = 0, @@ -31,7 +34,7 @@ class IterHistory(BaseClient): offset_id: int = 0, offset_date: int = 0, reverse: bool = False - ) -> Generator["pyrogram.Message", None, None]: + ) -> Optional[Generator["pyrogram.Message", None, None]]: """Iterate through a chat history sequentially. This convenience method does the same as repeatedly calling :meth:`~Client.get_history` in a loop, thus saving @@ -76,7 +79,7 @@ class IterHistory(BaseClient): limit = min(100, total) while True: - messages = self.get_history( + messages = await self.get_history( chat_id=chat_id, limit=limit, offset=offset, @@ -91,7 +94,7 @@ class IterHistory(BaseClient): offset_id = messages[-1].message_id + (1 if reverse else 0) for message in messages: - yield message + await yield_(message) current += 1 diff --git a/pyrogram/client/methods/messages/read_history.py b/pyrogram/client/methods/messages/read_history.py index f23fa800..5e1e265b 100644 --- a/pyrogram/client/methods/messages/read_history.py +++ b/pyrogram/client/methods/messages/read_history.py @@ -23,7 +23,7 @@ from ...ext import BaseClient class ReadHistory(BaseClient): - def read_history( + async def read_history( self, chat_id: Union[int, str], max_id: int = 0 @@ -53,7 +53,7 @@ class ReadHistory(BaseClient): app.read_history("pyrogramlounge", 123456) """ - peer = self.resolve_peer(chat_id) + peer = await self.resolve_peer(chat_id) if isinstance(peer, types.InputPeerChannel): q = functions.channels.ReadHistory( @@ -66,6 +66,6 @@ class ReadHistory(BaseClient): max_id=max_id ) - self.send(q) + await self.send(q) return True diff --git a/pyrogram/client/methods/messages/retract_vote.py b/pyrogram/client/methods/messages/retract_vote.py index fbb020b2..191c8c75 100644 --- a/pyrogram/client/methods/messages/retract_vote.py +++ b/pyrogram/client/methods/messages/retract_vote.py @@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient class RetractVote(BaseClient): - def retract_vote( + async def retract_vote( self, chat_id: Union[int, str], message_id: int @@ -48,9 +48,9 @@ class RetractVote(BaseClient): app.retract_vote(chat_id, message_id) """ - r = self.send( + r = await self.send( functions.messages.SendVote( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), msg_id=message_id, options=[] ) diff --git a/pyrogram/client/methods/messages/search_global.py b/pyrogram/client/methods/messages/search_global.py index 45262d78..2a889e31 100644 --- a/pyrogram/client/methods/messages/search_global.py +++ b/pyrogram/client/methods/messages/search_global.py @@ -16,7 +16,9 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . -from typing import Generator +from typing import Generator, Optional + +from async_generator import async_generator, yield_ import pyrogram from pyrogram.api import functions, types @@ -24,11 +26,12 @@ from pyrogram.client.ext import BaseClient, utils class SearchGlobal(BaseClient): - def search_global( + @async_generator + async def search_global( self, query: str, limit: int = 0, - ) -> Generator["pyrogram.Message", None, None]: + ) -> Optional[Generator["pyrogram.Message", None, None]]: """Search messages globally from all of your chats. .. note:: @@ -64,9 +67,9 @@ class SearchGlobal(BaseClient): offset_id = 0 while True: - messages = utils.parse_messages( + messages = await utils.parse_messages( self, - self.send( + await self.send( functions.messages.SearchGlobal( q=query, offset_rate=offset_date, @@ -84,11 +87,11 @@ class SearchGlobal(BaseClient): last = messages[-1] offset_date = last.date - offset_peer = self.resolve_peer(last.chat.id) + offset_peer = await self.resolve_peer(last.chat.id) offset_id = last.message_id for message in messages: - yield message + await yield_(message) current += 1 diff --git a/pyrogram/client/methods/messages/search_messages.py b/pyrogram/client/methods/messages/search_messages.py index 119c40e2..bfd56663 100644 --- a/pyrogram/client/methods/messages/search_messages.py +++ b/pyrogram/client/methods/messages/search_messages.py @@ -16,11 +16,12 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . -from typing import Union, List, Generator +from typing import Union, List, Generator, Optional import pyrogram from pyrogram.client.ext import BaseClient, utils from pyrogram.api import functions, types +from async_generator import async_generator, yield_ class Filters: @@ -46,7 +47,7 @@ POSSIBLE_VALUES = list(map(lambda x: x.lower(), filter(lambda x: not x.startswit # noinspection PyShadowingBuiltins -def get_chunk( +async def get_chunk( client: BaseClient, chat_id: Union[int, str], query: str = "", @@ -61,9 +62,9 @@ def get_chunk( raise ValueError('Invalid filter "{}". Possible values are: {}'.format( filter, ", ".join('"{}"'.format(v) for v in POSSIBLE_VALUES))) from None - r = client.send( + r = await client.send( functions.messages.Search( - peer=client.resolve_peer(chat_id), + peer=await client.resolve_peer(chat_id), q=query, filter=filter, min_date=0, @@ -74,7 +75,7 @@ def get_chunk( min_id=0, max_id=0, from_id=( - client.resolve_peer(from_user) + await client.resolve_peer(from_user) if from_user else None ), @@ -82,12 +83,13 @@ def get_chunk( ) ) - return utils.parse_messages(client, r) + return await utils.parse_messages(client, r) class SearchMessages(BaseClient): # noinspection PyShadowingBuiltins - def search_messages( + @async_generator + async def search_messages( self, chat_id: Union[int, str], query: str = "", @@ -95,7 +97,7 @@ class SearchMessages(BaseClient): filter: str = "empty", limit: int = 0, from_user: Union[int, str] = None - ) -> Generator["pyrogram.Message", None, None]: + ) -> Optional[Generator["pyrogram.Message", None, None]]: """Search for text and media messages inside a specific chat. Parameters: @@ -160,7 +162,7 @@ class SearchMessages(BaseClient): limit = min(100, total) while True: - messages = get_chunk( + messages = await get_chunk( client=self, chat_id=chat_id, query=query, @@ -176,7 +178,7 @@ class SearchMessages(BaseClient): offset += 100 for message in messages: - yield message + await yield_(message) current += 1 diff --git a/pyrogram/client/methods/messages/send_animation.py b/pyrogram/client/methods/messages/send_animation.py index e8d9285f..46562214 100644 --- a/pyrogram/client/methods/messages/send_animation.py +++ b/pyrogram/client/methods/messages/send_animation.py @@ -27,7 +27,7 @@ from pyrogram.errors import FilePartMissing class SendAnimation(BaseClient): - def send_animation( + async def send_animation( self, chat_id: Union[int, str], animation: Union[str, BinaryIO], @@ -167,8 +167,8 @@ class SendAnimation(BaseClient): try: if isinstance(animation, str): if os.path.isfile(animation): - thumb = self.save_file(thumb) - file = self.save_file(animation, progress=progress, progress_args=progress_args) + thumb = await self.save_file(thumb) + file = await self.save_file(animation, progress=progress, progress_args=progress_args) media = types.InputMediaUploadedDocument( mime_type=self.guess_mime_type(animation) or "video/mp4", file=file, @@ -191,8 +191,8 @@ class SendAnimation(BaseClient): else: media = utils.get_input_media_from_file_id(animation, file_ref, 10) else: - thumb = self.save_file(thumb) - file = self.save_file(animation, progress=progress, progress_args=progress_args) + thumb = await self.save_file(thumb) + file = await self.save_file(animation, progress=progress, progress_args=progress_args) media = types.InputMediaUploadedDocument( mime_type=self.guess_mime_type(animation.name) or "video/mp4", file=file, @@ -211,27 +211,27 @@ class SendAnimation(BaseClient): while True: try: - r = self.send( + r = await self.send( functions.messages.SendMedia( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), media=media, silent=disable_notification or None, reply_to_msg_id=reply_to_message_id, random_id=self.rnd_id(), schedule_date=schedule_date, reply_markup=reply_markup.write() if reply_markup else None, - **self.parser.parse(caption, parse_mode) + **await self.parser.parse(caption, parse_mode) ) ) except FilePartMissing as e: - self.save_file(animation, file_id=file.id, file_part=e.x) + await self.save_file(animation, file_id=file.id, file_part=e.x) else: for i in r.updates: if isinstance( i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage) ): - message = pyrogram.Message._parse( + message = await pyrogram.Message._parse( self, i.message, {i.id: i for i in r.users}, {i.id: i for i in r.chats}, @@ -242,7 +242,7 @@ class SendAnimation(BaseClient): document = message.animation or message.document document_id = utils.get_input_media_from_file_id(document.file_id, document.file_ref).id - self.send( + await self.send( functions.messages.SaveGif( id=document_id, unsave=True diff --git a/pyrogram/client/methods/messages/send_audio.py b/pyrogram/client/methods/messages/send_audio.py index 8dfabe8c..dc460e97 100644 --- a/pyrogram/client/methods/messages/send_audio.py +++ b/pyrogram/client/methods/messages/send_audio.py @@ -27,7 +27,7 @@ from pyrogram.errors import FilePartMissing class SendAudio(BaseClient): - def send_audio( + async def send_audio( self, chat_id: Union[int, str], audio: Union[str, BinaryIO], @@ -37,8 +37,7 @@ class SendAudio(BaseClient): duration: int = 0, performer: str = None, title: str = None, - thumb: Union[str, BinaryIO] = None, - file_name: str = None, + thumb: Union[str, BinaryIO] = None, file_name: str = None, disable_notification: bool = None, reply_to_message_id: int = None, schedule_date: int = None, @@ -167,8 +166,8 @@ class SendAudio(BaseClient): try: if isinstance(audio, str): if os.path.isfile(audio): - thumb = self.save_file(thumb) - file = self.save_file(audio, progress=progress, progress_args=progress_args) + thumb = await self.save_file(thumb) + file = await self.save_file(audio, progress=progress, progress_args=progress_args) media = types.InputMediaUploadedDocument( mime_type=self.guess_mime_type(audio) or "audio/mpeg", file=file, @@ -189,8 +188,8 @@ class SendAudio(BaseClient): else: media = utils.get_input_media_from_file_id(audio, file_ref, 9) else: - thumb = self.save_file(thumb) - file = self.save_file(audio, progress=progress, progress_args=progress_args) + thumb = await self.save_file(thumb) + file = await self.save_file(audio, progress=progress, progress_args=progress_args) media = types.InputMediaUploadedDocument( mime_type=self.guess_mime_type(audio.name) or "audio/mpeg", file=file, @@ -207,27 +206,27 @@ class SendAudio(BaseClient): while True: try: - r = self.send( + r = await self.send( functions.messages.SendMedia( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), media=media, silent=disable_notification or None, reply_to_msg_id=reply_to_message_id, random_id=self.rnd_id(), schedule_date=schedule_date, reply_markup=reply_markup.write() if reply_markup else None, - **self.parser.parse(caption, parse_mode) + **await self.parser.parse(caption, parse_mode) ) ) except FilePartMissing as e: - self.save_file(audio, file_id=file.id, file_part=e.x) + await self.save_file(audio, file_id=file.id, file_part=e.x) else: for i in r.updates: if isinstance( i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage) ): - return pyrogram.Message._parse( + return await pyrogram.Message._parse( self, i.message, {i.id: i for i in r.users}, {i.id: i for i in r.chats}, diff --git a/pyrogram/client/methods/messages/send_cached_media.py b/pyrogram/client/methods/messages/send_cached_media.py index 1ee139f5..d550cc27 100644 --- a/pyrogram/client/methods/messages/send_cached_media.py +++ b/pyrogram/client/methods/messages/send_cached_media.py @@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient, utils class SendCachedMedia(BaseClient): - def send_cached_media( + async def send_cached_media( self, chat_id: Union[int, str], file_id: str, @@ -94,22 +94,22 @@ class SendCachedMedia(BaseClient): app.send_cached_media("me", "CAADBAADzg4AAvLQYAEz_x2EOgdRwBYE") """ - r = self.send( + r = await self.send( functions.messages.SendMedia( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), media=utils.get_input_media_from_file_id(file_id, file_ref), silent=disable_notification or None, reply_to_msg_id=reply_to_message_id, random_id=self.rnd_id(), schedule_date=schedule_date, reply_markup=reply_markup.write() if reply_markup else None, - **self.parser.parse(caption, parse_mode) + **await self.parser.parse(caption, parse_mode) ) ) for i in r.updates: if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)): - return pyrogram.Message._parse( + return await pyrogram.Message._parse( self, i.message, {i.id: i for i in r.users}, {i.id: i for i in r.chats}, diff --git a/pyrogram/client/methods/messages/send_chat_action.py b/pyrogram/client/methods/messages/send_chat_action.py index 5d0dabb9..35a3d722 100644 --- a/pyrogram/client/methods/messages/send_chat_action.py +++ b/pyrogram/client/methods/messages/send_chat_action.py @@ -43,7 +43,7 @@ POSSIBLE_VALUES = list(map(lambda x: x.lower(), filter(lambda x: not x.startswit class SendChatAction(BaseClient): - def send_chat_action(self, chat_id: Union[int, str], action: str) -> bool: + async def send_chat_action(self, chat_id: Union[int, str], action: str) -> bool: """Tell the other party that something is happening on your side. Parameters: @@ -93,9 +93,9 @@ class SendChatAction(BaseClient): else: action = action() - return self.send( + return await self.send( functions.messages.SetTyping( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), action=action ) ) diff --git a/pyrogram/client/methods/messages/send_contact.py b/pyrogram/client/methods/messages/send_contact.py index 95ff6f38..0eacd8fd 100644 --- a/pyrogram/client/methods/messages/send_contact.py +++ b/pyrogram/client/methods/messages/send_contact.py @@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient class SendContact(BaseClient): - def send_contact( + async def send_contact( self, chat_id: Union[int, str], phone_number: str, @@ -83,9 +83,9 @@ class SendContact(BaseClient): app.send_contact("me", "+39 123 456 7890", "Dan") """ - r = self.send( + r = await self.send( functions.messages.SendMedia( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), media=types.InputMediaContact( phone_number=phone_number, first_name=first_name, @@ -103,7 +103,7 @@ class SendContact(BaseClient): for i in r.updates: if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)): - return pyrogram.Message._parse( + return await pyrogram.Message._parse( self, i.message, {i.id: i for i in r.users}, {i.id: i for i in r.chats}, diff --git a/pyrogram/client/methods/messages/send_dice.py b/pyrogram/client/methods/messages/send_dice.py index b426f939..155185cd 100644 --- a/pyrogram/client/methods/messages/send_dice.py +++ b/pyrogram/client/methods/messages/send_dice.py @@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient class SendDice(BaseClient): - def send_dice( + async def send_dice( self, chat_id: Union[int, str], emoji: str = "🎲", @@ -80,9 +80,9 @@ class SendDice(BaseClient): app.send_dice("pyrogramlounge", "🏀") """ - r = self.send( + r = await self.send( functions.messages.SendMedia( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), media=types.InputMediaDice(emoticon=emoji), silent=disable_notification or None, reply_to_msg_id=reply_to_message_id, @@ -94,11 +94,8 @@ class SendDice(BaseClient): ) for i in r.updates: - if isinstance( - i, - (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage) - ): - return pyrogram.Message._parse( + if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)): + return await pyrogram.Message._parse( self, i.message, {i.id: i for i in r.users}, {i.id: i for i in r.chats}, diff --git a/pyrogram/client/methods/messages/send_document.py b/pyrogram/client/methods/messages/send_document.py index 8ca7fc4e..7f93d6cb 100644 --- a/pyrogram/client/methods/messages/send_document.py +++ b/pyrogram/client/methods/messages/send_document.py @@ -27,7 +27,7 @@ from pyrogram.errors import FilePartMissing class SendDocument(BaseClient): - def send_document( + async def send_document( self, chat_id: Union[int, str], document: Union[str, BinaryIO], @@ -147,8 +147,8 @@ class SendDocument(BaseClient): try: if isinstance(document, str): if os.path.isfile(document): - thumb = self.save_file(thumb) - file = self.save_file(document, progress=progress, progress_args=progress_args) + thumb = await self.save_file(thumb) + file = await self.save_file(document, progress=progress, progress_args=progress_args) media = types.InputMediaUploadedDocument( mime_type=self.guess_mime_type(document) or "application/zip", file=file, @@ -165,8 +165,8 @@ class SendDocument(BaseClient): else: media = utils.get_input_media_from_file_id(document, file_ref, 5) else: - thumb = self.save_file(thumb) - file = self.save_file(document, progress=progress, progress_args=progress_args) + thumb = await self.save_file(thumb) + file = await self.save_file(document, progress=progress, progress_args=progress_args) media = types.InputMediaUploadedDocument( mime_type=self.guess_mime_type(document.name) or "application/zip", file=file, @@ -178,27 +178,27 @@ class SendDocument(BaseClient): while True: try: - r = self.send( + r = await self.send( functions.messages.SendMedia( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), media=media, silent=disable_notification or None, reply_to_msg_id=reply_to_message_id, random_id=self.rnd_id(), schedule_date=schedule_date, reply_markup=reply_markup.write() if reply_markup else None, - **self.parser.parse(caption, parse_mode) + **await self.parser.parse(caption, parse_mode) ) ) except FilePartMissing as e: - self.save_file(document, file_id=file.id, file_part=e.x) + await self.save_file(document, file_id=file.id, file_part=e.x) else: for i in r.updates: if isinstance( i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage) ): - return pyrogram.Message._parse( + return await pyrogram.Message._parse( self, i.message, {i.id: i for i in r.users}, {i.id: i for i in r.chats}, diff --git a/pyrogram/client/methods/messages/send_location.py b/pyrogram/client/methods/messages/send_location.py index 04b614ce..b23d9b10 100644 --- a/pyrogram/client/methods/messages/send_location.py +++ b/pyrogram/client/methods/messages/send_location.py @@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient class SendLocation(BaseClient): - def send_location( + async def send_location( self, chat_id: Union[int, str], latitude: float, @@ -75,9 +75,9 @@ class SendLocation(BaseClient): app.send_location("me", 51.500729, -0.124583) """ - r = self.send( + r = await self.send( functions.messages.SendMedia( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), media=types.InputMediaGeoPoint( geo_point=types.InputGeoPoint( lat=latitude, @@ -95,7 +95,7 @@ class SendLocation(BaseClient): for i in r.updates: if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)): - return pyrogram.Message._parse( + return await pyrogram.Message._parse( self, i.message, {i.id: i for i in r.users}, {i.id: i for i in r.chats}, diff --git a/pyrogram/client/methods/messages/send_media_group.py b/pyrogram/client/methods/messages/send_media_group.py index 2a2cca74..0fad6e91 100644 --- a/pyrogram/client/methods/messages/send_media_group.py +++ b/pyrogram/client/methods/messages/send_media_group.py @@ -30,7 +30,7 @@ log = logging.getLogger(__name__) class SendMediaGroup(BaseClient): # TODO: Add progress parameter - def send_media_group( + async def send_media_group( self, chat_id: Union[int, str], media: List[Union["pyrogram.InputMediaPhoto", "pyrogram.InputMediaVideo"]], @@ -77,11 +77,11 @@ class SendMediaGroup(BaseClient): for i in media: if isinstance(i, pyrogram.InputMediaPhoto): if os.path.isfile(i.media): - media = self.send( + media = await self.send( functions.messages.UploadMedia( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), media=types.InputMediaUploadedPhoto( - file=self.save_file(i.media) + file=await self.save_file(i.media) ) ) ) @@ -94,9 +94,9 @@ class SendMediaGroup(BaseClient): ) ) elif re.match("^https?://", i.media): - media = self.send( + media = await self.send( functions.messages.UploadMedia( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), media=types.InputMediaPhotoExternal( url=i.media ) @@ -114,11 +114,11 @@ class SendMediaGroup(BaseClient): media = utils.get_input_media_from_file_id(i.media, i.file_ref, 2) elif isinstance(i, pyrogram.InputMediaVideo): if os.path.isfile(i.media): - media = self.send( + media = await self.send( functions.messages.UploadMedia( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), media=types.InputMediaUploadedDocument( - file=self.save_file(i.media), + file=await self.save_file(i.media), thumb=self.save_file(i.thumb), mime_type=self.guess_mime_type(i.media) or "video/mp4", attributes=[ @@ -142,9 +142,9 @@ class SendMediaGroup(BaseClient): ) ) elif re.match("^https?://", i.media): - media = self.send( + media = await self.send( functions.messages.UploadMedia( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), media=types.InputMediaDocumentExternal( url=i.media ) @@ -165,20 +165,20 @@ class SendMediaGroup(BaseClient): types.InputSingleMedia( media=media, random_id=self.rnd_id(), - **self.parser.parse(i.caption, i.parse_mode) + **await self.parser.parse(i.caption, i.parse_mode) ) ) - r = self.send( + r = await self.send( functions.messages.SendMultiMedia( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), multi_media=multi_media, silent=disable_notification or None, reply_to_msg_id=reply_to_message_id ) ) - return utils.parse_messages( + return await utils.parse_messages( self, types.messages.Messages( messages=[m.message for m in filter( diff --git a/pyrogram/client/methods/messages/send_message.py b/pyrogram/client/methods/messages/send_message.py index f719031c..58385bcf 100644 --- a/pyrogram/client/methods/messages/send_message.py +++ b/pyrogram/client/methods/messages/send_message.py @@ -24,7 +24,7 @@ from ...ext import BaseClient class SendMessage(BaseClient): - def send_message( + async def send_message( self, chat_id: Union[int, str], text: str, @@ -116,11 +116,11 @@ class SendMessage(BaseClient): ])) """ - message, entities = self.parser.parse(text, parse_mode).values() + message, entities = (await self.parser.parse(text, parse_mode)).values() - r = self.send( + r = await self.send( functions.messages.SendMessage( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), no_webpage=disable_web_page_preview or None, silent=disable_notification or None, reply_to_msg_id=reply_to_message_id, @@ -133,7 +133,7 @@ class SendMessage(BaseClient): ) if isinstance(r, types.UpdateShortSentMessage): - peer = self.resolve_peer(chat_id) + peer = await self.resolve_peer(chat_id) peer_id = ( peer.user_id @@ -160,7 +160,7 @@ class SendMessage(BaseClient): for i in r.updates: if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)): - return pyrogram.Message._parse( + return await pyrogram.Message._parse( self, i.message, {i.id: i for i in r.users}, {i.id: i for i in r.chats}, diff --git a/pyrogram/client/methods/messages/send_photo.py b/pyrogram/client/methods/messages/send_photo.py index 63101685..7c2c194c 100644 --- a/pyrogram/client/methods/messages/send_photo.py +++ b/pyrogram/client/methods/messages/send_photo.py @@ -27,7 +27,7 @@ from pyrogram.errors import FilePartMissing class SendPhoto(BaseClient): - def send_photo( + async def send_photo( self, chat_id: Union[int, str], photo: Union[str, BinaryIO], @@ -141,7 +141,7 @@ class SendPhoto(BaseClient): try: if isinstance(photo, str): if os.path.isfile(photo): - file = self.save_file(photo, progress=progress, progress_args=progress_args) + file = await self.save_file(photo, progress=progress, progress_args=progress_args) media = types.InputMediaUploadedPhoto( file=file, ttl_seconds=ttl_seconds @@ -154,7 +154,7 @@ class SendPhoto(BaseClient): else: media = utils.get_input_media_from_file_id(photo, file_ref, 2) else: - file = self.save_file(photo, progress=progress, progress_args=progress_args) + file = await self.save_file(photo, progress=progress, progress_args=progress_args) media = types.InputMediaUploadedPhoto( file=file, ttl_seconds=ttl_seconds @@ -162,27 +162,27 @@ class SendPhoto(BaseClient): while True: try: - r = self.send( + r = await self.send( functions.messages.SendMedia( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), media=media, silent=disable_notification or None, reply_to_msg_id=reply_to_message_id, random_id=self.rnd_id(), schedule_date=schedule_date, reply_markup=reply_markup.write() if reply_markup else None, - **self.parser.parse(caption, parse_mode) + **await self.parser.parse(caption, parse_mode) ) ) except FilePartMissing as e: - self.save_file(photo, file_id=file.id, file_part=e.x) + await self.save_file(photo, file_id=file.id, file_part=e.x) else: for i in r.updates: if isinstance( i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage) ): - return pyrogram.Message._parse( + return await pyrogram.Message._parse( self, i.message, {i.id: i for i in r.users}, {i.id: i for i in r.chats}, diff --git a/pyrogram/client/methods/messages/send_poll.py b/pyrogram/client/methods/messages/send_poll.py index e0baf5a8..7607a546 100644 --- a/pyrogram/client/methods/messages/send_poll.py +++ b/pyrogram/client/methods/messages/send_poll.py @@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient class SendPoll(BaseClient): - def send_poll( + async def send_poll( self, chat_id: Union[int, str], question: str, @@ -95,9 +95,9 @@ class SendPoll(BaseClient): app.send_poll(chat_id, "Is this a poll question?", ["Yes", "No", "Maybe"]) """ - r = self.send( + r = await self.send( functions.messages.SendMedia( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), media=types.InputMediaPoll( poll=types.Poll( id=0, @@ -123,7 +123,7 @@ class SendPoll(BaseClient): for i in r.updates: if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)): - return pyrogram.Message._parse( + return await pyrogram.Message._parse( self, i.message, {i.id: i for i in r.users}, {i.id: i for i in r.chats}, diff --git a/pyrogram/client/methods/messages/send_sticker.py b/pyrogram/client/methods/messages/send_sticker.py index 529a4f8d..8025a0dd 100644 --- a/pyrogram/client/methods/messages/send_sticker.py +++ b/pyrogram/client/methods/messages/send_sticker.py @@ -27,7 +27,7 @@ from pyrogram.errors import FilePartMissing class SendSticker(BaseClient): - def send_sticker( + async def send_sticker( self, chat_id: Union[int, str], sticker: Union[str, BinaryIO], @@ -117,7 +117,7 @@ class SendSticker(BaseClient): try: if isinstance(sticker, str): if os.path.isfile(sticker): - file = self.save_file(sticker, progress=progress, progress_args=progress_args) + file = await self.save_file(sticker, progress=progress, progress_args=progress_args) media = types.InputMediaUploadedDocument( mime_type=self.guess_mime_type(sticker) or "image/webp", file=file, @@ -132,7 +132,7 @@ class SendSticker(BaseClient): else: media = utils.get_input_media_from_file_id(sticker, file_ref, 8) else: - file = self.save_file(sticker, progress=progress, progress_args=progress_args) + file = await self.save_file(sticker, progress=progress, progress_args=progress_args) media = types.InputMediaUploadedDocument( mime_type=self.guess_mime_type(sticker.name) or "image/webp", file=file, @@ -143,9 +143,9 @@ class SendSticker(BaseClient): while True: try: - r = self.send( + r = await self.send( functions.messages.SendMedia( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), media=media, silent=disable_notification or None, reply_to_msg_id=reply_to_message_id, @@ -156,14 +156,14 @@ class SendSticker(BaseClient): ) ) except FilePartMissing as e: - self.save_file(sticker, file_id=file.id, file_part=e.x) + await self.save_file(sticker, file_id=file.id, file_part=e.x) else: for i in r.updates: if isinstance( i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage) ): - return pyrogram.Message._parse( + return await pyrogram.Message._parse( self, i.message, {i.id: i for i in r.users}, {i.id: i for i in r.chats}, diff --git a/pyrogram/client/methods/messages/send_venue.py b/pyrogram/client/methods/messages/send_venue.py index 98ff4103..f6f09d5e 100644 --- a/pyrogram/client/methods/messages/send_venue.py +++ b/pyrogram/client/methods/messages/send_venue.py @@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient class SendVenue(BaseClient): - def send_venue( + async def send_venue( self, chat_id: Union[int, str], latitude: float, @@ -94,9 +94,9 @@ class SendVenue(BaseClient): "me", 51.500729, -0.124583, "Elizabeth Tower", "Westminster, London SW1A 0AA, UK") """ - r = self.send( + r = await self.send( functions.messages.SendMedia( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), media=types.InputMediaVenue( geo_point=types.InputGeoPoint( lat=latitude, @@ -119,7 +119,7 @@ class SendVenue(BaseClient): for i in r.updates: if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)): - return pyrogram.Message._parse( + return await pyrogram.Message._parse( self, i.message, {i.id: i for i in r.users}, {i.id: i for i in r.chats}, diff --git a/pyrogram/client/methods/messages/send_video.py b/pyrogram/client/methods/messages/send_video.py index 40691771..87cecf5a 100644 --- a/pyrogram/client/methods/messages/send_video.py +++ b/pyrogram/client/methods/messages/send_video.py @@ -27,7 +27,7 @@ from pyrogram.errors import FilePartMissing class SendVideo(BaseClient): - def send_video( + async def send_video( self, chat_id: Union[int, str], video: Union[str, BinaryIO], @@ -164,8 +164,8 @@ class SendVideo(BaseClient): try: if isinstance(video, str): if os.path.isfile(video): - thumb = self.save_file(thumb) - file = self.save_file(video, progress=progress, progress_args=progress_args) + thumb = await self.save_file(thumb) + file = await self.save_file(video, progress=progress, progress_args=progress_args) media = types.InputMediaUploadedDocument( mime_type=self.guess_mime_type(video) or "video/mp4", file=file, @@ -187,8 +187,8 @@ class SendVideo(BaseClient): else: media = utils.get_input_media_from_file_id(video, file_ref, 4) else: - thumb = self.save_file(thumb) - file = self.save_file(video, progress=progress, progress_args=progress_args) + thumb = await self.save_file(thumb) + file = await self.save_file(video, progress=progress, progress_args=progress_args) media = types.InputMediaUploadedDocument( mime_type=self.guess_mime_type(video.name) or "video/mp4", file=file, @@ -206,27 +206,27 @@ class SendVideo(BaseClient): while True: try: - r = self.send( + r = await self.send( functions.messages.SendMedia( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), media=media, silent=disable_notification or None, reply_to_msg_id=reply_to_message_id, random_id=self.rnd_id(), schedule_date=schedule_date, reply_markup=reply_markup.write() if reply_markup else None, - **self.parser.parse(caption, parse_mode) + **await self.parser.parse(caption, parse_mode) ) ) except FilePartMissing as e: - self.save_file(video, file_id=file.id, file_part=e.x) + await self.save_file(video, file_id=file.id, file_part=e.x) else: for i in r.updates: if isinstance( i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage) ): - return pyrogram.Message._parse( + return await pyrogram.Message._parse( self, i.message, {i.id: i for i in r.users}, {i.id: i for i in r.chats}, diff --git a/pyrogram/client/methods/messages/send_video_note.py b/pyrogram/client/methods/messages/send_video_note.py index d3b32834..c204f6ca 100644 --- a/pyrogram/client/methods/messages/send_video_note.py +++ b/pyrogram/client/methods/messages/send_video_note.py @@ -26,7 +26,7 @@ from pyrogram.errors import FilePartMissing class SendVideoNote(BaseClient): - def send_video_note( + async def send_video_note( self, chat_id: Union[int, str], video_note: Union[str, BinaryIO], @@ -131,8 +131,8 @@ class SendVideoNote(BaseClient): try: if isinstance(video_note, str): if os.path.isfile(video_note): - thumb = self.save_file(thumb) - file = self.save_file(video_note, progress=progress, progress_args=progress_args) + thumb = await self.save_file(thumb) + file = await self.save_file(video_note, progress=progress, progress_args=progress_args) media = types.InputMediaUploadedDocument( mime_type=self.guess_mime_type(video_note) or "video/mp4", file=file, @@ -149,8 +149,8 @@ class SendVideoNote(BaseClient): else: media = utils.get_input_media_from_file_id(video_note, file_ref, 13) else: - thumb = self.save_file(thumb) - file = self.save_file(video_note, progress=progress, progress_args=progress_args) + thumb = await self.save_file(thumb) + file = await self.save_file(video_note, progress=progress, progress_args=progress_args) media = types.InputMediaUploadedDocument( mime_type=self.guess_mime_type(video_note.name) or "video/mp4", file=file, @@ -167,9 +167,9 @@ class SendVideoNote(BaseClient): while True: try: - r = self.send( + r = await self.send( functions.messages.SendMedia( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), media=media, silent=disable_notification or None, reply_to_msg_id=reply_to_message_id, @@ -180,14 +180,14 @@ class SendVideoNote(BaseClient): ) ) except FilePartMissing as e: - self.save_file(video_note, file_id=file.id, file_part=e.x) + await self.save_file(video_note, file_id=file.id, file_part=e.x) else: for i in r.updates: if isinstance( i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage) ): - return pyrogram.Message._parse( + return await pyrogram.Message._parse( self, i.message, {i.id: i for i in r.users}, {i.id: i for i in r.chats}, diff --git a/pyrogram/client/methods/messages/send_voice.py b/pyrogram/client/methods/messages/send_voice.py index f99b4236..98221e8d 100644 --- a/pyrogram/client/methods/messages/send_voice.py +++ b/pyrogram/client/methods/messages/send_voice.py @@ -27,7 +27,7 @@ from pyrogram.errors import FilePartMissing class SendVoice(BaseClient): - def send_voice( + async def send_voice( self, chat_id: Union[int, str], voice: Union[str, BinaryIO], @@ -136,7 +136,7 @@ class SendVoice(BaseClient): try: if isinstance(voice, str): if os.path.isfile(voice): - file = self.save_file(voice, progress=progress, progress_args=progress_args) + file = await self.save_file(voice, progress=progress, progress_args=progress_args) media = types.InputMediaUploadedDocument( mime_type=self.guess_mime_type(voice) or "audio/mpeg", file=file, @@ -154,7 +154,7 @@ class SendVoice(BaseClient): else: media = utils.get_input_media_from_file_id(voice, file_ref, 3) else: - file = self.save_file(voice, progress=progress, progress_args=progress_args) + file = await self.save_file(voice, progress=progress, progress_args=progress_args) media = types.InputMediaUploadedDocument( mime_type=self.guess_mime_type(voice.name) or "audio/mpeg", file=file, @@ -168,27 +168,27 @@ class SendVoice(BaseClient): while True: try: - r = self.send( + r = await self.send( functions.messages.SendMedia( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), media=media, silent=disable_notification or None, reply_to_msg_id=reply_to_message_id, random_id=self.rnd_id(), schedule_date=schedule_date, reply_markup=reply_markup.write() if reply_markup else None, - **self.parser.parse(caption, parse_mode) + **await self.parser.parse(caption, parse_mode) ) ) except FilePartMissing as e: - self.save_file(voice, file_id=file.id, file_part=e.x) + await self.save_file(voice, file_id=file.id, file_part=e.x) else: for i in r.updates: if isinstance( i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage) ): - return pyrogram.Message._parse( + return await pyrogram.Message._parse( self, i.message, {i.id: i for i in r.users}, {i.id: i for i in r.chats}, diff --git a/pyrogram/client/methods/messages/stop_poll.py b/pyrogram/client/methods/messages/stop_poll.py index 4d133f7f..79498e45 100644 --- a/pyrogram/client/methods/messages/stop_poll.py +++ b/pyrogram/client/methods/messages/stop_poll.py @@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient class StopPoll(BaseClient): - def stop_poll( + async def stop_poll( self, chat_id: Union[int, str], message_id: int, @@ -54,11 +54,11 @@ class StopPoll(BaseClient): app.stop_poll(chat_id, message_id) """ - poll = self.get_messages(chat_id, message_id).poll + poll = (await self.get_messages(chat_id, message_id)).poll - r = self.send( + r = await self.send( functions.messages.EditMessage( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), id=message_id, media=types.InputMediaPoll( poll=types.Poll( diff --git a/pyrogram/client/methods/messages/vote_poll.py b/pyrogram/client/methods/messages/vote_poll.py index 335ca7cb..667df0fb 100644 --- a/pyrogram/client/methods/messages/vote_poll.py +++ b/pyrogram/client/methods/messages/vote_poll.py @@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient class VotePoll(BaseClient): - def vote_poll( + async def vote_poll( self, chat_id: Union[int, str], message_id: id, @@ -53,12 +53,12 @@ class VotePoll(BaseClient): app.vote_poll(chat_id, message_id, 6) """ - poll = self.get_messages(chat_id, message_id).poll + poll = (await self.get_messages(chat_id, message_id)).poll options = [options] if not isinstance(options, list) else options - r = self.send( + r = await self.send( functions.messages.SendVote( - peer=self.resolve_peer(chat_id), + peer=await self.resolve_peer(chat_id), msg_id=message_id, options=[poll.options[option].data for option in options] ) diff --git a/pyrogram/client/methods/password/change_cloud_password.py b/pyrogram/client/methods/password/change_cloud_password.py index 20625657..54ec2891 100644 --- a/pyrogram/client/methods/password/change_cloud_password.py +++ b/pyrogram/client/methods/password/change_cloud_password.py @@ -24,7 +24,7 @@ from ...ext import BaseClient class ChangeCloudPassword(BaseClient): - def change_cloud_password( + async def change_cloud_password( self, current_password: str, new_password: str, @@ -57,7 +57,7 @@ class ChangeCloudPassword(BaseClient): # Change password and hint app.change_cloud_password("current_password", "new_password", new_hint="hint") """ - r = self.send(functions.account.GetPassword()) + r = await self.send(functions.account.GetPassword()) if not r.has_password: raise ValueError("There is no cloud password to change") @@ -66,7 +66,7 @@ class ChangeCloudPassword(BaseClient): new_hash = btoi(compute_hash(r.new_algo, new_password)) new_hash = itob(pow(r.new_algo.g, new_hash, btoi(r.new_algo.p))) - self.send( + await self.send( functions.account.UpdatePasswordSettings( password=compute_check(r, current_password), new_settings=types.account.PasswordInputSettings( diff --git a/pyrogram/client/methods/password/enable_cloud_password.py b/pyrogram/client/methods/password/enable_cloud_password.py index c8052aa8..b6291943 100644 --- a/pyrogram/client/methods/password/enable_cloud_password.py +++ b/pyrogram/client/methods/password/enable_cloud_password.py @@ -24,7 +24,7 @@ from ...ext import BaseClient class EnableCloudPassword(BaseClient): - def enable_cloud_password( + async def enable_cloud_password( self, password: str, hint: str = "", @@ -62,7 +62,7 @@ class EnableCloudPassword(BaseClient): # Enable password with hint and email app.enable_cloud_password("password", hint="hint", email="user@email.com") """ - r = self.send(functions.account.GetPassword()) + r = await self.send(functions.account.GetPassword()) if r.has_password: raise ValueError("There is already a cloud password enabled") @@ -71,7 +71,7 @@ class EnableCloudPassword(BaseClient): new_hash = btoi(compute_hash(r.new_algo, password)) new_hash = itob(pow(r.new_algo.g, new_hash, btoi(r.new_algo.p))) - self.send( + await self.send( functions.account.UpdatePasswordSettings( password=types.InputCheckPasswordEmpty(), new_settings=types.account.PasswordInputSettings( diff --git a/pyrogram/client/methods/password/remove_cloud_password.py b/pyrogram/client/methods/password/remove_cloud_password.py index 21ebb8b3..9d41a9db 100644 --- a/pyrogram/client/methods/password/remove_cloud_password.py +++ b/pyrogram/client/methods/password/remove_cloud_password.py @@ -22,7 +22,7 @@ from ...ext import BaseClient class RemoveCloudPassword(BaseClient): - def remove_cloud_password( + async def remove_cloud_password( self, password: str ) -> bool: @@ -43,12 +43,12 @@ class RemoveCloudPassword(BaseClient): app.remove_cloud_password("password") """ - r = self.send(functions.account.GetPassword()) + r = await self.send(functions.account.GetPassword()) if not r.has_password: raise ValueError("There is no cloud password to remove") - self.send( + await self.send( functions.account.UpdatePasswordSettings( password=compute_check(r, password), new_settings=types.account.PasswordInputSettings( diff --git a/pyrogram/client/methods/users/block_user.py b/pyrogram/client/methods/users/block_user.py index b61507f9..8dd6e09d 100644 --- a/pyrogram/client/methods/users/block_user.py +++ b/pyrogram/client/methods/users/block_user.py @@ -23,7 +23,7 @@ from ...ext import BaseClient class BlockUser(BaseClient): - def block_user( + async def block_user( self, user_id: Union[int, str] ) -> bool: @@ -44,9 +44,9 @@ class BlockUser(BaseClient): app.block_user(user_id) """ return bool( - self.send( + await self.send( functions.contacts.Block( - id=self.resolve_peer(user_id) + id=await self.resolve_peer(user_id) ) ) ) diff --git a/pyrogram/client/methods/users/delete_profile_photos.py b/pyrogram/client/methods/users/delete_profile_photos.py index 66ad219f..ac184da5 100644 --- a/pyrogram/client/methods/users/delete_profile_photos.py +++ b/pyrogram/client/methods/users/delete_profile_photos.py @@ -24,7 +24,7 @@ from ...ext import BaseClient class DeleteProfilePhotos(BaseClient): - def delete_profile_photos( + async def delete_profile_photos( self, photo_ids: Union[str, List[str]] ) -> bool: @@ -53,7 +53,7 @@ class DeleteProfilePhotos(BaseClient): photo_ids = photo_ids if isinstance(photo_ids, list) else [photo_ids] input_photos = [utils.get_input_media_from_file_id(i).id for i in photo_ids] - return bool(self.send( + return bool(await self.send( functions.photos.DeletePhotos( id=input_photos ) diff --git a/pyrogram/client/methods/users/get_common_chats.py b/pyrogram/client/methods/users/get_common_chats.py index 35d037fb..fab202fd 100644 --- a/pyrogram/client/methods/users/get_common_chats.py +++ b/pyrogram/client/methods/users/get_common_chats.py @@ -24,7 +24,7 @@ from ...ext import BaseClient class GetCommonChats(BaseClient): - def get_common_chats(self, user_id: Union[int, str]) -> list: + async def get_common_chats(self, user_id: Union[int, str]) -> list: """Get the common chats you have with a user. Parameters: @@ -46,10 +46,10 @@ class GetCommonChats(BaseClient): print(common) """ - peer = self.resolve_peer(user_id) + peer = await self.resolve_peer(user_id) if isinstance(peer, types.InputPeerUser): - r = self.send( + r = await self.send( functions.messages.GetCommonChats( user_id=peer, max_id=0, diff --git a/pyrogram/client/methods/users/get_me.py b/pyrogram/client/methods/users/get_me.py index 1814fa6d..0efbddb2 100644 --- a/pyrogram/client/methods/users/get_me.py +++ b/pyrogram/client/methods/users/get_me.py @@ -22,7 +22,7 @@ from ...ext import BaseClient class GetMe(BaseClient): - def get_me(self) -> "pyrogram.User": + async def get_me(self) -> "pyrogram.User": """Get your own user identity. Returns: @@ -36,9 +36,9 @@ class GetMe(BaseClient): """ return pyrogram.User._parse( self, - self.send( + (await self.send( functions.users.GetFullUser( id=types.InputPeerSelf() ) - ).user + )).user ) diff --git a/pyrogram/client/methods/users/get_profile_photos.py b/pyrogram/client/methods/users/get_profile_photos.py index ec23e651..fded8dcb 100644 --- a/pyrogram/client/methods/users/get_profile_photos.py +++ b/pyrogram/client/methods/users/get_profile_photos.py @@ -25,7 +25,7 @@ from ...ext import BaseClient class GetProfilePhotos(BaseClient): - def get_profile_photos( + async def get_profile_photos( self, chat_id: Union[int, str], offset: int = 0, @@ -62,12 +62,12 @@ class GetProfilePhotos(BaseClient): # Get 3 profile photos of a user, skip the first 5 app.get_profile_photos("haskell", limit=3, offset=5) """ - peer_id = self.resolve_peer(chat_id) + peer_id = await self.resolve_peer(chat_id) if isinstance(peer_id, types.InputPeerChannel): - r = utils.parse_messages( + r = await utils.parse_messages( self, - self.send( + await self.send( functions.messages.Search( peer=peer_id, q="", @@ -86,7 +86,7 @@ class GetProfilePhotos(BaseClient): return pyrogram.List([message.new_chat_photo for message in r][:limit]) else: - r = self.send( + r = await self.send( functions.photos.GetUserPhotos( user_id=peer_id, offset=offset, diff --git a/pyrogram/client/methods/users/get_profile_photos_count.py b/pyrogram/client/methods/users/get_profile_photos_count.py index a927d8bd..affc00e1 100644 --- a/pyrogram/client/methods/users/get_profile_photos_count.py +++ b/pyrogram/client/methods/users/get_profile_photos_count.py @@ -23,7 +23,7 @@ from ...ext import BaseClient class GetProfilePhotosCount(BaseClient): - def get_profile_photos_count(self, chat_id: Union[int, str]) -> int: + async def get_profile_photos_count(self, chat_id: Union[int, str]) -> int: """Get the total count of profile pictures for a user. Parameters: @@ -42,10 +42,10 @@ class GetProfilePhotosCount(BaseClient): print(count) """ - peer_id = self.resolve_peer(chat_id) + peer_id = await self.resolve_peer(chat_id) if isinstance(peer_id, types.InputPeerChannel): - r = self.send( + r = await self.send( functions.messages.GetSearchCounters( peer=peer_id, filters=[types.InputMessagesFilterChatPhotos()], @@ -54,7 +54,7 @@ class GetProfilePhotosCount(BaseClient): return r[0].count else: - r = self.send( + r = await self.send( functions.photos.GetUserPhotos( user_id=peer_id, offset=0, diff --git a/pyrogram/client/methods/users/get_users.py b/pyrogram/client/methods/users/get_users.py index b115cb03..05476bc4 100644 --- a/pyrogram/client/methods/users/get_users.py +++ b/pyrogram/client/methods/users/get_users.py @@ -16,6 +16,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +import asyncio from typing import Iterable, Union, List import pyrogram @@ -24,7 +25,7 @@ from ...ext import BaseClient class GetUsers(BaseClient): - def get_users( + async def get_users( self, user_ids: Union[Iterable[Union[int, str]], int, str] ) -> Union["pyrogram.User", List["pyrogram.User"]]: @@ -53,9 +54,9 @@ class GetUsers(BaseClient): """ is_iterable = not isinstance(user_ids, (int, str)) user_ids = list(user_ids) if is_iterable else [user_ids] - user_ids = [self.resolve_peer(i) for i in user_ids] + user_ids = await asyncio.gather(*[self.resolve_peer(i) for i in user_ids]) - r = self.send( + r = await self.send( functions.users.GetUsers( id=user_ids ) diff --git a/pyrogram/client/methods/users/iter_profile_photos.py b/pyrogram/client/methods/users/iter_profile_photos.py index fb09cff7..bdd3d8b7 100644 --- a/pyrogram/client/methods/users/iter_profile_photos.py +++ b/pyrogram/client/methods/users/iter_profile_photos.py @@ -16,19 +16,22 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . -from typing import Union, Generator +from typing import Union, Generator, Optional import pyrogram +from async_generator import async_generator, yield_ + from ...ext import BaseClient class IterProfilePhotos(BaseClient): - def iter_profile_photos( + @async_generator + async def iter_profile_photos( self, chat_id: Union[int, str], offset: int = 0, limit: int = 0, - ) -> Generator["pyrogram.Photo", None, None]: + ) -> Optional[Generator["pyrogram.Message", None, None]]: """Iterate through a chat or a user profile photos sequentially. This convenience method does the same as repeatedly calling :meth:`~Client.get_profile_photos` in a loop, thus @@ -62,7 +65,7 @@ class IterProfilePhotos(BaseClient): limit = min(100, total) while True: - photos = self.get_profile_photos( + photos = await self.get_profile_photos( chat_id=chat_id, offset=offset, limit=limit @@ -74,7 +77,7 @@ class IterProfilePhotos(BaseClient): offset += len(photos) for photo in photos: - yield photo + await yield_(photo) current += 1 diff --git a/pyrogram/client/methods/users/set_profile_photo.py b/pyrogram/client/methods/users/set_profile_photo.py index 01741df9..b9dbbf10 100644 --- a/pyrogram/client/methods/users/set_profile_photo.py +++ b/pyrogram/client/methods/users/set_profile_photo.py @@ -23,7 +23,7 @@ from ...ext import BaseClient class SetProfilePhoto(BaseClient): - def set_profile_photo( + async def set_profile_photo( self, *, photo: Union[str, BinaryIO] = None, @@ -64,10 +64,10 @@ class SetProfilePhoto(BaseClient): """ return bool( - self.send( + await self.send( functions.photos.UploadProfilePhoto( - file=self.save_file(photo), - video=self.save_file(video) + file=await self.save_file(photo), + video=await self.save_file(video) ) ) ) diff --git a/pyrogram/client/methods/users/unblock_user.py b/pyrogram/client/methods/users/unblock_user.py index 8459cfd6..fddf9ff6 100644 --- a/pyrogram/client/methods/users/unblock_user.py +++ b/pyrogram/client/methods/users/unblock_user.py @@ -23,7 +23,7 @@ from ...ext import BaseClient class UnblockUser(BaseClient): - def unblock_user( + async def unblock_user( self, user_id: Union[int, str] ) -> bool: @@ -44,9 +44,9 @@ class UnblockUser(BaseClient): app.unblock_user(user_id) """ return bool( - self.send( + await self.send( functions.contacts.Unblock( - id=self.resolve_peer(user_id) + id=await self.resolve_peer(user_id) ) ) ) diff --git a/pyrogram/client/methods/users/update_profile.py b/pyrogram/client/methods/users/update_profile.py index da7d62f4..145d5035 100644 --- a/pyrogram/client/methods/users/update_profile.py +++ b/pyrogram/client/methods/users/update_profile.py @@ -21,7 +21,7 @@ from ...ext import BaseClient class UpdateProfile(BaseClient): - def update_profile( + async def update_profile( self, first_name: str = None, last_name: str = None, @@ -60,7 +60,7 @@ class UpdateProfile(BaseClient): """ return bool( - self.send( + await self.send( functions.account.UpdateProfile( first_name=first_name, last_name=last_name, diff --git a/pyrogram/client/methods/users/update_username.py b/pyrogram/client/methods/users/update_username.py index 24a12a8e..50e2388c 100644 --- a/pyrogram/client/methods/users/update_username.py +++ b/pyrogram/client/methods/users/update_username.py @@ -23,7 +23,7 @@ from ...ext import BaseClient class UpdateUsername(BaseClient): - def update_username( + async def update_username( self, username: Union[str, None] ) -> bool: @@ -47,7 +47,7 @@ class UpdateUsername(BaseClient): """ return bool( - self.send( + await self.send( functions.account.UpdateUsername( username=username or "" ) diff --git a/pyrogram/client/parser/html.py b/pyrogram/client/parser/html.py index 35dd770c..3ce70d51 100644 --- a/pyrogram/client/parser/html.py +++ b/pyrogram/client/parser/html.py @@ -110,7 +110,7 @@ class HTML: def __init__(self, client: Union["pyrogram.BaseClient", None]): self.client = client - def parse(self, text: str): + async def parse(self, text: str): # Strip whitespace characters from the end of the message, but preserve closing tags text = re.sub(r"\s*()\s*$", r"\1", text) @@ -132,7 +132,7 @@ class HTML: if isinstance(entity, types.InputMessageEntityMentionName): try: if self.client is not None: - entity.user_id = self.client.resolve_peer(entity.user_id) + entity.user_id = await self.client.resolve_peer(entity.user_id) except PeerIdInvalid: continue diff --git a/pyrogram/client/parser/markdown.py b/pyrogram/client/parser/markdown.py index 5f4ab258..4c954efd 100644 --- a/pyrogram/client/parser/markdown.py +++ b/pyrogram/client/parser/markdown.py @@ -56,7 +56,7 @@ class Markdown: def __init__(self, client: Union["pyrogram.BaseClient", None]): self.html = HTML(client) - def parse(self, text: str, strict: bool = False): + async def parse(self, text: str, strict: bool = False): if strict: text = html.escape(text) @@ -102,7 +102,7 @@ class Markdown: text = utils.replace_once(text, delim, tag, start) - return self.html.parse(text) + return await self.html.parse(text) @staticmethod def unparse(text: str, entities: list): diff --git a/pyrogram/client/parser/parser.py b/pyrogram/client/parser/parser.py index 968b95f1..eb4f2e19 100644 --- a/pyrogram/client/parser/parser.py +++ b/pyrogram/client/parser/parser.py @@ -30,7 +30,7 @@ class Parser: self.html = HTML(client) self.markdown = Markdown(client) - def parse(self, text: str, mode: Union[str, None] = object): + async def parse(self, text: str, mode: Union[str, None] = object): text = str(text).strip() if mode == object: @@ -48,13 +48,13 @@ class Parser: mode = mode.lower() if mode == "combined": - return self.markdown.parse(text) + return await self.markdown.parse(text) if mode in ["markdown", "md"]: - return self.markdown.parse(text, True) + return await self.markdown.parse(text, True) if mode == "html": - return self.html.parse(text) + return await self.html.parse(text) raise ValueError('parse_mode must be one of {} or None. Not "{}"'.format( ", ".join('"{}"'.format(m) for m in pyrogram.Client.PARSE_MODES[:-1]), diff --git a/pyrogram/client/types/bots_and_keyboards/callback_query.py b/pyrogram/client/types/bots_and_keyboards/callback_query.py index ec4048fc..a15a7c35 100644 --- a/pyrogram/client/types/bots_and_keyboards/callback_query.py +++ b/pyrogram/client/types/bots_and_keyboards/callback_query.py @@ -89,12 +89,12 @@ class CallbackQuery(Object, Update): self.matches = matches @staticmethod - def _parse(client, callback_query, users) -> "CallbackQuery": + async def _parse(client, callback_query, users) -> "CallbackQuery": message = None inline_message_id = None if isinstance(callback_query, types.UpdateBotCallbackQuery): - message = client.get_messages(utils.get_peer_id(callback_query.peer), callback_query.msg_id) + message = await client.get_messages(utils.get_peer_id(callback_query.peer), callback_query.msg_id) elif isinstance(callback_query, types.UpdateInlineBotCallbackQuery): inline_message_id = b64encode( pack( @@ -124,7 +124,7 @@ class CallbackQuery(Object, Update): client=client ) - def answer(self, text: str = None, show_alert: bool = None, url: str = None, cache_time: int = 0): + async def answer(self, text: str = None, show_alert: bool = None, url: str = None, cache_time: int = 0): """Bound method *answer* of :obj:`CallbackQuery`. Use this method as a shortcut for: @@ -160,7 +160,7 @@ class CallbackQuery(Object, Update): The maximum amount of time in seconds that the result of the callback query may be cached client-side. Telegram apps will support caching starting in version 3.14. Defaults to 0. """ - return self._client.answer_callback_query( + return await self._client.answer_callback_query( callback_query_id=self.id, text=text, show_alert=show_alert, @@ -168,7 +168,7 @@ class CallbackQuery(Object, Update): cache_time=cache_time ) - def edit_message_text( + async def edit_message_text( self, text: str, parse_mode: Union[str, None] = object, @@ -204,7 +204,7 @@ class CallbackQuery(Object, Update): RPCError: In case of a Telegram RPC error. """ if self.inline_message_id is None: - return self._client.edit_message_text( + return await self._client.edit_message_text( chat_id=self.message.chat.id, message_id=self.message.message_id, text=text, @@ -213,7 +213,7 @@ class CallbackQuery(Object, Update): reply_markup=reply_markup ) else: - return self._client.edit_inline_text( + return await self._client.edit_inline_text( inline_message_id=self.inline_message_id, text=text, parse_mode=parse_mode, @@ -221,7 +221,7 @@ class CallbackQuery(Object, Update): reply_markup=reply_markup ) - def edit_message_caption( + async def edit_message_caption( self, caption: str, parse_mode: Union[str, None] = object, @@ -252,9 +252,9 @@ class CallbackQuery(Object, Update): Raises: RPCError: In case of a Telegram RPC error. """ - return self.edit_message_text(caption, parse_mode, reply_markup) + return await self.edit_message_text(caption, parse_mode, reply_markup) - def edit_message_media( + async def edit_message_media( self, media: "pyrogram.InputMedia", reply_markup: "pyrogram.InlineKeyboardMarkup" = None @@ -278,20 +278,20 @@ class CallbackQuery(Object, Update): RPCError: In case of a Telegram RPC error. """ if self.inline_message_id is None: - return self._client.edit_message_media( + return await self._client.edit_message_media( chat_id=self.message.chat.id, message_id=self.message.message_id, media=media, reply_markup=reply_markup ) else: - return self._client.edit_inline_media( + return await self._client.edit_inline_media( inline_message_id=self.inline_message_id, media=media, reply_markup=reply_markup ) - def edit_message_reply_markup( + async def edit_message_reply_markup( self, reply_markup: "pyrogram.InlineKeyboardMarkup" = None ) -> Union["pyrogram.Message", bool]: @@ -311,13 +311,13 @@ class CallbackQuery(Object, Update): RPCError: In case of a Telegram RPC error. """ if self.inline_message_id is None: - return self._client.edit_message_reply_markup( + return await self._client.edit_message_reply_markup( chat_id=self.message.chat.id, message_id=self.message.message_id, reply_markup=reply_markup ) else: - return self._client.edit_inline_reply_markup( + return await self._client.edit_inline_reply_markup( inline_message_id=self.inline_message_id, reply_markup=reply_markup ) diff --git a/pyrogram/client/types/inline_mode/inline_query.py b/pyrogram/client/types/inline_mode/inline_query.py index eadc539c..c48bb053 100644 --- a/pyrogram/client/types/inline_mode/inline_query.py +++ b/pyrogram/client/types/inline_mode/inline_query.py @@ -88,7 +88,7 @@ class InlineQuery(Object, Update): client=client ) - def answer( + async def answer( self, results: List[InlineQueryResult], cache_time: int = 300, @@ -151,7 +151,7 @@ class InlineQuery(Object, Update): where they wanted to use the bot's inline capabilities. """ - return self._client.answer_inline_query( + return await self._client.answer_inline_query( inline_query_id=self.id, results=results, cache_time=cache_time, diff --git a/pyrogram/client/types/inline_mode/inline_query_result.py b/pyrogram/client/types/inline_mode/inline_query_result.py index c815aedf..6525585b 100644 --- a/pyrogram/client/types/inline_mode/inline_query_result.py +++ b/pyrogram/client/types/inline_mode/inline_query_result.py @@ -67,5 +67,5 @@ class InlineQueryResult(Object): self.input_message_content = input_message_content self.reply_markup = reply_markup - def write(self): + async def write(self): pass diff --git a/pyrogram/client/types/inline_mode/inline_query_result_animation.py b/pyrogram/client/types/inline_mode/inline_query_result_animation.py index 756ee91a..d53bbd59 100644 --- a/pyrogram/client/types/inline_mode/inline_query_result_animation.py +++ b/pyrogram/client/types/inline_mode/inline_query_result_animation.py @@ -91,7 +91,7 @@ class InlineQueryResultAnimation(InlineQueryResult): self.reply_markup = reply_markup self.input_message_content = input_message_content - def write(self): + async def write(self): animation = types.InputWebDocument( url=self.animation_url, size=0, @@ -121,7 +121,7 @@ class InlineQueryResultAnimation(InlineQueryResult): if self.input_message_content else types.InputBotInlineMessageMediaAuto( reply_markup=self.reply_markup.write() if self.reply_markup else None, - **(Parser(None)).parse(self.caption, self.parse_mode) + **await(Parser(None)).parse(self.caption, self.parse_mode) ) ) ) diff --git a/pyrogram/client/types/inline_mode/inline_query_result_article.py b/pyrogram/client/types/inline_mode/inline_query_result_article.py index 900bf477..65e85a0f 100644 --- a/pyrogram/client/types/inline_mode/inline_query_result_article.py +++ b/pyrogram/client/types/inline_mode/inline_query_result_article.py @@ -66,11 +66,11 @@ class InlineQueryResultArticle(InlineQueryResult): self.description = description self.thumb_url = thumb_url - def write(self): + async def write(self): return types.InputBotInlineResult( id=self.id, type=self.type, - send_message=self.input_message_content.write(self.reply_markup), + send_message=await self.input_message_content.write(self.reply_markup), title=self.title, description=self.description, url=self.url, diff --git a/pyrogram/client/types/inline_mode/inline_query_result_photo.py b/pyrogram/client/types/inline_mode/inline_query_result_photo.py index 5905c14e..e3890b9b 100644 --- a/pyrogram/client/types/inline_mode/inline_query_result_photo.py +++ b/pyrogram/client/types/inline_mode/inline_query_result_photo.py @@ -91,7 +91,7 @@ class InlineQueryResultPhoto(InlineQueryResult): self.reply_markup = reply_markup self.input_message_content = input_message_content - def write(self): + async def write(self): photo = types.InputWebDocument( url=self.photo_url, size=0, @@ -117,11 +117,11 @@ class InlineQueryResultPhoto(InlineQueryResult): thumb=thumb, content=photo, send_message=( - self.input_message_content.write(self.reply_markup) + await self.input_message_content.write(self.reply_markup) if self.input_message_content else types.InputBotInlineMessageMediaAuto( reply_markup=self.reply_markup.write() if self.reply_markup else None, - **(Parser(None)).parse(self.caption, self.parse_mode) + **await(Parser(None)).parse(self.caption, self.parse_mode) ) ) ) diff --git a/pyrogram/client/types/input_message_content/input_text_message_content.py b/pyrogram/client/types/input_message_content/input_text_message_content.py index 699ab80d..1247011e 100644 --- a/pyrogram/client/types/input_message_content/input_text_message_content.py +++ b/pyrogram/client/types/input_message_content/input_text_message_content.py @@ -48,9 +48,9 @@ class InputTextMessageContent(InputMessageContent): self.parse_mode = parse_mode self.disable_web_page_preview = disable_web_page_preview - def write(self, reply_markup): + async def write(self, reply_markup): return types.InputBotInlineMessageText( no_webpage=self.disable_web_page_preview or None, reply_markup=reply_markup.write() if reply_markup else None, - **(Parser(None)).parse(self.message_text, self.parse_mode) + **await(Parser(None)).parse(self.message_text, self.parse_mode) ) diff --git a/pyrogram/client/types/messages_and_media/message.py b/pyrogram/client/types/messages_and_media/message.py index 0d1ee2b6..8bf7cac4 100644 --- a/pyrogram/client/types/messages_and_media/message.py +++ b/pyrogram/client/types/messages_and_media/message.py @@ -398,8 +398,8 @@ class Message(Object, Update): self.reply_markup = reply_markup @staticmethod - def _parse(client, message: types.Message or types.MessageService or types.MessageEmpty, users: dict, chats: dict, - is_scheduled: bool = False, replies: int = 1): + async def _parse(client, message: types.Message or types.MessageService or types.MessageEmpty, users: dict, + chats: dict, is_scheduled: bool = False, replies: int = 1): if isinstance(message, types.MessageEmpty): return Message(message_id=message.id, empty=True, client=client) @@ -458,7 +458,7 @@ class Message(Object, Update): if isinstance(action, types.MessageActionPinMessage): try: - parsed_message.pinned_message = client.get_messages( + parsed_message.pinned_message = await client.get_messages( parsed_message.chat.id, reply_to_message_ids=message.id, replies=0 @@ -471,7 +471,7 @@ class Message(Object, Update): if message.reply_to_msg_id and replies: try: - parsed_message.reply_to_message = client.get_messages( + parsed_message.reply_to_message = await client.get_messages( parsed_message.chat.id, reply_to_message_ids=message.id, replies=0 @@ -567,7 +567,7 @@ class Message(Object, Update): video = pyrogram.Video._parse(client, doc, video_attributes, file_name, media.ttl_seconds) elif types.DocumentAttributeSticker in attributes: - sticker = pyrogram.Sticker._parse( + sticker = await pyrogram.Sticker._parse( client, doc, attributes.get(types.DocumentAttributeImageSize, None), attributes[types.DocumentAttributeSticker], @@ -663,7 +663,7 @@ class Message(Object, Update): if message.reply_to_msg_id and replies: try: - parsed_message.reply_to_message = client.get_messages( + parsed_message.reply_to_message = await client.get_messages( parsed_message.chat.id, reply_to_message_ids=message.id, replies=replies - 1 @@ -680,7 +680,7 @@ class Message(Object, Update): else: return "https://t.me/c/{}/{}".format(utils.get_channel_id(self.chat.id), self.message_id) - def reply_text( + async def reply_text( self, text: str, quote: bool = None, @@ -749,7 +749,7 @@ class Message(Object, Update): if reply_to_message_id is None and quote: reply_to_message_id = self.message_id - return self._client.send_message( + return await self._client.send_message( chat_id=self.chat.id, text=text, parse_mode=parse_mode, @@ -761,7 +761,7 @@ class Message(Object, Update): reply = reply_text - def reply_animation( + async def reply_animation( self, animation: Union[str, BinaryIO], file_ref: str = None, @@ -886,7 +886,7 @@ class Message(Object, Update): if reply_to_message_id is None and quote: reply_to_message_id = self.message_id - return self._client.send_animation( + return await self._client.send_animation( chat_id=self.chat.id, animation=animation, file_ref=file_ref, @@ -903,7 +903,7 @@ class Message(Object, Update): progress_args=progress_args ) - def reply_audio( + async def reply_audio( self, audio: Union[str, BinaryIO], file_ref: str = None, @@ -1028,7 +1028,7 @@ class Message(Object, Update): if reply_to_message_id is None and quote: reply_to_message_id = self.message_id - return self._client.send_audio( + return await self._client.send_audio( chat_id=self.chat.id, audio=audio, file_ref=file_ref, @@ -1045,7 +1045,7 @@ class Message(Object, Update): progress_args=progress_args ) - def reply_cached_media( + async def reply_cached_media( self, file_id: str, file_ref: str = None, @@ -1124,7 +1124,7 @@ class Message(Object, Update): if reply_to_message_id is None and quote: reply_to_message_id = self.message_id - return self._client.send_cached_media( + return await self._client.send_cached_media( chat_id=self.chat.id, file_id=file_id, file_ref=file_ref, @@ -1135,7 +1135,7 @@ class Message(Object, Update): reply_markup=reply_markup ) - def reply_chat_action(self, action: str) -> bool: + async def reply_chat_action(self, action: str) -> bool: """Bound method *reply_chat_action* of :obj:`Message`. Use as a shortcut for: @@ -1168,12 +1168,12 @@ class Message(Object, Update): RPCError: In case of a Telegram RPC error. ValueError: In case the provided string is not a valid chat action. """ - return self._client.send_chat_action( + return await self._client.send_chat_action( chat_id=self.chat.id, action=action ) - def reply_contact( + async def reply_contact( self, phone_number: str, first_name: str, @@ -1247,7 +1247,7 @@ class Message(Object, Update): if reply_to_message_id is None and quote: reply_to_message_id = self.message_id - return self._client.send_contact( + return await self._client.send_contact( chat_id=self.chat.id, phone_number=phone_number, first_name=first_name, @@ -1258,7 +1258,7 @@ class Message(Object, Update): reply_markup=reply_markup ) - def reply_document( + async def reply_document( self, document: Union[str, BinaryIO], file_ref: str = None, @@ -1371,7 +1371,7 @@ class Message(Object, Update): if reply_to_message_id is None and quote: reply_to_message_id = self.message_id - return self._client.send_document( + return await self._client.send_document( chat_id=self.chat.id, document=document, file_ref=file_ref, @@ -1385,7 +1385,7 @@ class Message(Object, Update): progress_args=progress_args ) - def reply_game( + async def reply_game( self, game_short_name: str, quote: bool = None, @@ -1446,7 +1446,7 @@ class Message(Object, Update): if reply_to_message_id is None and quote: reply_to_message_id = self.message_id - return self._client.send_game( + return await self._client.send_game( chat_id=self.chat.id, game_short_name=game_short_name, disable_notification=disable_notification, @@ -1454,7 +1454,7 @@ class Message(Object, Update): reply_markup=reply_markup ) - def reply_inline_bot_result( + async def reply_inline_bot_result( self, query_id: int, result_id: str, @@ -1514,7 +1514,7 @@ class Message(Object, Update): if reply_to_message_id is None and quote: reply_to_message_id = self.message_id - return self._client.send_inline_bot_result( + return await self._client.send_inline_bot_result( chat_id=self.chat.id, query_id=query_id, result_id=result_id, @@ -1523,7 +1523,7 @@ class Message(Object, Update): hide_via=hide_via ) - def reply_location( + async def reply_location( self, latitude: float, longitude: float, @@ -1589,7 +1589,7 @@ class Message(Object, Update): if reply_to_message_id is None and quote: reply_to_message_id = self.message_id - return self._client.send_location( + return await self._client.send_location( chat_id=self.chat.id, latitude=latitude, longitude=longitude, @@ -1598,7 +1598,7 @@ class Message(Object, Update): reply_markup=reply_markup ) - def reply_media_group( + async def reply_media_group( self, media: List[Union["pyrogram.InputMediaPhoto", "pyrogram.InputMediaVideo"]], quote: bool = None, @@ -1652,14 +1652,14 @@ class Message(Object, Update): if reply_to_message_id is None and quote: reply_to_message_id = self.message_id - return self._client.send_media_group( + return await self._client.send_media_group( chat_id=self.chat.id, media=media, disable_notification=disable_notification, reply_to_message_id=reply_to_message_id ) - def reply_photo( + async def reply_photo( self, photo: Union[str, BinaryIO], file_ref: str = None, @@ -1771,7 +1771,7 @@ class Message(Object, Update): if reply_to_message_id is None and quote: reply_to_message_id = self.message_id - return self._client.send_photo( + return await self._client.send_photo( chat_id=self.chat.id, photo=photo, file_ref=file_ref, @@ -1785,7 +1785,7 @@ class Message(Object, Update): progress_args=progress_args ) - def reply_poll( + async def reply_poll( self, question: str, options: List[str], @@ -1851,7 +1851,7 @@ class Message(Object, Update): if reply_to_message_id is None and quote: reply_to_message_id = self.message_id - return self._client.send_poll( + return await self._client.send_poll( chat_id=self.chat.id, question=question, options=options, @@ -1860,7 +1860,7 @@ class Message(Object, Update): reply_markup=reply_markup ) - def reply_sticker( + async def reply_sticker( self, sticker: Union[str, BinaryIO], file_ref: str = None, @@ -1954,7 +1954,7 @@ class Message(Object, Update): if reply_to_message_id is None and quote: reply_to_message_id = self.message_id - return self._client.send_sticker( + return await self._client.send_sticker( chat_id=self.chat.id, sticker=sticker, file_ref=file_ref, @@ -1965,7 +1965,7 @@ class Message(Object, Update): progress_args=progress_args ) - def reply_venue( + async def reply_venue( self, latitude: float, longitude: float, @@ -2050,7 +2050,7 @@ class Message(Object, Update): if reply_to_message_id is None and quote: reply_to_message_id = self.message_id - return self._client.send_venue( + return await self._client.send_venue( chat_id=self.chat.id, latitude=latitude, longitude=longitude, @@ -2063,7 +2063,7 @@ class Message(Object, Update): reply_markup=reply_markup ) - def reply_video( + async def reply_video( self, video: Union[str, BinaryIO], file_ref: str = None, @@ -2192,7 +2192,7 @@ class Message(Object, Update): if reply_to_message_id is None and quote: reply_to_message_id = self.message_id - return self._client.send_video( + return await self._client.send_video( chat_id=self.chat.id, video=video, file_ref=file_ref, @@ -2210,7 +2210,7 @@ class Message(Object, Update): progress_args=progress_args ) - def reply_video_note( + async def reply_video_note( self, video_note: Union[str, BinaryIO], file_ref: str = None, @@ -2319,7 +2319,7 @@ class Message(Object, Update): if reply_to_message_id is None and quote: reply_to_message_id = self.message_id - return self._client.send_video_note( + return await self._client.send_video_note( chat_id=self.chat.id, video_note=video_note, file_ref=file_ref, @@ -2333,7 +2333,7 @@ class Message(Object, Update): progress_args=progress_args ) - def reply_voice( + async def reply_voice( self, voice: Union[str, BinaryIO], file_ref: str = None, @@ -2443,7 +2443,7 @@ class Message(Object, Update): if reply_to_message_id is None and quote: reply_to_message_id = self.message_id - return self._client.send_voice( + return await self._client.send_voice( chat_id=self.chat.id, voice=voice, file_ref=file_ref, @@ -2457,7 +2457,7 @@ class Message(Object, Update): progress_args=progress_args ) - def edit_text( + async def edit_text( self, text: str, parse_mode: Union[str, None] = object, @@ -2504,7 +2504,7 @@ class Message(Object, Update): Raises: RPCError: In case of a Telegram RPC error. """ - return self._client.edit_message_text( + return await self._client.edit_message_text( chat_id=self.chat.id, message_id=self.message_id, text=text, @@ -2515,7 +2515,7 @@ class Message(Object, Update): edit = edit_text - def edit_caption( + async def edit_caption( self, caption: str, parse_mode: Union[str, None] = object, @@ -2558,7 +2558,7 @@ class Message(Object, Update): Raises: RPCError: In case of a Telegram RPC error. """ - return self._client.edit_message_caption( + return await self._client.edit_message_caption( chat_id=self.chat.id, message_id=self.message_id, caption=caption, @@ -2566,7 +2566,7 @@ class Message(Object, Update): reply_markup=reply_markup ) - def edit_media(self, media: InputMedia, reply_markup: "pyrogram.InlineKeyboardMarkup" = None) -> "Message": + async def edit_media(self, media: InputMedia, reply_markup: "pyrogram.InlineKeyboardMarkup" = None) -> "Message": """Bound method *edit_media* of :obj:`Message`. Use as a shortcut for: @@ -2597,14 +2597,14 @@ class Message(Object, Update): Raises: RPCError: In case of a Telegram RPC error. """ - return self._client.edit_message_media( + return await self._client.edit_message_media( chat_id=self.chat.id, message_id=self.message_id, media=media, reply_markup=reply_markup ) - def edit_reply_markup(self, reply_markup: "pyrogram.InlineKeyboardMarkup" = None) -> "Message": + async def edit_reply_markup(self, reply_markup: "pyrogram.InlineKeyboardMarkup" = None) -> "Message": """Bound method *edit_reply_markup* of :obj:`Message`. Use as a shortcut for: @@ -2633,13 +2633,13 @@ class Message(Object, Update): Raises: RPCError: In case of a Telegram RPC error. """ - return self._client.edit_message_reply_markup( + return await self._client.edit_message_reply_markup( chat_id=self.chat.id, message_id=self.message_id, reply_markup=reply_markup ) - def forward( + async def forward( self, chat_id: int or str, disable_notification: bool = None, @@ -2700,7 +2700,7 @@ class Message(Object, Update): raise ValueError("Users cannot send messages with Game media type") if self.text: - return self._client.send_message( + return await self._client.send_message( chat_id, text=self.text.html, parse_mode="html", @@ -2743,7 +2743,7 @@ class Message(Object, Update): file_id = self.video_note.file_id file_ref = self.video_note.file_ref elif self.contact: - return self._client.send_contact( + return await self._client.send_contact( chat_id, phone_number=self.contact.phone_number, first_name=self.contact.first_name, @@ -2753,7 +2753,7 @@ class Message(Object, Update): schedule_date=schedule_date ) elif self.location: - return self._client.send_location( + return await self._client.send_location( chat_id, latitude=self.location.latitude, longitude=self.location.longitude, @@ -2761,7 +2761,7 @@ class Message(Object, Update): schedule_date=schedule_date ) elif self.venue: - return self._client.send_venue( + return await self._client.send_venue( chat_id, latitude=self.venue.location.latitude, longitude=self.venue.location.longitude, @@ -2773,7 +2773,7 @@ class Message(Object, Update): schedule_date=schedule_date ) elif self.poll: - return self._client.send_poll( + return await self._client.send_poll( chat_id, question=self.poll.question, options=[opt.text for opt in self.poll.options], @@ -2781,7 +2781,7 @@ class Message(Object, Update): schedule_date=schedule_date ) elif self.game: - return self._client.send_game( + return await self._client.send_game( chat_id, game_short_name=self.game.short_name, disable_notification=disable_notification @@ -2790,13 +2790,13 @@ class Message(Object, Update): raise ValueError("Unknown media type") if self.sticker or self.video_note: # Sticker and VideoNote should have no caption - return send_media(file_id=file_id, file_ref=file_ref) + return await send_media(file_id=file_id, file_ref=file_ref) else: - return send_media(file_id=file_id, file_ref=file_ref, caption=caption, parse_mode="html") + return await send_media(file_id=file_id, file_ref=file_ref, caption=caption, parse_mode="html") else: raise ValueError("Can't copy this message") else: - return self._client.forward_messages( + return await self._client.forward_messages( chat_id=chat_id, from_chat_id=self.chat.id, message_ids=self.message_id, @@ -2804,7 +2804,7 @@ class Message(Object, Update): schedule_date=schedule_date ) - def delete(self, revoke: bool = True): + async def delete(self, revoke: bool = True): """Bound method *delete* of :obj:`Message`. Use as a shortcut for: @@ -2834,13 +2834,13 @@ class Message(Object, Update): Raises: RPCError: In case of a Telegram RPC error. """ - return self._client.delete_messages( + return await self._client.delete_messages( chat_id=self.chat.id, message_ids=self.message_id, revoke=revoke ) - def click(self, x: int or str = 0, y: int = None, quote: bool = None, timeout: int = 10): + async def click(self, x: int or str = 0, y: int = None, quote: bool = None, timeout: int = 10): """Bound method *click* of :obj:`Message`. Use as a shortcut for clicking a button attached to the message instead of: @@ -2944,7 +2944,7 @@ class Message(Object, Update): if is_inline: if button.callback_data: - return self._client.request_callback_answer( + return await self._client.request_callback_answer( chat_id=self.chat.id, message_id=self.message_id, callback_data=button.callback_data, @@ -2959,9 +2959,9 @@ class Message(Object, Update): else: raise ValueError("This button is not supported yet") else: - self.reply(button, quote=quote) + await self.reply(button, quote=quote) - def retract_vote( + async def retract_vote( self, ) -> "pyrogram.Poll": """Bound method *retract_vote* of :obj:`Message`. @@ -2987,12 +2987,12 @@ class Message(Object, Update): RPCError: In case of a Telegram RPC error. """ - return self._client.retract_vote( + return await self._client.retract_vote( chat_id=self.chat.id, message_id=self.message_id ) - def download( + async def download( self, file_name: str = "", block: bool = True, @@ -3052,7 +3052,7 @@ class Message(Object, Update): RPCError: In case of a Telegram RPC error. ``ValueError``: If the message doesn't contain any downloadable media """ - return self._client.download_media( + return await self._client.download_media( message=self, file_name=file_name, block=block, @@ -3060,7 +3060,7 @@ class Message(Object, Update): progress_args=progress_args, ) - def vote( + async def vote( self, option: int, ) -> "pyrogram.Poll": @@ -3092,13 +3092,13 @@ class Message(Object, Update): RPCError: In case of a Telegram RPC error. """ - return self._client.vote_poll( + return await self._client.vote_poll( chat_id=self.chat.id, message_id=self.message_id, option=option ) - def pin(self, disable_notification: bool = None) -> "Message": + async def pin(self, disable_notification: bool = None) -> "Message": """Bound method *pin* of :obj:`Message`. Use as a shortcut for: @@ -3126,7 +3126,7 @@ class Message(Object, Update): Raises: RPCError: In case of a Telegram RPC error. """ - return self._client.pin_chat_message( + return await self._client.pin_chat_message( chat_id=self.chat.id, message_id=self.message_id, disable_notification=disable_notification diff --git a/pyrogram/client/types/messages_and_media/sticker.py b/pyrogram/client/types/messages_and_media/sticker.py index b434e451..aca7d3a3 100644 --- a/pyrogram/client/types/messages_and_media/sticker.py +++ b/pyrogram/client/types/messages_and_media/sticker.py @@ -16,10 +16,11 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . -from functools import lru_cache from struct import pack from typing import List +from async_lru import alru_cache + import pyrogram from pyrogram.api import types, functions from pyrogram.errors import StickersetInvalid @@ -105,28 +106,28 @@ class Sticker(Object): # self.mask_position = mask_position @staticmethod - @lru_cache(maxsize=256) - def _get_sticker_set_name(send, input_sticker_set_id): + @alru_cache(maxsize=256) + async def _get_sticker_set_name(send, input_sticker_set_id): try: - return send( + return (await send( functions.messages.GetStickerSet( stickerset=types.InputStickerSetID( id=input_sticker_set_id[0], access_hash=input_sticker_set_id[1] ) ) - ).set.short_name + )).set.short_name except StickersetInvalid: return None @staticmethod - def _parse(client, sticker: types.Document, image_size_attributes: types.DocumentAttributeImageSize, - sticker_attributes: types.DocumentAttributeSticker, file_name: str) -> "Sticker": + async def _parse(client, sticker: types.Document, image_size_attributes: types.DocumentAttributeImageSize, + sticker_attributes: types.DocumentAttributeSticker, file_name: str) -> "Sticker": sticker_set = sticker_attributes.stickerset if isinstance(sticker_set, types.InputStickerSetID): input_sticker_set_id = (sticker_set.id, sticker_set.access_hash) - set_name = Sticker._get_sticker_set_name(client.send, input_sticker_set_id) + set_name = await Sticker._get_sticker_set_name(client.send, input_sticker_set_id) else: set_name = None diff --git a/pyrogram/client/types/update.py b/pyrogram/client/types/update.py index 1e70944a..3f70f580 100644 --- a/pyrogram/client/types/update.py +++ b/pyrogram/client/types/update.py @@ -16,11 +16,11 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . -class StopPropagation(StopIteration): +class StopPropagation(StopAsyncIteration): pass -class ContinuePropagation(StopIteration): +class ContinuePropagation(StopAsyncIteration): pass diff --git a/pyrogram/client/types/user_and_chats/chat.py b/pyrogram/client/types/user_and_chats/chat.py index fe103573..85ef0a1f 100644 --- a/pyrogram/client/types/user_and_chats/chat.py +++ b/pyrogram/client/types/user_and_chats/chat.py @@ -233,13 +233,13 @@ class Chat(Object): return Chat._parse_channel_chat(client, chats[peer.channel_id]) @staticmethod - def _parse_full(client, chat_full: types.messages.ChatFull or types.UserFull) -> "Chat": + async def _parse_full(client, chat_full: types.messages.ChatFull or types.UserFull) -> "Chat": if isinstance(chat_full, types.UserFull): parsed_chat = Chat._parse_user_chat(client, chat_full.user) parsed_chat.description = chat_full.about if chat_full.pinned_msg_id: - parsed_chat.pinned_message = client.get_messages( + parsed_chat.pinned_message = await client.get_messages( parsed_chat.id, message_ids=chat_full.pinned_msg_id ) @@ -273,7 +273,7 @@ class Chat(Object): parsed_chat.linked_chat = Chat._parse_channel_chat(client, linked_chat) if full_chat.pinned_msg_id: - parsed_chat.pinned_message = client.get_messages( + parsed_chat.pinned_message = await client.get_messages( parsed_chat.id, message_ids=full_chat.pinned_msg_id ) @@ -292,7 +292,7 @@ class Chat(Object): else: return Chat._parse_channel_chat(client, chat) - def archive(self): + async def archive(self): """Bound method *archive* of :obj:`Chat`. Use as a shortcut for: @@ -313,9 +313,9 @@ class Chat(Object): RPCError: In case of a Telegram RPC error. """ - return self._client.archive_chats(self.id) + return await self._client.archive_chats(self.id) - def unarchive(self): + async def unarchive(self): """Bound method *unarchive* of :obj:`Chat`. Use as a shortcut for: @@ -336,10 +336,10 @@ class Chat(Object): RPCError: In case of a Telegram RPC error. """ - return self._client.unarchive_chats(self.id) + return await self._client.unarchive_chats(self.id) # TODO: Remove notes about "All Members Are Admins" for basic groups, the attribute doesn't exist anymore - def set_title(self, title: str) -> bool: + async def set_title(self, title: str) -> bool: """Bound method *set_title* of :obj:`Chat`. Use as a shortcut for: @@ -372,12 +372,12 @@ class Chat(Object): ValueError: In case a chat_id belongs to user. """ - return self._client.set_chat_title( + return await self._client.set_chat_title( chat_id=self.id, title=title ) - def set_description(self, description: str) -> bool: + async def set_description(self, description: str) -> bool: """Bound method *set_description* of :obj:`Chat`. Use as a shortcut for: @@ -406,12 +406,12 @@ class Chat(Object): ValueError: If a chat_id doesn't belong to a supergroup or a channel. """ - return self._client.set_chat_description( + return await self._client.set_chat_description( chat_id=self.id, description=description ) - def set_photo(self, photo: str) -> bool: + async def set_photo(self, photo: str) -> bool: """Bound method *set_photo* of :obj:`Chat`. Use as a shortcut for: @@ -440,12 +440,12 @@ class Chat(Object): ValueError: if a chat_id belongs to user. """ - return self._client.set_chat_photo( + return await self._client.set_chat_photo( chat_id=self.id, photo=photo ) - def kick_member( + async def kick_member( self, user_id: Union[int, str], until_date: int = 0 @@ -489,13 +489,13 @@ class Chat(Object): RPCError: In case of a Telegram RPC error. """ - return self._client.kick_chat_member( + return await self._client.kick_chat_member( chat_id=self.id, user_id=user_id, until_date=until_date ) - def unban_member( + async def unban_member( self, user_id: Union[int, str] ) -> bool: @@ -527,12 +527,12 @@ class Chat(Object): RPCError: In case of a Telegram RPC error. """ - return self._client.unban_chat_member( + return await self._client.unban_chat_member( chat_id=self.id, user_id=user_id, ) - def restrict_member( + async def restrict_member( self, user_id: Union[int, str], permissions: ChatPermissions, @@ -575,14 +575,14 @@ class Chat(Object): RPCError: In case of a Telegram RPC error. """ - return self._client.restrict_chat_member( + return await self._client.restrict_chat_member( chat_id=self.id, user_id=user_id, permissions=permissions, until_date=until_date, ) - def promote_member( + async def promote_member( self, user_id: Union[int, str], can_change_info: bool = True, @@ -649,7 +649,7 @@ class Chat(Object): RPCError: In case of a Telegram RPC error. """ - return self._client.promote_chat_member( + return await self._client.promote_chat_member( chat_id=self.id, user_id=user_id, can_change_info=can_change_info, @@ -662,7 +662,7 @@ class Chat(Object): can_promote_members=can_promote_members ) - def join(self): + async def join(self): """Bound method *join* of :obj:`Chat`. Use as a shortcut for: @@ -686,9 +686,9 @@ class Chat(Object): RPCError: In case of a Telegram RPC error. """ - return self._client.join_chat(self.username or self.id) + return await self._client.join_chat(self.username or self.id) - def leave(self): + async def leave(self): """Bound method *leave* of :obj:`Chat`. Use as a shortcut for: @@ -706,9 +706,9 @@ class Chat(Object): RPCError: In case of a Telegram RPC error. """ - return self._client.leave_chat(self.id) + return await self._client.leave_chat(self.id) - def export_invite_link(self): + async def export_invite_link(self): """Bound method *export_invite_link* of :obj:`Chat`. Use as a shortcut for: @@ -729,9 +729,9 @@ class Chat(Object): ValueError: In case the chat_id belongs to a user. """ - return self._client.export_chat_invite_link(self.id) + return await self._client.export_chat_invite_link(self.id) - def get_member( + async def get_member( self, user_id: Union[int, str], ) -> "pyrogram.ChatMember": @@ -755,12 +755,12 @@ class Chat(Object): :obj:`ChatMember`: On success, a chat member is returned. """ - return self._client.get_chat_member( + return await self._client.get_chat_member( self.id, user_id=user_id ) - def get_members( + async def get_members( self, offset: int = 0, limit: int = 200, @@ -785,7 +785,7 @@ class Chat(Object): List of :obj:`ChatMember`: On success, a list of chat members is returned. """ - return self._client.get_chat_members( + return await self._client.get_chat_members( self.id, offset=offset, limit=limit, @@ -825,7 +825,7 @@ class Chat(Object): filter=filter ) - def add_members( + async def add_members( self, user_ids: Union[Union[int, str], List[Union[int, str]]], forward_limit: int = 100 @@ -847,7 +847,7 @@ class Chat(Object): ``bool``: On success, True is returned. """ - return self._client.add_chat_members( + return await self._client.add_chat_members( self.id, user_ids=user_ids, forward_limit=forward_limit diff --git a/pyrogram/client/types/user_and_chats/user.py b/pyrogram/client/types/user_and_chats/user.py index b90db42a..f386e161 100644 --- a/pyrogram/client/types/user_and_chats/user.py +++ b/pyrogram/client/types/user_and_chats/user.py @@ -231,7 +231,7 @@ class User(Object, Update): client=client ) - def archive(self): + async def archive(self): """Bound method *archive* of :obj:`User`. Use as a shortcut for: @@ -252,9 +252,9 @@ class User(Object, Update): RPCError: In case of a Telegram RPC error. """ - return self._client.archive_chats(self.id) + return await self._client.archive_chats(self.id) - def unarchive(self): + async def unarchive(self): """Bound method *unarchive* of :obj:`User`. Use as a shortcut for: @@ -275,7 +275,7 @@ class User(Object, Update): RPCError: In case of a Telegram RPC error. """ - return self._client.unarchive_chats(self.id) + return await self._client.unarchive_chats(self.id) def block(self): """Bound method *block* of :obj:`User`. diff --git a/pyrogram/connection/connection.py b/pyrogram/connection/connection.py index 50fb3b7f..3a2126e9 100644 --- a/pyrogram/connection/connection.py +++ b/pyrogram/connection/connection.py @@ -16,9 +16,8 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +import asyncio import logging -import threading -import time from .transport import * from ..session.internals import DataCenter @@ -45,20 +44,19 @@ class Connection: self.address = DataCenter(dc_id, test_mode, ipv6) self.mode = self.MODES.get(mode, TCPAbridged) - self.lock = threading.Lock() - self.connection = None + self.protocol = None # type: TCP - def connect(self): + async def connect(self): for i in range(Connection.MAX_RETRIES): - self.connection = self.mode(self.ipv6, self.proxy) + self.protocol = self.mode(self.ipv6, self.proxy) try: log.info("Connecting...") - self.connection.connect(self.address) + await self.protocol.connect(self.address) except OSError as e: log.warning(e) # TODO: Remove - self.connection.close() - time.sleep(1) + self.protocol.close() + await asyncio.sleep(1) else: log.info("Connected! {} DC{} - IPv{} - {}".format( "Test" if self.test_mode else "Production", @@ -72,12 +70,14 @@ class Connection: raise TimeoutError def close(self): - self.connection.close() + self.protocol.close() log.info("Disconnected") - def send(self, data: bytes): - with self.lock: - self.connection.sendall(data) + async def send(self, data: bytes): + try: + await self.protocol.send(data) + except Exception: + raise OSError - def recv(self) -> bytes or None: - return self.connection.recvall() + async def recv(self) -> bytes or None: + return await self.protocol.recv() diff --git a/pyrogram/connection/transport/tcp/__init__.py b/pyrogram/connection/transport/tcp/__init__.py index 9628d99e..1909e723 100644 --- a/pyrogram/connection/transport/tcp/__init__.py +++ b/pyrogram/connection/transport/tcp/__init__.py @@ -16,6 +16,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +from .tcp import TCP from .tcp_abridged import TCPAbridged from .tcp_abridged_o import TCPAbridgedO from .tcp_full import TCPFull diff --git a/pyrogram/connection/transport/tcp/tcp.py b/pyrogram/connection/transport/tcp/tcp.py index db1c3ee7..070907f4 100644 --- a/pyrogram/connection/transport/tcp/tcp.py +++ b/pyrogram/connection/transport/tcp/tcp.py @@ -16,6 +16,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +import asyncio import ipaddress import logging import socket @@ -34,8 +35,17 @@ except ImportError as e: log = logging.getLogger(__name__) -class TCP(socks.socksocket): +class TCP: + TIMEOUT = 10 + def __init__(self, ipv6: bool, proxy: dict): + self.socket = None + + self.reader = None # type: asyncio.StreamReader + self.writer = None # type: asyncio.StreamWriter + + self.lock = asyncio.Lock() + if proxy.get("enabled", False): hostname = proxy.get("hostname", None) port = proxy.get("port", None) @@ -43,14 +53,14 @@ class TCP(socks.socksocket): try: ip_address = ipaddress.ip_address(hostname) except ValueError: - super().__init__(socket.AF_INET) + self.socket = socks.socksocket(socket.AF_INET) else: if isinstance(ip_address, ipaddress.IPv6Address): - super().__init__(socket.AF_INET6) + self.socket = socks.socksocket(socket.AF_INET6) else: - super().__init__(socket.AF_INET) + self.socket = socks.socksocket(socket.AF_INET) - self.set_proxy( + self.socket.set_proxy( proxy_type=socks.SOCKS5, addr=hostname, port=port, @@ -60,35 +70,50 @@ class TCP(socks.socksocket): log.info("Using proxy {}:{}".format(hostname, port)) else: - super().__init__( + self.socket = socks.socksocket( socket.AF_INET6 if ipv6 else socket.AF_INET ) - self.settimeout(10) + self.socket.settimeout(TCP.TIMEOUT) + + async def connect(self, address: tuple): + self.socket.connect(address) + self.reader, self.writer = await asyncio.open_connection(sock=self.socket) def close(self): try: - self.shutdown(socket.SHUT_RDWR) - except OSError: - pass - finally: - # A tiny sleep placed here helps avoiding .recv(n) hanging until the timeout. - # This is a workaround that seems to fix the occasional delayed stop of a client. - time.sleep(0.001) - super().close() + self.writer.close() + except AttributeError: + try: + self.socket.shutdown(socket.SHUT_RDWR) + except OSError: + pass + finally: + # A tiny sleep placed here helps avoiding .recv(n) hanging until the timeout. + # This is a workaround that seems to fix the occasional delayed stop of a client. + time.sleep(0.001) + self.socket.close() - def recvall(self, length: int) -> bytes or None: + async def send(self, data: bytes): + async with self.lock: + self.writer.write(data) + await self.writer.drain() + + async def recv(self, length: int = 0): data = b"" while len(data) < length: try: - packet = super().recv(length - len(data)) - except OSError: + chunk = await asyncio.wait_for( + self.reader.read(length - len(data)), + TCP.TIMEOUT + ) + except (OSError, asyncio.TimeoutError): return None else: - if packet: - data += packet + if chunk: + data += chunk else: return None diff --git a/pyrogram/connection/transport/tcp/tcp_abridged.py b/pyrogram/connection/transport/tcp/tcp_abridged.py index e828aa4c..4b4de7b2 100644 --- a/pyrogram/connection/transport/tcp/tcp_abridged.py +++ b/pyrogram/connection/transport/tcp/tcp_abridged.py @@ -27,30 +27,30 @@ class TCPAbridged(TCP): def __init__(self, ipv6: bool, proxy: dict): super().__init__(ipv6, proxy) - def connect(self, address: tuple): - super().connect(address) - super().sendall(b"\xef") + async def connect(self, address: tuple): + await super().connect(address) + await super().send(b"\xef") - def sendall(self, data: bytes, *args): + async def send(self, data: bytes, *args): length = len(data) // 4 - super().sendall( + await super().send( (bytes([length]) if length <= 126 else b"\x7f" + length.to_bytes(3, "little")) + data ) - def recvall(self, length: int = 0) -> bytes or None: - length = super().recvall(1) + async def recv(self, length: int = 0) -> bytes or None: + length = await super().recv(1) if length is None: return None if length == b"\x7f": - length = super().recvall(3) + length = await super().recv(3) if length is None: return None - return super().recvall(int.from_bytes(length, "little") * 4) + return await super().recv(int.from_bytes(length, "little") * 4) diff --git a/pyrogram/connection/transport/tcp/tcp_abridged_o.py b/pyrogram/connection/transport/tcp/tcp_abridged_o.py index 8417735c..e8f8fba0 100644 --- a/pyrogram/connection/transport/tcp/tcp_abridged_o.py +++ b/pyrogram/connection/transport/tcp/tcp_abridged_o.py @@ -34,8 +34,8 @@ class TCPAbridgedO(TCP): self.encrypt = None self.decrypt = None - def connect(self, address: tuple): - super().connect(address) + async def connect(self, address: tuple): + await super().connect(address) while True: nonce = bytearray(os.urandom(64)) @@ -51,12 +51,12 @@ class TCPAbridgedO(TCP): nonce[56:64] = AES.ctr256_encrypt(nonce, *self.encrypt)[56:64] - super().sendall(nonce) + await super().send(nonce) - def sendall(self, data: bytes, *args): + async def send(self, data: bytes, *args): length = len(data) // 4 - super().sendall( + await super().send( AES.ctr256_encrypt( (bytes([length]) if length <= 126 @@ -66,8 +66,8 @@ class TCPAbridgedO(TCP): ) ) - def recvall(self, length: int = 0) -> bytes or None: - length = super().recvall(1) + async def recv(self, length: int = 0) -> bytes or None: + length = await super().recv(1) if length is None: return None @@ -75,14 +75,14 @@ class TCPAbridgedO(TCP): length = AES.ctr256_decrypt(length, *self.decrypt) if length == b"\x7f": - length = super().recvall(3) + length = await super().recv(3) if length is None: return None length = AES.ctr256_decrypt(length, *self.decrypt) - data = super().recvall(int.from_bytes(length, "little") * 4) + data = await super().recv(int.from_bytes(length, "little") * 4) if data is None: return None diff --git a/pyrogram/connection/transport/tcp/tcp_full.py b/pyrogram/connection/transport/tcp/tcp_full.py index 366c1e65..0490fa65 100644 --- a/pyrogram/connection/transport/tcp/tcp_full.py +++ b/pyrogram/connection/transport/tcp/tcp_full.py @@ -31,34 +31,33 @@ class TCPFull(TCP): self.seq_no = None - def connect(self, address: tuple): - super().connect(address) + async def connect(self, address: tuple): + await super().connect(address) self.seq_no = 0 - def sendall(self, data: bytes, *args): - # 12 = packet_length (4), seq_no (4), crc32 (4) (at the end) + async def send(self, data: bytes, *args): data = pack(" bytes or None: - length = super().recvall(4) + async def recv(self, length: int = 0) -> bytes or None: + length = await super().recv(4) if length is None: return None - packet = super().recvall(unpack(" bytes or None: - length = super().recvall(4) + async def recv(self, length: int = 0) -> bytes or None: + length = await super().recv(4) if length is None: return None - return super().recvall(unpack(" bytes or None: - length = super().recvall(4) + async def recv(self, length: int = 0) -> bytes or None: + length = await super().recv(4) if length is None: return None length = AES.ctr256_decrypt(length, *self.decrypt) - data = super().recvall(unpack(" +# +# This file is part of Pyrogram. +# +# Pyrogram is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Pyrogram is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Pyrogram. If not, see . + +from hashlib import sha256 +from io import BytesIO +from os import urandom + +from pyrogram.api.core import Message, Long +from . import AES, KDF + + +class MTProto: + @staticmethod + def pack(message: Message, salt: int, session_id: bytes, auth_key: bytes, auth_key_id: bytes) -> bytes: + data = Long(salt) + session_id + message.write() + padding = urandom(-(len(data) + 12) % 16 + 12) + + # 88 = 88 + 0 (outgoing message) + msg_key_large = sha256(auth_key[88: 88 + 32] + data + padding).digest() + msg_key = msg_key_large[8:24] + aes_key, aes_iv = KDF(auth_key, msg_key, True) + + return auth_key_id + msg_key + AES.ige256_encrypt(data + padding, aes_key, aes_iv) + + @staticmethod + def unpack(b: BytesIO, session_id: bytes, auth_key: bytes, auth_key_id: bytes) -> Message: + assert b.read(8) == auth_key_id, b.getvalue() + + msg_key = b.read(16) + aes_key, aes_iv = KDF(auth_key, msg_key, False) + data = BytesIO(AES.ige256_decrypt(b.read(), aes_key, aes_iv)) + data.read(8) + + # https://core.telegram.org/mtproto/security_guidelines#checking-session-id + assert data.read(8) == session_id + + message = Message.read(data) + + # https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key + # https://core.telegram.org/mtproto/security_guidelines#checking-message-length + # 96 = 88 + 8 (incoming message) + assert msg_key == sha256(auth_key[96:96 + 32] + data.getvalue()).digest()[8:24] + + # https://core.telegram.org/mtproto/security_guidelines#checking-msg-id + assert message.msg_id % 2 != 0 + + return message diff --git a/pyrogram/session/auth.py b/pyrogram/session/auth.py index c3d0b99c..6795928a 100644 --- a/pyrogram/session/auth.py +++ b/pyrogram/session/auth.py @@ -16,6 +16,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +import asyncio import logging import time from hashlib import sha1 @@ -57,14 +58,14 @@ class Auth: b.seek(20) # Skip auth_key_id (8), message_id (8) and message_length (4) return TLObject.read(b) - def send(self, data: TLObject): + async def send(self, data: TLObject): data = self.pack(data) - self.connection.send(data) - response = BytesIO(self.connection.recv()) + await self.connection.send(data) + response = BytesIO(await self.connection.recv()) return self.unpack(response) - def create(self): + async def create(self): """ https://core.telegram.org/mtproto/auth_key https://core.telegram.org/mtproto/samples-auth_key @@ -79,12 +80,12 @@ class Auth: try: log.info("Start creating a new auth key on DC{}".format(self.dc_id)) - self.connection.connect() + await self.connection.connect() # Step 1; Step 2 nonce = int.from_bytes(urandom(16), "little", signed=True) log.debug("Send req_pq: {}".format(nonce)) - res_pq = self.send(functions.ReqPqMulti(nonce=nonce)) + res_pq = await self.send(functions.ReqPqMulti(nonce=nonce)) log.debug("Got ResPq: {}".format(res_pq.server_nonce)) log.debug("Server public key fingerprints: {}".format(res_pq.server_public_key_fingerprints)) @@ -128,7 +129,7 @@ class Auth: # Step 5. TODO: Handle "server_DH_params_fail". Code assumes response is ok log.debug("Send req_DH_params") - server_dh_params = self.send( + server_dh_params = await self.send( functions.ReqDHParams( nonce=nonce, server_nonce=server_nonce, @@ -188,7 +189,7 @@ class Auth: encrypted_data = AES.ige256_encrypt(data_with_hash, tmp_aes_key, tmp_aes_iv) log.debug("Send set_client_DH_params") - set_client_dh_params_answer = self.send( + set_client_dh_params_answer = await self.send( functions.SetClientDHParams( nonce=nonce, server_nonce=server_nonce, @@ -255,7 +256,7 @@ class Auth: else: raise e - time.sleep(1) + await asyncio.sleep(1) continue else: return auth_key diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index 3387b3d2..beb4affa 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -16,23 +16,19 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +import asyncio import logging -import threading -import time -from datetime import timedelta, datetime -from hashlib import sha1, sha256 +from datetime import datetime, timedelta +from hashlib import sha1 from io import BytesIO -from os import urandom -from queue import Queue -from threading import Event, Thread import pyrogram from pyrogram import __copyright__, __license__, __version__ -from pyrogram.api import functions, types, core +from pyrogram.api import functions, types from pyrogram.api.all import layer -from pyrogram.api.core import Message, TLObject, MsgContainer, Long, FutureSalt, Int +from pyrogram.api.core import TLObject, MsgContainer, Int, Long, FutureSalt, FutureSalts from pyrogram.connection import Connection -from pyrogram.crypto import AES, KDF +from pyrogram.crypto import MTProto from pyrogram.errors import RPCError, InternalServerError, AuthKeyDuplicated, FloodWait from .internals import MsgId, MsgFactory @@ -42,7 +38,7 @@ log = logging.getLogger(__name__) class Result: def __init__(self): self.value = None - self.event = Event() + self.event = asyncio.Event() class Session: @@ -101,20 +97,21 @@ class Session: self.pending_acks = set() - self.recv_queue = Queue() + self.recv_queue = asyncio.Queue() self.results = {} - self.ping_thread = None - self.ping_thread_event = Event() + self.ping_task = None + self.ping_task_event = asyncio.Event() - self.next_salt_thread = None - self.next_salt_thread_event = Event() + self.next_salt_task = None + self.next_salt_task_event = asyncio.Event() - self.net_worker_list = [] + self.net_worker_task = None + self.recv_task = None - self.is_connected = Event() + self.is_connected = asyncio.Event() - def start(self): + async def start(self): while True: self.connection = Connection( self.dc_id, @@ -124,35 +121,26 @@ class Session: ) try: - self.connection.connect() + await self.connection.connect() - for i in range(self.NET_WORKERS): - self.net_worker_list.append( - Thread( - target=self.net_worker, - name="NetWorker#{}".format(i + 1) - ) - ) + self.net_worker_task = asyncio.ensure_future(self.net_worker()) + self.recv_task = asyncio.ensure_future(self.recv()) - self.net_worker_list[-1].start() - - Thread(target=self.recv, name="RecvThread").start() - - self.current_salt = FutureSalt(0, 0, self.INITIAL_SALT) + self.current_salt = FutureSalt(0, 0, Session.INITIAL_SALT) self.current_salt = FutureSalt( 0, 0, - self._send( + (await self._send( functions.Ping(ping_id=0), timeout=self.START_TIMEOUT - ).new_server_salt + )).new_server_salt ) - self.current_salt = self._send(functions.GetFutureSalts(num=1), timeout=self.START_TIMEOUT).salts[0] + self.current_salt = \ + (await self._send(functions.GetFutureSalts(num=1), timeout=self.START_TIMEOUT)).salts[0] - self.next_salt_thread = Thread(target=self.next_salt, name="NextSaltThread") - self.next_salt_thread.start() + self.next_salt_task = asyncio.ensure_future(self.next_salt()) if not self.is_cdn: - self._send( + await self._send( functions.InvokeWithLayer( layer=layer, query=functions.InitConnection( @@ -169,116 +157,81 @@ class Session: timeout=self.START_TIMEOUT ) - self.ping_thread = Thread(target=self.ping, name="PingThread") - self.ping_thread.start() + self.ping_task = asyncio.ensure_future(self.ping()) log.info("Session initialized: Layer {}".format(layer)) log.info("Device: {} - {}".format(self.client.device_model, self.client.app_version)) log.info("System: {} ({})".format(self.client.system_version, self.client.lang_code.upper())) except AuthKeyDuplicated as e: - self.stop() + await self.stop() raise e except (OSError, TimeoutError, RPCError): - self.stop() + await self.stop() except Exception as e: - self.stop() + await self.stop() raise e else: break self.is_connected.set() - log.debug("Session started") + log.info("Session started") - def stop(self): + async def stop(self): self.is_connected.clear() - self.ping_thread_event.set() - self.next_salt_thread_event.set() + self.ping_task_event.set() + self.next_salt_task_event.set() - if self.ping_thread is not None: - self.ping_thread.join() + if self.ping_task is not None: + await self.ping_task - if self.next_salt_thread is not None: - self.next_salt_thread.join() + if self.next_salt_task is not None: + await self.next_salt_task - self.ping_thread_event.clear() - self.next_salt_thread_event.clear() + self.ping_task_event.clear() + self.next_salt_task_event.clear() self.connection.close() - for i in range(self.NET_WORKERS): - self.recv_queue.put(None) + if self.recv_task: + await self.recv_task - for i in self.net_worker_list: - i.join() - - self.net_worker_list.clear() - self.recv_queue.queue.clear() + if self.net_worker_task: + await self.net_worker_task for i in self.results.values(): i.event.set() if not self.is_media and callable(self.client.disconnect_handler): try: - self.client.disconnect_handler(self.client) + await self.client.disconnect_handler(self.client) except Exception as e: log.error(e, exc_info=True) - log.debug("Session stopped") + log.info("Session stopped") - def restart(self): - self.stop() - self.start() + async def restart(self): + await self.stop() + await self.start() - def pack(self, message: Message): - data = Long(self.current_salt.salt) + self.session_id + message.write() - padding = urandom(-(len(data) + 12) % 16 + 12) - - # 88 = 88 + 0 (outgoing message) - msg_key_large = sha256(self.auth_key[88: 88 + 32] + data + padding).digest() - msg_key = msg_key_large[8:24] - aes_key, aes_iv = KDF(self.auth_key, msg_key, True) - - return self.auth_key_id + msg_key + AES.ige256_encrypt(data + padding, aes_key, aes_iv) - - def unpack(self, b: BytesIO) -> Message: - assert b.read(8) == self.auth_key_id, b.getvalue() - - msg_key = b.read(16) - aes_key, aes_iv = KDF(self.auth_key, msg_key, False) - data = BytesIO(AES.ige256_decrypt(b.read(), aes_key, aes_iv)) - data.read(8) - - # https://core.telegram.org/mtproto/security_guidelines#checking-session-id - assert data.read(8) == self.session_id - - message = Message.read(data) - - # https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key - # https://core.telegram.org/mtproto/security_guidelines#checking-message-length - # 96 = 88 + 8 (incoming message) - assert msg_key == sha256(self.auth_key[96:96 + 32] + data.getvalue()).digest()[8:24] - - # https://core.telegram.org/mtproto/security_guidelines#checking-msg-id - # TODO: check for lower msg_ids - assert message.msg_id % 2 != 0 - - return message - - def net_worker(self): - name = threading.current_thread().name - log.debug("{} started".format(name)) + async def net_worker(self): + logging.info("NetWorkerTask started") while True: - packet = self.recv_queue.get() + packet = await self.recv_queue.get() if packet is None: break try: - data = self.unpack(BytesIO(packet)) + data = MTProto.unpack( + BytesIO(packet), + self.session_id, + self.auth_key, + self.auth_key_id + ) messages = ( data.body.messages @@ -306,13 +259,13 @@ class Session: if isinstance(msg.body, (types.BadMsgNotification, types.BadServerSalt)): msg_id = msg.body.bad_msg_id - elif isinstance(msg.body, (core.FutureSalts, types.RpcResult)): + elif isinstance(msg.body, (FutureSalts, types.RpcResult)): msg_id = msg.body.req_msg_id elif isinstance(msg.body, types.Pong): msg_id = msg.body.msg_id else: if self.client is not None: - self.client.updates_queue.put(msg.body) + self.client.updates_queue.put_nowait(msg.body) if msg_id in self.results: self.results[msg_id].value = getattr(msg.body, "result", msg.body) @@ -322,7 +275,7 @@ class Session: log.info("Send {} acks".format(len(self.pending_acks))) try: - self._send(types.MsgsAck(msg_ids=list(self.pending_acks)), False) + await self._send(types.MsgsAck(msg_ids=list(self.pending_acks)), False) except (OSError, TimeoutError): pass else: @@ -330,28 +283,32 @@ class Session: except Exception as e: log.error(e, exc_info=True) - log.debug("{} stopped".format(name)) + log.info("NetWorkerTask stopped") - def ping(self): - log.debug("PingThread started") + async def ping(self): + log.info("PingTask started") while True: - self.ping_thread_event.wait(self.PING_INTERVAL) - - if self.ping_thread_event.is_set(): + try: + await asyncio.wait_for(self.ping_task_event.wait(), self.PING_INTERVAL) + except asyncio.TimeoutError: + pass + else: break try: - self._send(functions.PingDelayDisconnect( - ping_id=0, disconnect_delay=self.WAIT_TIMEOUT + 10 - ), False) + await self._send( + functions.PingDelayDisconnect( + ping_id=0, disconnect_delay=self.WAIT_TIMEOUT + 10 + ), False + ) except (OSError, TimeoutError, RPCError): pass - log.debug("PingThread stopped") + log.info("PingTask stopped") - def next_salt(self): - log.debug("NextSaltThread started") + async def next_salt(self): + log.info("NextSaltTask started") while True: now = datetime.now() @@ -361,45 +318,48 @@ class Session: valid_until = datetime.fromtimestamp(self.current_salt.valid_until) dt = (valid_until - now).total_seconds() - 900 - log.debug("Current salt: {} | Next salt in {:.0f}m {:.0f}s ({})".format( - self.current_salt.salt, - dt // 60, - dt % 60, + log.info("Next salt in {:.0f}m {:.0f}s ({})".format( + dt // 60, dt % 60, now + timedelta(seconds=dt) )) - self.next_salt_thread_event.wait(dt) - - if self.next_salt_thread_event.is_set(): + try: + await asyncio.wait_for(self.next_salt_task_event.wait(), dt) + except asyncio.TimeoutError: + pass + else: break try: - self.current_salt = self._send(functions.GetFutureSalts(num=1)).salts[0] + self.current_salt = (await self._send(functions.GetFutureSalts(num=1))).salts[0] except (OSError, TimeoutError, RPCError): self.connection.close() break - log.debug("NextSaltThread stopped") + log.info("NextSaltTask stopped") - def recv(self): - log.debug("RecvThread started") + async def recv(self): + log.info("RecvTask started") while True: - packet = self.connection.recv() + packet = await self.connection.recv() if packet is None or len(packet) == 4: + self.recv_queue.put_nowait(None) + if packet: log.warning("Server sent \"{}\"".format(Int.read(BytesIO(packet)))) if self.is_connected.is_set(): - Thread(target=self.restart, name="RestartThread").start() + asyncio.ensure_future(self.restart()) + break - self.recv_queue.put(packet) + self.recv_queue.put_nowait(packet) - log.debug("RecvThread stopped") + log.info("RecvTask stopped") - def _send(self, data: TLObject, wait_response: bool = True, timeout: float = WAIT_TIMEOUT): + async def _send(self, data: TLObject, wait_response: bool = True, timeout: float = WAIT_TIMEOUT): message = self.msg_factory(data) msg_id = message.msg_id @@ -408,17 +368,27 @@ class Session: log.debug("Sent:\n{}".format(message)) - payload = self.pack(message) + payload = MTProto.pack( + message, + self.current_salt.salt, + self.session_id, + self.auth_key, + self.auth_key_id + ) try: - self.connection.send(payload) + await self.connection.send(payload) except OSError as e: self.results.pop(msg_id, None) raise e if wait_response: - self.results[msg_id].event.wait(timeout) - result = self.results.pop(msg_id).value + try: + await asyncio.wait_for(self.results[msg_id].event.wait(), timeout) + except asyncio.TimeoutError: + pass + finally: + result = self.results.pop(msg_id).value if result is None: raise TimeoutError @@ -435,14 +405,17 @@ class Session: else: return result - def send( + async def send( self, data: TLObject, retries: int = MAX_RETRIES, timeout: float = WAIT_TIMEOUT, sleep_threshold: float = SLEEP_THRESHOLD ): - self.is_connected.wait(self.WAIT_TIMEOUT) + try: + await asyncio.wait_for(self.is_connected.wait(), self.WAIT_TIMEOUT) + except asyncio.TimeoutError: + pass if isinstance(data, (functions.InvokeWithoutUpdates, functions.InvokeWithTakeout)): query = data.query @@ -453,7 +426,7 @@ class Session: while True: try: - return self._send(data, timeout=timeout) + return await self._send(data, timeout=timeout) except FloodWait as e: amount = e.x @@ -463,7 +436,7 @@ class Session: log.warning('[{}] Sleeping for {}s (required by "{}")'.format( self.client.session_name, amount, query)) - time.sleep(amount) + await asyncio.sleep(amount) except (OSError, TimeoutError, InternalServerError) as e: if retries == 0: raise e from None @@ -473,6 +446,6 @@ class Session: Session.MAX_RETRIES - retries + 1, query, e)) - time.sleep(0.5) + await asyncio.sleep(0.5) - return self.send(data, retries - 1, timeout) + return await self.send(data, retries - 1, timeout) diff --git a/requirements.txt b/requirements.txt index 9e6f7590..1af97061 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,4 @@ pyaes==1.6.1 -pysocks==1.7.1 \ No newline at end of file +pysocks==1.7.1 +async_lru==1.0.1 +async_generator==1.10 \ No newline at end of file