Change logging hierarchy for loading plugins (#451)

Loading plugins shouldn't be considered a warning
This commit is contained in:
CyanBook 2020-08-21 07:28:27 +02:00 committed by GitHub
parent 2e08266f56
commit c8c6faa96e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
133 changed files with 1349 additions and 1207 deletions

View File

@ -16,7 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
__version__ = "0.18.0" __version__ = "0.18.0-async"
__license__ = "GNU Lesser General Public License v3 or later (LGPLv3+)" __license__ = "GNU Lesser General Public License v3 or later (LGPLv3+)"
__copyright__ = "Copyright (C) 2017-2020 Dan <https://github.com/delivrance>" __copyright__ = "Copyright (C) 2017-2020 Dan <https://github.com/delivrance>"

View File

@ -16,6 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import io import io
import logging import logging
import math import math
@ -23,14 +24,11 @@ import os
import re import re
import shutil import shutil
import tempfile import tempfile
import threading
import time
from configparser import ConfigParser from configparser import ConfigParser
from hashlib import sha256, md5 from hashlib import sha256, md5
from importlib import import_module from importlib import import_module
from pathlib import Path from pathlib import Path
from signal import signal, SIGINT, SIGTERM, SIGABRT from signal import signal, SIGINT, SIGTERM, SIGABRT
from threading import Thread
from typing import Union, List, BinaryIO from typing import Union, List, BinaryIO
from pyrogram.api import functions, types from pyrogram.api import functions, types
@ -40,11 +38,13 @@ from pyrogram.client.handlers.handler import Handler
from pyrogram.client.methods.password.utils import compute_check from pyrogram.client.methods.password.utils import compute_check
from pyrogram.crypto import AES from pyrogram.crypto import AES
from pyrogram.errors import ( from pyrogram.errors import (
PhoneMigrate, NetworkMigrate, SessionPasswordNeeded, PeerIdInvalid, VolumeLocNotFound, UserMigrate, ChannelPrivate, PhoneMigrate, NetworkMigrate, SessionPasswordNeeded,
PeerIdInvalid, VolumeLocNotFound, UserMigrate, ChannelPrivate,
AuthBytesInvalid, BadRequest AuthBytesInvalid, BadRequest
) )
from pyrogram.session import Auth, Session from pyrogram.session import Auth, Session
from .ext import utils, Syncer, BaseClient, Dispatcher from .ext import utils, Syncer, BaseClient, Dispatcher
from .ext.utils import ainput
from .methods import Methods from .methods import Methods
from .storage import Storage, FileStorage, MemoryStorage from .storage import Storage, FileStorage, MemoryStorage
from .types import User, SentCode, TermsOfService from .types import User, SentCode, TermsOfService
@ -127,7 +127,7 @@ class Client(Methods, BaseClient):
Defaults to False. Defaults to False.
workers (``int``, *optional*): workers (``int``, *optional*):
Thread pool size for handling incoming updates. Number of maximum concurrent workers for handling incoming updates.
Defaults to 4. Defaults to 4.
workdir (``str``, *optional*): workdir (``str``, *optional*):
@ -243,6 +243,12 @@ class Client(Methods, BaseClient):
except ConnectionError: except ConnectionError:
pass pass
async def __aenter__(self):
return await self.start()
async def __aexit__(self, *args):
await self.stop()
@property @property
def proxy(self): def proxy(self):
return self._proxy return self._proxy
@ -259,7 +265,7 @@ class Client(Methods, BaseClient):
self._proxy["enabled"] = bool(value.get("enabled", True)) self._proxy["enabled"] = bool(value.get("enabled", True))
self._proxy.update(value) self._proxy.update(value)
def connect(self) -> bool: async def connect(self) -> bool:
""" """
Connect the client to Telegram servers. Connect the client to Telegram servers.
@ -274,16 +280,17 @@ class Client(Methods, BaseClient):
raise ConnectionError("Client is already connected") raise ConnectionError("Client is already connected")
self.load_config() self.load_config()
self.load_session() await self.load_session()
self.session = Session(self, self.storage.dc_id(), self.storage.auth_key()) self.session = Session(self, self.storage.dc_id(), self.storage.auth_key())
self.session.start()
await self.session.start()
self.is_connected = True self.is_connected = True
return bool(self.storage.user_id()) return bool(self.storage.user_id())
def disconnect(self): async def disconnect(self):
"""Disconnect the client from Telegram servers. """Disconnect the client from Telegram servers.
Raises: Raises:
@ -296,11 +303,11 @@ class Client(Methods, BaseClient):
if self.is_initialized: if self.is_initialized:
raise ConnectionError("Can't disconnect an initialized client") raise ConnectionError("Can't disconnect an initialized client")
self.session.stop() await self.session.stop()
self.storage.close() self.storage.close()
self.is_connected = False self.is_connected = False
def initialize(self): async def initialize(self):
"""Initialize the client by starting up workers. """Initialize the client by starting up workers.
This method will start updates and download workers. This method will start updates and download workers.
@ -319,33 +326,26 @@ class Client(Methods, BaseClient):
self.load_plugins() self.load_plugins()
if not self.no_updates: if not self.no_updates:
for i in range(self.UPDATES_WORKERS): for _ in range(Client.UPDATES_WORKERS):
self.updates_workers_list.append( self.updates_worker_tasks.append(
Thread( asyncio.ensure_future(self.updates_worker())
target=self.updates_worker,
name="UpdatesWorker#{}".format(i + 1)
)
) )
self.updates_workers_list[-1].start() logging.info("Started {} UpdatesWorkerTasks".format(Client.UPDATES_WORKERS))
for i in range(self.DOWNLOAD_WORKERS): for _ in range(Client.DOWNLOAD_WORKERS):
self.download_workers_list.append( self.download_worker_tasks.append(
Thread( asyncio.ensure_future(self.download_worker())
target=self.download_worker,
name="DownloadWorker#{}".format(i + 1)
)
) )
self.download_workers_list[-1].start() logging.info("Started {} DownloadWorkerTasks".format(Client.DOWNLOAD_WORKERS))
self.dispatcher.start() await self.dispatcher.start()
await Syncer.add(self)
Syncer.add(self)
self.is_initialized = True self.is_initialized = True
def terminate(self): async def terminate(self):
"""Terminate the client by shutting down workers. """Terminate the client by shutting down workers.
This method does the opposite of :meth:`~Client.initialize`. This method does the opposite of :meth:`~Client.initialize`.
@ -358,37 +358,41 @@ class Client(Methods, BaseClient):
raise ConnectionError("Client is already terminated") raise ConnectionError("Client is already terminated")
if self.takeout_id: if self.takeout_id:
self.send(functions.account.FinishTakeoutSession()) await self.send(functions.account.FinishTakeoutSession())
log.warning("Takeout session {} finished".format(self.takeout_id)) log.warning("Takeout session {} finished".format(self.takeout_id))
Syncer.remove(self) await Syncer.remove(self)
self.dispatcher.stop() await self.dispatcher.stop()
for _ in range(self.DOWNLOAD_WORKERS): for _ in range(Client.DOWNLOAD_WORKERS):
self.download_queue.put(None) self.download_queue.put_nowait(None)
for i in self.download_workers_list: for task in self.download_worker_tasks:
i.join() await task
self.download_workers_list.clear() self.download_worker_tasks.clear()
logging.info("Stopped {} DownloadWorkerTasks".format(Client.DOWNLOAD_WORKERS))
if not self.no_updates: if not self.no_updates:
for _ in range(self.UPDATES_WORKERS): for _ in range(Client.UPDATES_WORKERS):
self.updates_queue.put(None) self.updates_queue.put_nowait(None)
for i in self.updates_workers_list: for task in self.updates_worker_tasks:
i.join() await task
self.updates_workers_list.clear() self.updates_worker_tasks.clear()
for i in self.media_sessions.values(): logging.info("Stopped {} UpdatesWorkerTasks".format(Client.UPDATES_WORKERS))
i.stop()
for media_session in self.media_sessions.values():
await media_session.stop()
self.media_sessions.clear() self.media_sessions.clear()
self.is_initialized = False self.is_initialized = False
def send_code(self, phone_number: str) -> SentCode: async def send_code(self, phone_number: str) -> SentCode:
"""Send the confirmation code to the given phone number. """Send the confirmation code to the given phone number.
Parameters: Parameters:
@ -405,7 +409,7 @@ class Client(Methods, BaseClient):
while True: while True:
try: try:
r = self.send( r = await self.send(
functions.auth.SendCode( functions.auth.SendCode(
phone_number=phone_number, phone_number=phone_number,
api_id=self.api_id, api_id=self.api_id,
@ -414,17 +418,17 @@ class Client(Methods, BaseClient):
) )
) )
except (PhoneMigrate, NetworkMigrate) as e: except (PhoneMigrate, NetworkMigrate) as e:
self.session.stop() await self.session.stop()
self.storage.dc_id(e.x) self.storage.dc_id(e.x)
self.storage.auth_key(Auth(self, self.storage.dc_id()).create()) self.storage.auth_key(await Auth(self, self.storage.dc_id()).create())
self.session = Session(self, self.storage.dc_id(), self.storage.auth_key()) self.session = Session(self, self.storage.dc_id(), self.storage.auth_key())
self.session.start() await self.session.start()
else: else:
return SentCode._parse(r) return SentCode._parse(r)
def resend_code(self, phone_number: str, phone_code_hash: str) -> SentCode: async def resend_code(self, phone_number: str, phone_code_hash: str) -> SentCode:
"""Re-send the confirmation code using a different type. """Re-send the confirmation code using a different type.
The type of the code to be re-sent is specified in the *next_type* attribute of the :obj:`SentCode` object The type of the code to be re-sent is specified in the *next_type* attribute of the :obj:`SentCode` object
@ -445,7 +449,7 @@ class Client(Methods, BaseClient):
""" """
phone_number = phone_number.strip(" +") phone_number = phone_number.strip(" +")
r = self.send( r = await self.send(
functions.auth.ResendCode( functions.auth.ResendCode(
phone_number=phone_number, phone_number=phone_number,
phone_code_hash=phone_code_hash phone_code_hash=phone_code_hash
@ -454,7 +458,8 @@ class Client(Methods, BaseClient):
return SentCode._parse(r) return SentCode._parse(r)
def sign_in(self, phone_number: str, phone_code_hash: str, phone_code: str) -> Union[User, TermsOfService, bool]: async def sign_in(self, phone_number: str, phone_code_hash: str, phone_code: str) -> Union[
User, TermsOfService, bool]:
"""Authorize a user in Telegram with a valid confirmation code. """Authorize a user in Telegram with a valid confirmation code.
Parameters: Parameters:
@ -479,7 +484,7 @@ class Client(Methods, BaseClient):
""" """
phone_number = phone_number.strip(" +") phone_number = phone_number.strip(" +")
r = self.send( r = await self.send(
functions.auth.SignIn( functions.auth.SignIn(
phone_number=phone_number, phone_number=phone_number,
phone_code_hash=phone_code_hash, phone_code_hash=phone_code_hash,
@ -498,7 +503,7 @@ class Client(Methods, BaseClient):
return User._parse(self, r.user) return User._parse(self, r.user)
def sign_up(self, phone_number: str, phone_code_hash: str, first_name: str, last_name: str = "") -> User: async def sign_up(self, phone_number: str, phone_code_hash: str, first_name: str, last_name: str = "") -> User:
"""Register a new user in Telegram. """Register a new user in Telegram.
Parameters: Parameters:
@ -522,7 +527,7 @@ class Client(Methods, BaseClient):
""" """
phone_number = phone_number.strip(" +") phone_number = phone_number.strip(" +")
r = self.send( r = await self.send(
functions.auth.SignUp( functions.auth.SignUp(
phone_number=phone_number, phone_number=phone_number,
first_name=first_name, first_name=first_name,
@ -536,7 +541,7 @@ class Client(Methods, BaseClient):
return User._parse(self, r.user) return User._parse(self, r.user)
def sign_in_bot(self, bot_token: str) -> User: async def sign_in_bot(self, bot_token: str) -> User:
"""Authorize a bot using its bot token generated by BotFather. """Authorize a bot using its bot token generated by BotFather.
Parameters: Parameters:
@ -551,7 +556,7 @@ class Client(Methods, BaseClient):
""" """
while True: while True:
try: try:
r = self.send( r = await self.send(
functions.auth.ImportBotAuthorization( functions.auth.ImportBotAuthorization(
flags=0, flags=0,
api_id=self.api_id, api_id=self.api_id,
@ -560,28 +565,28 @@ class Client(Methods, BaseClient):
) )
) )
except UserMigrate as e: except UserMigrate as e:
self.session.stop() await self.session.stop()
self.storage.dc_id(e.x) self.storage.dc_id(e.x)
self.storage.auth_key(Auth(self, self.storage.dc_id()).create()) self.storage.auth_key(await Auth(self, self.storage.dc_id()).create())
self.session = Session(self, self.storage.dc_id(), self.storage.auth_key()) self.session = Session(self, self.storage.dc_id(), self.storage.auth_key())
self.session.start() await self.session.start()
else: else:
self.storage.user_id(r.user.id) self.storage.user_id(r.user.id)
self.storage.is_bot(True) self.storage.is_bot(True)
return User._parse(self, r.user) return User._parse(self, r.user)
def get_password_hint(self) -> str: async def get_password_hint(self) -> str:
"""Get your Two-Step Verification password hint. """Get your Two-Step Verification password hint.
Returns: Returns:
``str``: On success, the password hint as string is returned. ``str``: On success, the password hint as string is returned.
""" """
return self.send(functions.account.GetPassword()).hint return (await self.send(functions.account.GetPassword())).hint
def check_password(self, password: str) -> User: async def check_password(self, password: str) -> User:
"""Check your Two-Step Verification password and log in. """Check your Two-Step Verification password and log in.
Parameters: Parameters:
@ -594,10 +599,10 @@ class Client(Methods, BaseClient):
Raises: Raises:
BadRequest: In case the password is invalid. BadRequest: In case the password is invalid.
""" """
r = self.send( r = await self.send(
functions.auth.CheckPassword( functions.auth.CheckPassword(
password=compute_check( password=compute_check(
self.send(functions.account.GetPassword()), await self.send(functions.account.GetPassword()),
password password
) )
) )
@ -608,7 +613,7 @@ class Client(Methods, BaseClient):
return User._parse(self, r.user) return User._parse(self, r.user)
def send_recovery_code(self) -> str: async def send_recovery_code(self) -> str:
"""Send a code to your email to recover your password. """Send a code to your email to recover your password.
Returns: Returns:
@ -617,11 +622,11 @@ class Client(Methods, BaseClient):
Raises: Raises:
BadRequest: In case no recovery email was set up. BadRequest: In case no recovery email was set up.
""" """
return self.send( return (await self.send(
functions.auth.RequestPasswordRecovery() functions.auth.RequestPasswordRecovery()
).email_pattern )).email_pattern
def recover_password(self, recovery_code: str) -> User: async def recover_password(self, recovery_code: str) -> User:
"""Recover your password with a recovery code and log in. """Recover your password with a recovery code and log in.
Parameters: Parameters:
@ -634,7 +639,7 @@ class Client(Methods, BaseClient):
Raises: Raises:
BadRequest: In case the recovery code is invalid. BadRequest: In case the recovery code is invalid.
""" """
r = self.send( r = await self.send(
functions.auth.RecoverPassword( functions.auth.RecoverPassword(
code=recovery_code code=recovery_code
) )
@ -645,14 +650,14 @@ class Client(Methods, BaseClient):
return User._parse(self, r.user) return User._parse(self, r.user)
def accept_terms_of_service(self, terms_of_service_id: str) -> bool: async def accept_terms_of_service(self, terms_of_service_id: str) -> bool:
"""Accept the given terms of service. """Accept the given terms of service.
Parameters: Parameters:
terms_of_service_id (``str``): terms_of_service_id (``str``):
The terms of service identifier. The terms of service identifier.
""" """
r = self.send( r = await self.send(
functions.help.AcceptTermsOfService( functions.help.AcceptTermsOfService(
id=types.DataJSON( id=types.DataJSON(
data=terms_of_service_id data=terms_of_service_id
@ -664,15 +669,15 @@ class Client(Methods, BaseClient):
return True return True
def authorize(self) -> User: async def authorize(self) -> User:
if self.bot_token: if self.bot_token:
return self.sign_in_bot(self.bot_token) return await self.sign_in_bot(self.bot_token)
while True: while True:
try: try:
if not self.phone_number: if not self.phone_number:
while True: while True:
value = input("Enter phone number or bot token: ") value = await ainput("Enter phone number or bot token: ")
if not value: if not value:
continue continue
@ -684,11 +689,11 @@ class Client(Methods, BaseClient):
if ":" in value: if ":" in value:
self.bot_token = value self.bot_token = value
return self.sign_in_bot(value) return await self.sign_in_bot(value)
else: else:
self.phone_number = value self.phone_number = value
sent_code = self.send_code(self.phone_number) sent_code = await self.send_code(self.phone_number)
except BadRequest as e: except BadRequest as e:
print(e.MESSAGE) print(e.MESSAGE)
self.phone_number = None self.phone_number = None
@ -697,7 +702,7 @@ class Client(Methods, BaseClient):
break break
if self.force_sms: if self.force_sms:
sent_code = self.resend_code(self.phone_number, sent_code.phone_code_hash) sent_code = await self.resend_code(self.phone_number, sent_code.phone_code_hash)
print("The confirmation code has been sent via {}".format( print("The confirmation code has been sent via {}".format(
{ {
@ -710,10 +715,10 @@ class Client(Methods, BaseClient):
while True: while True:
if not self.phone_code: if not self.phone_code:
self.phone_code = input("Enter confirmation code: ") self.phone_code = await ainput("Enter confirmation code: ")
try: try:
signed_in = self.sign_in(self.phone_number, sent_code.phone_code_hash, self.phone_code) signed_in = await self.sign_in(self.phone_number, sent_code.phone_code_hash, self.phone_code)
except BadRequest as e: except BadRequest as e:
print(e.MESSAGE) print(e.MESSAGE)
self.phone_code = None self.phone_code = None
@ -721,24 +726,24 @@ class Client(Methods, BaseClient):
print(e.MESSAGE) print(e.MESSAGE)
while True: while True:
print("Password hint: {}".format(self.get_password_hint())) print("Password hint: {}".format(await self.get_password_hint()))
if not self.password: if not self.password:
self.password = input("Enter password (empty to recover): ") self.password = await ainput("Enter password (empty to recover): ")
try: try:
if not self.password: if not self.password:
confirm = input("Confirm password recovery (y/n): ") confirm = await ainput("Confirm password recovery (y/n): ")
if confirm == "y": if confirm == "y":
email_pattern = self.send_recovery_code() email_pattern = await self.send_recovery_code()
print("The recovery code has been sent to {}".format(email_pattern)) print("The recovery code has been sent to {}".format(email_pattern))
while True: while True:
recovery_code = input("Enter recovery code: ") recovery_code = await ainput("Enter recovery code: ")
try: try:
return self.recover_password(recovery_code) return await self.recover_password(recovery_code)
except BadRequest as e: except BadRequest as e:
print(e.MESSAGE) print(e.MESSAGE)
except Exception as e: except Exception as e:
@ -747,7 +752,7 @@ class Client(Methods, BaseClient):
else: else:
self.password = None self.password = None
else: else:
return self.check_password(self.password) return await self.check_password(self.password)
except BadRequest as e: except BadRequest as e:
print(e.MESSAGE) print(e.MESSAGE)
self.password = None self.password = None
@ -758,11 +763,11 @@ class Client(Methods, BaseClient):
return signed_in return signed_in
while True: while True:
first_name = input("Enter first name: ") first_name = await ainput("Enter first name: ")
last_name = input("Enter last name (empty to skip): ") last_name = await ainput("Enter last name (empty to skip): ")
try: try:
signed_up = self.sign_up( signed_up = await self.sign_up(
self.phone_number, self.phone_number,
sent_code.phone_code_hash, sent_code.phone_code_hash,
first_name, first_name,
@ -775,11 +780,11 @@ class Client(Methods, BaseClient):
if isinstance(signed_in, TermsOfService): if isinstance(signed_in, TermsOfService):
print("\n" + signed_in.text + "\n") print("\n" + signed_in.text + "\n")
self.accept_terms_of_service(signed_in.id) await self.accept_terms_of_service(signed_in.id)
return signed_up return signed_up
def log_out(self): async def log_out(self):
"""Log out from Telegram and delete the *\\*.session* file. """Log out from Telegram and delete the *\\*.session* file.
When you log out, the current client is stopped and the storage session deleted. When you log out, the current client is stopped and the storage session deleted.
@ -794,13 +799,13 @@ class Client(Methods, BaseClient):
# Log out. # Log out.
app.log_out() app.log_out()
""" """
self.send(functions.auth.LogOut()) await self.send(functions.auth.LogOut())
self.stop() await self.stop()
self.storage.delete() self.storage.delete()
return True return True
def start(self): async def start(self):
"""Start the client. """Start the client.
This method connects the client to Telegram and, in case of new sessions, automatically manages the full This method connects the client to Telegram and, in case of new sessions, automatically manages the full
@ -825,25 +830,25 @@ class Client(Methods, BaseClient):
app.stop() app.stop()
""" """
is_authorized = self.connect() is_authorized = await self.connect()
try: try:
if not is_authorized: if not is_authorized:
self.authorize() await self.authorize()
if not self.storage.is_bot() and self.takeout: if not self.storage.is_bot() and self.takeout:
self.takeout_id = self.send(functions.account.InitTakeoutSession()).id self.takeout_id = (await self.send(functions.account.InitTakeoutSession())).id
log.warning("Takeout session {} initiated".format(self.takeout_id)) log.warning("Takeout session {} initiated".format(self.takeout_id))
self.send(functions.updates.GetState()) await self.send(functions.updates.GetState())
except (Exception, KeyboardInterrupt): except (Exception, KeyboardInterrupt):
self.disconnect() await self.disconnect()
raise raise
else: else:
self.initialize() await self.initialize()
return self return self
def stop(self, block: bool = True): async def stop(self, block: bool = True):
"""Stop the Client. """Stop the Client.
This method disconnects the client from Telegram and stops the underlying tasks. This method disconnects the client from Telegram and stops the underlying tasks.
@ -874,18 +879,18 @@ class Client(Methods, BaseClient):
app.stop() app.stop()
""" """
def do_it(): async def do_it():
self.terminate() await self.terminate()
self.disconnect() await self.disconnect()
if block: if block:
do_it() await do_it()
else: else:
Thread(target=do_it).start() asyncio.ensure_future(do_it())
return self return self
def restart(self, block: bool = True): async def restart(self, block: bool = True):
"""Restart the Client. """Restart the Client.
This method will first call :meth:`~Client.stop` and then :meth:`~Client.start` in a row in order to restart This method will first call :meth:`~Client.stop` and then :meth:`~Client.start` in a row in order to restart
@ -921,19 +926,19 @@ class Client(Methods, BaseClient):
app.stop() app.stop()
""" """
def do_it(): async def do_it():
self.stop() await self.stop()
self.start() await self.start()
if block: if block:
do_it() await do_it()
else: else:
Thread(target=do_it).start() asyncio.ensure_future(do_it())
return self return self
@staticmethod @staticmethod
def idle(stop_signals: tuple = (SIGINT, SIGTERM, SIGABRT)): async def idle(stop_signals: tuple = (SIGINT, SIGTERM, SIGABRT)):
"""Block the main script execution until a signal is received. """Block the main script execution until a signal is received.
This static method will run an infinite loop in order to block the main script execution and prevent it from This static method will run an infinite loop in order to block the main script execution and prevent it from
@ -978,6 +983,7 @@ class Client(Methods, BaseClient):
""" """
def signal_handler(_, __): def signal_handler(_, __):
logging.info("Stop signal received ({}). Exiting...".format(_))
Client.is_idling = False Client.is_idling = False
for s in stop_signals: for s in stop_signals:
@ -986,9 +992,9 @@ class Client(Methods, BaseClient):
Client.is_idling = True Client.is_idling = True
while Client.is_idling: while Client.is_idling:
time.sleep(1) await asyncio.sleep(1)
def run(self): def run(self, coroutine=None):
"""Start the client, idle the main script and finally stop the client. """Start the client, idle the main script and finally stop the client.
This is a convenience method that calls :meth:`~Client.start`, :meth:`~Client.idle` and :meth:`~Client.stop` in This is a convenience method that calls :meth:`~Client.start`, :meth:`~Client.idle` and :meth:`~Client.stop` in
@ -1010,9 +1016,17 @@ class Client(Methods, BaseClient):
app.run() app.run()
""" """
self.start() loop = asyncio.get_event_loop()
Client.idle() run = loop.run_until_complete
self.stop()
if coroutine is not None:
run(coroutine)
else:
run(self.start())
run(Client.idle())
run(self.stop())
loop.close()
def add_handler(self, handler: Handler, group: int = 0): def add_handler(self, handler: Handler, group: int = 0):
"""Register an update handler. """Register an update handler.
@ -1236,12 +1250,9 @@ class Client(Methods, BaseClient):
return is_min return is_min
def download_worker(self): async def download_worker(self):
name = threading.current_thread().name
log.debug("{} started".format(name))
while True: while True:
packet = self.download_queue.get() packet = await self.download_queue.get()
if packet is None: if packet is None:
break break
@ -1252,7 +1263,7 @@ class Client(Methods, BaseClient):
try: try:
data, directory, file_name, done, progress, progress_args, path = packet data, directory, file_name, done, progress, progress_args, path = packet
temp_file_path = self.get_file( temp_file_path = await self.get_file(
media_type=data.media_type, media_type=data.media_type,
dc_id=data.dc_id, dc_id=data.dc_id,
document_id=data.document_id, document_id=data.document_id,
@ -1289,14 +1300,9 @@ class Client(Methods, BaseClient):
finally: finally:
done.set() done.set()
log.debug("{} stopped".format(name)) async def updates_worker(self):
def updates_worker(self):
name = threading.current_thread().name
log.debug("{} started".format(name))
while True: while True:
updates = self.updates_queue.get() updates = await self.updates_queue.get()
if updates is None: if updates is None:
break break
@ -1328,9 +1334,9 @@ class Client(Methods, BaseClient):
if not isinstance(message, types.MessageEmpty): if not isinstance(message, types.MessageEmpty):
try: try:
diff = self.send( diff = await self.send(
functions.updates.GetChannelDifference( functions.updates.GetChannelDifference(
channel=self.resolve_peer(utils.get_channel_id(channel_id)), channel=await self.resolve_peer(utils.get_channel_id(channel_id)),
filter=types.ChannelMessagesFilter( filter=types.ChannelMessagesFilter(
ranges=[types.MessageRange( ranges=[types.MessageRange(
min_id=update.message.id, min_id=update.message.id,
@ -1348,9 +1354,9 @@ class Client(Methods, BaseClient):
users.update({u.id: u for u in diff.users}) users.update({u.id: u for u in diff.users})
chats.update({c.id: c for c in diff.chats}) chats.update({c.id: c for c in diff.chats})
self.dispatcher.updates_queue.put((update, users, chats)) self.dispatcher.updates_queue.put_nowait((update, users, chats))
elif isinstance(updates, (types.UpdateShortMessage, types.UpdateShortChatMessage)): elif isinstance(updates, (types.UpdateShortMessage, types.UpdateShortChatMessage)):
diff = self.send( diff = await self.send(
functions.updates.GetDifference( functions.updates.GetDifference(
pts=updates.pts - updates.pts_count, pts=updates.pts - updates.pts_count,
date=updates.date, date=updates.date,
@ -1359,7 +1365,7 @@ class Client(Methods, BaseClient):
) )
if diff.new_messages: if diff.new_messages:
self.dispatcher.updates_queue.put(( self.dispatcher.updates_queue.put_nowait((
types.UpdateNewMessage( types.UpdateNewMessage(
message=diff.new_messages[0], message=diff.new_messages[0],
pts=updates.pts, pts=updates.pts,
@ -1369,17 +1375,15 @@ class Client(Methods, BaseClient):
{c.id: c for c in diff.chats} {c.id: c for c in diff.chats}
)) ))
else: else:
self.dispatcher.updates_queue.put((diff.other_updates[0], {}, {})) self.dispatcher.updates_queue.put_nowait((diff.other_updates[0], {}, {}))
elif isinstance(updates, types.UpdateShort): elif isinstance(updates, types.UpdateShort):
self.dispatcher.updates_queue.put((updates.update, {}, {})) self.dispatcher.updates_queue.put_nowait((updates.update, {}, {}))
elif isinstance(updates, types.UpdatesTooLong): elif isinstance(updates, types.UpdatesTooLong):
log.info(updates) log.info(updates)
except Exception as e: except Exception as e:
log.error(e, exc_info=True) log.error(e, exc_info=True)
log.debug("{} stopped".format(name)) async def send(self, data: TLObject, retries: int = Session.MAX_RETRIES, timeout: float = Session.WAIT_TIMEOUT):
def send(self, data: TLObject, retries: int = Session.MAX_RETRIES, timeout: float = Session.WAIT_TIMEOUT):
"""Send raw Telegram queries. """Send raw Telegram queries.
This method makes it possible to manually call every single Telegram API method in a low-level manner. This method makes it possible to manually call every single Telegram API method in a low-level manner.
@ -1417,7 +1421,7 @@ class Client(Methods, BaseClient):
if self.takeout_id: if self.takeout_id:
data = functions.InvokeWithTakeout(takeout_id=self.takeout_id, query=data) data = functions.InvokeWithTakeout(takeout_id=self.takeout_id, query=data)
r = self.session.send(data, retries, timeout, self.sleep_threshold) r = await self.session.send(data, retries, timeout, self.sleep_threshold)
self.fetch_peers(getattr(r, "users", [])) self.fetch_peers(getattr(r, "users", []))
self.fetch_peers(getattr(r, "chats", [])) self.fetch_peers(getattr(r, "chats", []))
@ -1497,7 +1501,7 @@ class Client(Methods, BaseClient):
except KeyError: except KeyError:
self.plugins = None self.plugins = None
def load_session(self): async def load_session(self):
self.storage.open() self.storage.open()
session_empty = any([ session_empty = any([
@ -1512,7 +1516,7 @@ class Client(Methods, BaseClient):
self.storage.date(0) self.storage.date(0)
self.storage.test_mode(self.test_mode) self.storage.test_mode(self.test_mode)
self.storage.auth_key(Auth(self, self.storage.dc_id()).create()) self.storage.auth_key(await Auth(self, self.storage.dc_id()).create())
self.storage.user_id(None) self.storage.user_id(None)
self.storage.is_bot(None) self.storage.is_bot(None)
@ -1632,13 +1636,13 @@ class Client(Methods, BaseClient):
self.session_name, name, module_path)) self.session_name, name, module_path))
if count > 0: if count > 0:
log.warning('[{}] Successfully loaded {} plugin{} from "{}"'.format( log.info('[{}] Successfully loaded {} plugin{} from "{}"'.format(
self.session_name, count, "s" if count > 1 else "", root)) self.session_name, count, "s" if count > 1 else "", root))
else: else:
log.warning('[{}] No plugin loaded from "{}"'.format( log.warning('[{}] No plugin loaded from "{}"'.format(
self.session_name, root)) self.session_name, root))
def resolve_peer(self, peer_id: Union[int, str]): async def resolve_peer(self, peer_id: Union[int, str]):
"""Get the InputPeer of a known peer id. """Get the InputPeer of a known peer id.
Useful whenever an InputPeer type is required. Useful whenever an InputPeer type is required.
@ -1677,7 +1681,7 @@ class Client(Methods, BaseClient):
try: try:
return self.storage.get_peer_by_username(peer_id) return self.storage.get_peer_by_username(peer_id)
except KeyError: except KeyError:
self.send( await self.send(
functions.contacts.ResolveUsername( functions.contacts.ResolveUsername(
username=peer_id username=peer_id
) )
@ -1694,7 +1698,7 @@ class Client(Methods, BaseClient):
if peer_type == "user": if peer_type == "user":
self.fetch_peers( self.fetch_peers(
self.send( await self.send(
functions.users.GetUsers( functions.users.GetUsers(
id=[ id=[
types.InputUser( types.InputUser(
@ -1706,13 +1710,13 @@ class Client(Methods, BaseClient):
) )
) )
elif peer_type == "chat": elif peer_type == "chat":
self.send( await self.send(
functions.messages.GetChats( functions.messages.GetChats(
id=[-peer_id] id=[-peer_id]
) )
) )
else: else:
self.send( await self.send(
functions.channels.GetChannels( functions.channels.GetChannels(
id=[ id=[
types.InputChannel( types.InputChannel(
@ -1728,7 +1732,7 @@ class Client(Methods, BaseClient):
except KeyError: except KeyError:
raise PeerIdInvalid raise PeerIdInvalid
def save_file( async def save_file(
self, self,
path: Union[str, BinaryIO], path: Union[str, BinaryIO],
file_id: int = None, file_id: int = None,
@ -1786,6 +1790,18 @@ class Client(Methods, BaseClient):
if path is None: if path is None:
return None return None
async def worker(session):
while True:
data = await queue.get()
if data is None:
return
try:
await asyncio.ensure_future(session.send(data))
except Exception as e:
logging.error(e)
part_size = 512 * 1024 part_size = 512 * 1024
if isinstance(path, str): if isinstance(path, str):
@ -1808,15 +1824,20 @@ class Client(Methods, BaseClient):
raise ValueError("Telegram doesn't support uploading files bigger than 2000 MiB") raise ValueError("Telegram doesn't support uploading files bigger than 2000 MiB")
file_total_parts = int(math.ceil(file_size / part_size)) file_total_parts = int(math.ceil(file_size / part_size))
is_big = True if file_size > 10 * 1024 * 1024 else False is_big = file_size > 10 * 1024 * 1024
is_missing_part = True if file_id is not None else False pool_size = 3 if is_big else 1
workers_count = 4 if is_big else 1
is_missing_part = file_id is not None
file_id = file_id or self.rnd_id() file_id = file_id or self.rnd_id()
md5_sum = md5() if not is_big and not is_missing_part else None md5_sum = md5() if not is_big and not is_missing_part else None
pool = [Session(self, self.storage.dc_id(), self.storage.auth_key(), is_media=True) for _ in range(pool_size)]
session = Session(self, self.storage.dc_id(), self.storage.auth_key(), is_media=True) workers = [asyncio.ensure_future(worker(session)) for session in pool for _ in range(workers_count)]
session.start() queue = asyncio.Queue(16)
try: try:
for session in pool:
await session.start()
with fp: with fp:
fp.seek(part_size * file_part) fp.seek(part_size * file_part)
@ -1828,7 +1849,6 @@ class Client(Methods, BaseClient):
md5_sum = "".join([hex(i)[2:].zfill(2) for i in md5_sum.digest()]) md5_sum = "".join([hex(i)[2:].zfill(2) for i in md5_sum.digest()])
break break
for _ in range(3):
if is_big: if is_big:
rpc = functions.upload.SaveBigFilePart( rpc = functions.upload.SaveBigFilePart(
file_id=file_id, file_id=file_id,
@ -1843,10 +1863,7 @@ class Client(Methods, BaseClient):
bytes=chunk bytes=chunk
) )
if session.send(rpc): await queue.put(rpc)
break
else:
raise AssertionError("Telegram didn't accept chunk #{} of {}".format(file_part, path))
if is_missing_part: if is_missing_part:
return return
@ -1857,7 +1874,7 @@ class Client(Methods, BaseClient):
file_part += 1 file_part += 1
if progress: if progress:
progress(min(file_part * part_size, file_size), file_size, *progress_args) await progress(min(file_part * part_size, file_size), file_size, *progress_args)
except Client.StopTransmission: except Client.StopTransmission:
raise raise
except Exception as e: except Exception as e:
@ -1878,9 +1895,15 @@ class Client(Methods, BaseClient):
md5_checksum=md5_sum md5_checksum=md5_sum
) )
finally: finally:
session.stop() for _ in workers:
await queue.put(None)
def get_file( await asyncio.gather(*workers)
for session in pool:
await session.stop()
async def get_file(
self, self,
media_type: int, media_type: int,
dc_id: int, dc_id: int,
@ -1898,23 +1921,23 @@ class Client(Methods, BaseClient):
progress: callable, progress: callable,
progress_args: tuple = () progress_args: tuple = ()
) -> str: ) -> str:
with self.media_sessions_lock: async with self.media_sessions_lock:
session = self.media_sessions.get(dc_id, None) session = self.media_sessions.get(dc_id, None)
if session is None: if session is None:
if dc_id != self.storage.dc_id(): if dc_id != self.storage.dc_id():
session = Session(self, dc_id, Auth(self, dc_id).create(), is_media=True) session = Session(self, dc_id, await Auth(self, dc_id).create(), is_media=True)
session.start() await session.start()
for _ in range(3): for _ in range(3):
exported_auth = self.send( exported_auth = await self.send(
functions.auth.ExportAuthorization( functions.auth.ExportAuthorization(
dc_id=dc_id dc_id=dc_id
) )
) )
try: try:
session.send( await session.send(
functions.auth.ImportAuthorization( functions.auth.ImportAuthorization(
id=exported_auth.id, id=exported_auth.id,
bytes=exported_auth.bytes bytes=exported_auth.bytes
@ -1925,11 +1948,11 @@ class Client(Methods, BaseClient):
else: else:
break break
else: else:
session.stop() await session.stop()
raise AuthBytesInvalid raise AuthBytesInvalid
else: else:
session = Session(self, dc_id, self.storage.auth_key(), is_media=True) session = Session(self, dc_id, self.storage.auth_key(), is_media=True)
session.start() await session.start()
self.media_sessions[dc_id] = session self.media_sessions[dc_id] = session
@ -1984,7 +2007,7 @@ class Client(Methods, BaseClient):
file_name = "" file_name = ""
try: try:
r = session.send( r = await session.send(
functions.upload.GetFile( functions.upload.GetFile(
location=location, location=location,
offset=offset, offset=offset,
@ -2007,7 +2030,7 @@ class Client(Methods, BaseClient):
offset += limit offset += limit
if progress: if progress:
progress( await progress(
min(offset, file_size) min(offset, file_size)
if file_size != 0 if file_size != 0
else offset, else offset,
@ -2015,7 +2038,7 @@ class Client(Methods, BaseClient):
*progress_args *progress_args
) )
r = session.send( r = await session.send(
functions.upload.GetFile( functions.upload.GetFile(
location=location, location=location,
offset=offset, offset=offset,
@ -2024,13 +2047,16 @@ class Client(Methods, BaseClient):
) )
elif isinstance(r, types.upload.FileCdnRedirect): elif isinstance(r, types.upload.FileCdnRedirect):
with self.media_sessions_lock: async with self.media_sessions_lock:
cdn_session = self.media_sessions.get(r.dc_id, None) cdn_session = self.media_sessions.get(r.dc_id, None)
if cdn_session is None: if cdn_session is None:
cdn_session = Session(self, r.dc_id, Auth(self, r.dc_id).create(), is_media=True, is_cdn=True) cdn_session = Session(
self,
r.dc_id,
await Auth(self, r.dc_id).create(), is_media=True, is_cdn=True)
cdn_session.start() await cdn_session.start()
self.media_sessions[r.dc_id] = cdn_session self.media_sessions[r.dc_id] = cdn_session
@ -2039,7 +2065,7 @@ class Client(Methods, BaseClient):
file_name = f.name file_name = f.name
while True: while True:
r2 = cdn_session.send( r2 = await cdn_session.send(
functions.upload.GetCdnFile( functions.upload.GetCdnFile(
file_token=r.file_token, file_token=r.file_token,
offset=offset, offset=offset,
@ -2049,7 +2075,7 @@ class Client(Methods, BaseClient):
if isinstance(r2, types.upload.CdnFileReuploadNeeded): if isinstance(r2, types.upload.CdnFileReuploadNeeded):
try: try:
session.send( await session.send(
functions.upload.ReuploadCdnFile( functions.upload.ReuploadCdnFile(
file_token=r.file_token, file_token=r.file_token,
request_token=r2.request_token request_token=r2.request_token
@ -2072,7 +2098,7 @@ class Client(Methods, BaseClient):
) )
) )
hashes = session.send( hashes = await session.send(
functions.upload.GetCdnFileHashes( functions.upload.GetCdnFileHashes(
file_token=r.file_token, file_token=r.file_token,
offset=offset offset=offset
@ -2089,7 +2115,7 @@ class Client(Methods, BaseClient):
offset += limit offset += limit
if progress: if progress:
progress( await progress(
min(offset, file_size) min(offset, file_size)
if file_size != 0 if file_size != 0
else offset, else offset,

View File

@ -16,13 +16,12 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import os import os
import platform import platform
import re import re
import sys import sys
from pathlib import Path from pathlib import Path
from queue import Queue
from threading import Lock
from pyrogram import __version__ from pyrogram import __version__
from ..parser import Parser from ..parser import Parser
@ -30,7 +29,7 @@ from ...session.internals import MsgId
class BaseClient: class BaseClient:
class StopTransmission(StopIteration): class StopTransmission(StopAsyncIteration):
pass pass
APP_VERSION = "Pyrogram {}".format(__version__) APP_VERSION = "Pyrogram {}".format(__version__)
@ -52,7 +51,7 @@ class BaseClient:
INVITE_LINK_RE = re.compile(r"^(?:https?://)?(?:www\.)?(?:t(?:elegram)?\.(?:org|me|dog)/joinchat/)([\w-]+)$") INVITE_LINK_RE = re.compile(r"^(?:https?://)?(?:www\.)?(?:t(?:elegram)?\.(?:org|me|dog)/joinchat/)([\w-]+)$")
DIALOGS_AT_ONCE = 100 DIALOGS_AT_ONCE = 100
UPDATES_WORKERS = 4 UPDATES_WORKERS = 4
DOWNLOAD_WORKERS = 1 DOWNLOAD_WORKERS = 4
OFFLINE_SLEEP = 900 OFFLINE_SLEEP = 900
WORKERS = 4 WORKERS = 4
WORKDIR = PARENT_DIR WORKDIR = PARENT_DIR
@ -100,24 +99,24 @@ class BaseClient:
self.session = None self.session = None
self.media_sessions = {} self.media_sessions = {}
self.media_sessions_lock = Lock() self.media_sessions_lock = asyncio.Lock()
self.is_connected = None self.is_connected = None
self.is_initialized = None self.is_initialized = None
self.takeout_id = None self.takeout_id = None
self.updates_queue = Queue() self.updates_queue = asyncio.Queue()
self.updates_workers_list = [] self.updates_worker_tasks = []
self.download_queue = Queue() self.download_queue = asyncio.Queue()
self.download_workers_list = [] self.download_worker_tasks = []
self.disconnect_handler = None self.disconnect_handler = None
def send(self, *args, **kwargs): async def send(self, *args, **kwargs):
pass pass
def resolve_peer(self, *args, **kwargs): async def resolve_peer(self, *args, **kwargs):
pass pass
def fetch_peers(self, *args, **kwargs): def fetch_peers(self, *args, **kwargs):
@ -126,25 +125,46 @@ class BaseClient:
def add_handler(self, *args, **kwargs): def add_handler(self, *args, **kwargs):
pass pass
def save_file(self, *args, **kwargs): async def save_file(self, *args, **kwargs):
pass pass
def get_messages(self, *args, **kwargs): async def get_messages(self, *args, **kwargs):
pass pass
def get_history(self, *args, **kwargs): async def get_history(self, *args, **kwargs):
pass pass
def get_dialogs(self, *args, **kwargs): async def get_dialogs(self, *args, **kwargs):
pass pass
def get_chat_members(self, *args, **kwargs): async def get_chat_members(self, *args, **kwargs):
pass pass
def get_chat_members_count(self, *args, **kwargs): async def get_chat_members_count(self, *args, **kwargs):
pass pass
def answer_inline_query(self, *args, **kwargs): async def answer_inline_query(self, *args, **kwargs):
pass
async def get_profile_photos(self, *args, **kwargs):
pass
async def edit_message_text(self, *args, **kwargs):
pass
async def edit_inline_text(self, *args, **kwargs):
pass
async def edit_message_media(self, *args, **kwargs):
pass
async def edit_inline_media(self, *args, **kwargs):
pass
async def edit_message_reply_markup(self, *args, **kwargs):
pass
async def edit_inline_reply_markup(self, *args, **kwargs):
pass pass
def guess_mime_type(self, *args, **kwargs): def guess_mime_type(self, *args, **kwargs):
@ -152,24 +172,3 @@ class BaseClient:
def guess_extension(self, *args, **kwargs): def guess_extension(self, *args, **kwargs):
pass pass
def get_profile_photos(self, *args, **kwargs):
pass
def edit_message_text(self, *args, **kwargs):
pass
def edit_inline_text(self, *args, **kwargs):
pass
def edit_message_media(self, *args, **kwargs):
pass
def edit_inline_media(self, *args, **kwargs):
pass
def edit_message_reply_markup(self, *args, **kwargs):
pass
def edit_inline_reply_markup(self, *args, **kwargs):
pass

View File

@ -16,11 +16,9 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging import logging
import threading
from collections import OrderedDict from collections import OrderedDict
from queue import Queue
from threading import Thread, Lock
import pyrogram import pyrogram
from pyrogram.api.types import ( from pyrogram.api.types import (
@ -69,75 +67,74 @@ class Dispatcher:
self.client = client self.client = client
self.workers = workers self.workers = workers
self.workers_list = [] self.update_worker_tasks = []
self.locks_list = [] self.locks_list = []
self.updates_queue = Queue() self.updates_queue = asyncio.Queue()
self.groups = OrderedDict() self.groups = OrderedDict()
async def message_parser(update, users, chats):
return await pyrogram.Message._parse(
self.client, update.message, users, chats,
isinstance(update, UpdateNewScheduledMessage)
), MessageHandler
async def deleted_messages_parser(update, users, chats):
return utils.parse_deleted_messages(self.client, update), DeletedMessagesHandler
async def callback_query_parser(update, users, chats):
return await pyrogram.CallbackQuery._parse(self.client, update, users), CallbackQueryHandler
async def user_status_parser(update, users, chats):
return pyrogram.User._parse_user_status(self.client, update), UserStatusHandler
async def inline_query_parser(update, users, chats):
return pyrogram.InlineQuery._parse(self.client, update, users), InlineQueryHandler
async def poll_parser(update, users, chats):
return pyrogram.Poll._parse_update(self.client, update), PollHandler
async def chosen_inline_result_parser(update, users, chats):
return pyrogram.ChosenInlineResult._parse(self.client, update, users), ChosenInlineResultHandler
self.update_parsers = { self.update_parsers = {
Dispatcher.MESSAGE_UPDATES: Dispatcher.MESSAGE_UPDATES: message_parser,
lambda upd, usr, cht: ( Dispatcher.DELETE_MESSAGES_UPDATES: deleted_messages_parser,
pyrogram.Message._parse( Dispatcher.CALLBACK_QUERY_UPDATES: callback_query_parser,
self.client, (UpdateUserStatus,): user_status_parser,
upd.message, (UpdateBotInlineQuery,): inline_query_parser,
usr, (UpdateMessagePoll,): poll_parser,
cht, (UpdateBotInlineSend,): chosen_inline_result_parser
isinstance(upd, UpdateNewScheduledMessage)
),
MessageHandler
),
Dispatcher.DELETE_MESSAGES_UPDATES:
lambda upd, usr, cht: (utils.parse_deleted_messages(self.client, upd), DeletedMessagesHandler),
Dispatcher.CALLBACK_QUERY_UPDATES:
lambda upd, usr, cht: (pyrogram.CallbackQuery._parse(self.client, upd, usr), CallbackQueryHandler),
(UpdateUserStatus,):
lambda upd, usr, cht: (pyrogram.User._parse_user_status(self.client, upd), UserStatusHandler),
(UpdateBotInlineQuery,):
lambda upd, usr, cht: (pyrogram.InlineQuery._parse(self.client, upd, usr), InlineQueryHandler),
(UpdateMessagePoll,):
lambda upd, usr, cht: (pyrogram.Poll._parse_update(self.client, upd), PollHandler),
(UpdateBotInlineSend,):
lambda upd, usr, cht: (pyrogram.ChosenInlineResult._parse(self.client, upd, usr),
ChosenInlineResultHandler)
} }
self.update_parsers = {key: value for key_tuple, value in self.update_parsers.items() for key in key_tuple} self.update_parsers = {key: value for key_tuple, value in self.update_parsers.items() for key in key_tuple}
def start(self): async def start(self):
for i in range(self.workers): for i in range(self.workers):
self.locks_list.append(Lock()) self.locks_list.append(asyncio.Lock())
self.workers_list.append( self.update_worker_tasks.append(
Thread( asyncio.ensure_future(self.update_worker(self.locks_list[-1]))
target=self.update_worker,
name="UpdateWorker#{}".format(i + 1),
args=(self.locks_list[-1],)
)
) )
self.workers_list[-1].start() logging.info("Started {} UpdateWorkerTasks".format(self.workers))
def stop(self): async def stop(self):
for _ in range(self.workers): for i in range(self.workers):
self.updates_queue.put(None) self.updates_queue.put_nowait(None)
for worker in self.workers_list: for i in self.update_worker_tasks:
worker.join() await i
self.workers_list.clear() self.update_worker_tasks.clear()
self.locks_list.clear()
self.groups.clear() self.groups.clear()
logging.info("Stopped {} UpdateWorkerTasks".format(self.workers))
def add_handler(self, handler, group: int): def add_handler(self, handler, group: int):
async def fn():
for lock in self.locks_list: for lock in self.locks_list:
lock.acquire() await lock.acquire()
try: try:
if group not in self.groups: if group not in self.groups:
@ -149,9 +146,12 @@ class Dispatcher:
for lock in self.locks_list: for lock in self.locks_list:
lock.release() lock.release()
asyncio.ensure_future(fn())
def remove_handler(self, handler, group: int): def remove_handler(self, handler, group: int):
async def fn():
for lock in self.locks_list: for lock in self.locks_list:
lock.acquire() await lock.acquire()
try: try:
if group not in self.groups: if group not in self.groups:
@ -162,12 +162,11 @@ class Dispatcher:
for lock in self.locks_list: for lock in self.locks_list:
lock.release() lock.release()
def update_worker(self, lock): asyncio.ensure_future(fn())
name = threading.current_thread().name
log.debug("{} started".format(name))
async def update_worker(self, lock):
while True: while True:
packet = self.updates_queue.get() packet = await self.updates_queue.get()
if packet is None: if packet is None:
break break
@ -177,12 +176,12 @@ class Dispatcher:
parser = self.update_parsers.get(type(update), None) parser = self.update_parsers.get(type(update), None)
parsed_update, handler_type = ( parsed_update, handler_type = (
parser(update, users, chats) await parser(update, users, chats)
if parser is not None if parser is not None
else (None, type(None)) else (None, type(None))
) )
with lock: async with lock:
for group in self.groups.values(): for group in self.groups.values():
for handler in group: for handler in group:
args = None args = None
@ -202,7 +201,7 @@ class Dispatcher:
continue continue
try: try:
handler.callback(self.client, *args) await handler.callback(self.client, *args)
except pyrogram.StopPropagation: except pyrogram.StopPropagation:
raise raise
except pyrogram.ContinuePropagation: except pyrogram.ContinuePropagation:
@ -215,5 +214,3 @@ class Dispatcher:
pass pass
except Exception as e: except Exception as e:
log.error(e, exc_info=True) log.error(e, exc_info=True)
log.debug("{} stopped".format(name))

View File

@ -16,9 +16,9 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging import logging
import time import time
from threading import Thread, Event, Lock
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -27,13 +27,18 @@ class Syncer:
INTERVAL = 20 INTERVAL = 20
clients = {} clients = {}
thread = None event = None
event = Event() lock = None
lock = Lock()
@classmethod @classmethod
def add(cls, client): async def add(cls, client):
with cls.lock: if cls.event is None:
cls.event = asyncio.Event()
if cls.lock is None:
cls.lock = asyncio.Lock()
async with cls.lock:
cls.sync(client) cls.sync(client)
cls.clients[id(client)] = client cls.clients[id(client)] = client
@ -42,8 +47,8 @@ class Syncer:
cls.start() cls.start()
@classmethod @classmethod
def remove(cls, client): async def remove(cls, client):
with cls.lock: async with cls.lock:
cls.sync(client) cls.sync(client)
del cls.clients[id(client)] del cls.clients[id(client)]
@ -54,24 +59,23 @@ class Syncer:
@classmethod @classmethod
def start(cls): def start(cls):
cls.event.clear() cls.event.clear()
cls.thread = Thread(target=cls.worker, name=cls.__name__) asyncio.ensure_future(cls.worker())
cls.thread.start()
@classmethod @classmethod
def stop(cls): def stop(cls):
cls.event.set() cls.event.set()
@classmethod @classmethod
def worker(cls): async def worker(cls):
while True: while True:
cls.event.wait(cls.INTERVAL) try:
await asyncio.wait_for(cls.event.wait(), cls.INTERVAL)
if cls.event.is_set(): except asyncio.TimeoutError:
break async with cls.lock:
with cls.lock:
for client in cls.clients.values(): for client in cls.clients.values():
cls.sync(client) cls.sync(client)
else:
break
@classmethod @classmethod
def sync(cls, client): def sync(cls, client):

View File

@ -16,8 +16,11 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import base64 import base64
import struct import struct
import sys
from concurrent.futures.thread import ThreadPoolExecutor
from typing import List from typing import List
from typing import Union from typing import Union
@ -80,6 +83,15 @@ def decode_file_ref(file_ref: str) -> bytes:
return base64.urlsafe_b64decode(file_ref + "=" * (-len(file_ref) % 4)) return base64.urlsafe_b64decode(file_ref + "=" * (-len(file_ref) % 4))
async def ainput(prompt: str = ""):
print(prompt, end="", flush=True)
with ThreadPoolExecutor(1) as executor:
return (await asyncio.get_event_loop().run_in_executor(
executor, sys.stdin.readline
)).rstrip()
def get_offset_date(dialogs): def get_offset_date(dialogs):
for m in reversed(dialogs.messages): for m in reversed(dialogs.messages):
if isinstance(m, types.MessageEmpty): if isinstance(m, types.MessageEmpty):
@ -141,24 +153,24 @@ def get_input_media_from_file_id(
raise ValueError("Unknown media type: {}".format(file_id_str)) raise ValueError("Unknown media type: {}".format(file_id_str))
def parse_messages(client, messages: types.messages.Messages, replies: int = 1) -> List["pyrogram.Message"]: async def parse_messages(client, messages: types.messages.Messages, replies: int = 1) -> List["pyrogram.Message"]:
users = {i.id: i for i in messages.users} users = {i.id: i for i in messages.users}
chats = {i.id: i for i in messages.chats} chats = {i.id: i for i in messages.chats}
if not messages.messages: if not messages.messages:
return pyrogram.List() return pyrogram.List()
parsed_messages = [ parsed_messages = []
pyrogram.Message._parse(client, message, users, chats, replies=0)
for message in messages.messages for message in messages.messages:
] parsed_messages.append(await pyrogram.Message._parse(client, message, users, chats, replies=0))
if replies: if replies:
messages_with_replies = {i.id: getattr(i, "reply_to_msg_id", None) for i in messages.messages} messages_with_replies = {i.id: getattr(i, "reply_to_msg_id", None) for i in messages.messages}
reply_message_ids = [i[0] for i in filter(lambda x: x[1] is not None, messages_with_replies.items())] reply_message_ids = [i[0] for i in filter(lambda x: x[1] is not None, messages_with_replies.items())]
if reply_message_ids: if reply_message_ids:
reply_messages = client.get_messages( reply_messages = await client.get_messages(
parsed_messages[0].chat.id, parsed_messages[0].chat.id,
reply_to_message_ids=reply_message_ids, reply_to_message_ids=reply_message_ids,
replies=replies - 1 replies=replies - 1

View File

@ -21,7 +21,7 @@ from pyrogram.client.ext import BaseClient
class AnswerCallbackQuery(BaseClient): class AnswerCallbackQuery(BaseClient):
def answer_callback_query( async def answer_callback_query(
self, self,
callback_query_id: str, callback_query_id: str,
text: str = None, text: str = None,
@ -68,7 +68,7 @@ class AnswerCallbackQuery(BaseClient):
# Answer with alert # Answer with alert
app.answer_callback_query(query_id, text=text, show_alert=True) app.answer_callback_query(query_id, text=text, show_alert=True)
""" """
return self.send( return await self.send(
functions.messages.SetBotCallbackAnswer( functions.messages.SetBotCallbackAnswer(
query_id=int(callback_query_id), query_id=int(callback_query_id),
cache_time=cache_time, cache_time=cache_time,

View File

@ -24,7 +24,7 @@ from ...types.inline_mode import InlineQueryResult
class AnswerInlineQuery(BaseClient): class AnswerInlineQuery(BaseClient):
def answer_inline_query( async def answer_inline_query(
self, self,
inline_query_id: str, inline_query_id: str,
results: List[InlineQueryResult], results: List[InlineQueryResult],
@ -93,10 +93,15 @@ class AnswerInlineQuery(BaseClient):
"Title", "Title",
InputTextMessageContent("Message content"))]) InputTextMessageContent("Message content"))])
""" """
return self.send( written_results = [] # Py 3.5 doesn't support await inside comprehensions
for r in results:
written_results.append(await r.write())
return await self.send(
functions.messages.SetInlineBotResults( functions.messages.SetInlineBotResults(
query_id=int(inline_query_id), query_id=int(inline_query_id),
results=[r.write() for r in results], results=written_results,
cache_time=cache_time, cache_time=cache_time,
gallery=is_gallery or None, gallery=is_gallery or None,
private=is_personal or None, private=is_personal or None,

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class GetGameHighScores(BaseClient): class GetGameHighScores(BaseClient):
def get_game_high_scores( async def get_game_high_scores(
self, self,
user_id: Union[int, str], user_id: Union[int, str],
chat_id: Union[int, str], chat_id: Union[int, str],
@ -59,11 +59,11 @@ class GetGameHighScores(BaseClient):
""" """
# TODO: inline_message_id # TODO: inline_message_id
r = self.send( r = await self.send(
functions.messages.GetGameHighScores( functions.messages.GetGameHighScores(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
id=message_id, id=message_id,
user_id=self.resolve_peer(user_id) user_id=await self.resolve_peer(user_id)
) )
) )

View File

@ -24,7 +24,7 @@ from pyrogram.errors import UnknownError
class GetInlineBotResults(BaseClient): class GetInlineBotResults(BaseClient):
def get_inline_bot_results( async def get_inline_bot_results(
self, self,
bot: Union[int, str], bot: Union[int, str],
query: str = "", query: str = "",
@ -70,9 +70,9 @@ class GetInlineBotResults(BaseClient):
# TODO: Don't return the raw type # TODO: Don't return the raw type
try: try:
return self.send( return await self.send(
functions.messages.GetInlineBotResults( functions.messages.GetInlineBotResults(
bot=self.resolve_peer(bot), bot=await self.resolve_peer(bot),
peer=types.InputPeerSelf(), peer=types.InputPeerSelf(),
query=query, query=query,
offset=offset, offset=offset,

View File

@ -23,7 +23,7 @@ from pyrogram.client.ext import BaseClient
class RequestCallbackAnswer(BaseClient): class RequestCallbackAnswer(BaseClient):
def request_callback_answer( async def request_callback_answer(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
message_id: int, message_id: int,
@ -64,9 +64,9 @@ class RequestCallbackAnswer(BaseClient):
# Telegram only wants bytes, but we are allowed to pass strings too. # Telegram only wants bytes, but we are allowed to pass strings too.
data = bytes(callback_data, "utf-8") if isinstance(callback_data, str) else callback_data data = bytes(callback_data, "utf-8") if isinstance(callback_data, str) else callback_data
return self.send( return await self.send(
functions.messages.GetBotCallbackAnswer( functions.messages.GetBotCallbackAnswer(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
msg_id=message_id, msg_id=message_id,
data=data data=data
), ),

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class SendGame(BaseClient): class SendGame(BaseClient):
def send_game( async def send_game(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
game_short_name: str, game_short_name: str,
@ -67,9 +67,9 @@ class SendGame(BaseClient):
app.send_game(chat_id, "gamename") app.send_game(chat_id, "gamename")
""" """
r = self.send( r = await self.send(
functions.messages.SendMedia( functions.messages.SendMedia(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=types.InputMediaGame( media=types.InputMediaGame(
id=types.InputGameShortName( id=types.InputGameShortName(
bot_id=types.InputUserSelf(), bot_id=types.InputUserSelf(),
@ -86,7 +86,7 @@ class SendGame(BaseClient):
for i in r.updates: for i in r.updates:
if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage)): if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage)):
return pyrogram.Message._parse( return await pyrogram.Message._parse(
self, i.message, self, i.message,
{i.id: i for i in r.users}, {i.id: i for i in r.users},
{i.id: i for i in r.chats} {i.id: i for i in r.chats}

View File

@ -23,7 +23,7 @@ from pyrogram.client.ext import BaseClient
class SendInlineBotResult(BaseClient): class SendInlineBotResult(BaseClient):
def send_inline_bot_result( async def send_inline_bot_result(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
query_id: int, query_id: int,
@ -65,9 +65,9 @@ class SendInlineBotResult(BaseClient):
app.send_inline_bot_result(chat_id, query_id, result_id) app.send_inline_bot_result(chat_id, query_id, result_id)
""" """
return self.send( return await self.send(
functions.messages.SendInlineBotResult( functions.messages.SendInlineBotResult(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
query_id=query_id, query_id=query_id,
id=result_id, id=result_id,
random_id=self.rnd_id(), random_id=self.rnd_id(),

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class SetGameScore(BaseClient): class SetGameScore(BaseClient):
def set_game_score( async def set_game_score(
self, self,
user_id: Union[int, str], user_id: Union[int, str],
score: int, score: int,
@ -75,12 +75,12 @@ class SetGameScore(BaseClient):
# Force set new score # Force set new score
app.set_game_score(user_id, 25, force=True) app.set_game_score(user_id, 25, force=True)
""" """
r = self.send( r = await self.send(
functions.messages.SetGameScore( functions.messages.SetGameScore(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
score=score, score=score,
id=message_id, id=message_id,
user_id=self.resolve_peer(user_id), user_id=await self.resolve_peer(user_id),
force=force or None, force=force or None,
edit_message=not disable_edit_message or None edit_message=not disable_edit_message or None
) )
@ -88,7 +88,7 @@ class SetGameScore(BaseClient):
for i in r.updates: for i in r.updates:
if isinstance(i, (types.UpdateEditMessage, types.UpdateEditChannelMessage)): if isinstance(i, (types.UpdateEditMessage, types.UpdateEditChannelMessage)):
return pyrogram.Message._parse( return await pyrogram.Message._parse(
self, i.message, self, i.message,
{i.id: i for i in r.users}, {i.id: i for i in r.users},
{i.id: i for i in r.chats} {i.id: i for i in r.chats}

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class AddChatMembers(BaseClient): class AddChatMembers(BaseClient):
def add_chat_members( async def add_chat_members(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
user_ids: Union[Union[int, str], List[Union[int, str]]], user_ids: Union[Union[int, str], List[Union[int, str]]],
@ -60,26 +60,26 @@ class AddChatMembers(BaseClient):
# Change forward_limit (for basic groups only) # Change forward_limit (for basic groups only)
app.add_chat_members(chat_id, user_id, forward_limit=25) app.add_chat_members(chat_id, user_id, forward_limit=25)
""" """
peer = self.resolve_peer(chat_id) peer = await self.resolve_peer(chat_id)
if not isinstance(user_ids, list): if not isinstance(user_ids, list):
user_ids = [user_ids] user_ids = [user_ids]
if isinstance(peer, types.InputPeerChat): if isinstance(peer, types.InputPeerChat):
for user_id in user_ids: for user_id in user_ids:
self.send( await self.send(
functions.messages.AddChatUser( functions.messages.AddChatUser(
chat_id=peer.chat_id, chat_id=peer.chat_id,
user_id=self.resolve_peer(user_id), user_id=await self.resolve_peer(user_id),
fwd_limit=forward_limit fwd_limit=forward_limit
) )
) )
else: else:
self.send( await self.send(
functions.channels.InviteToChannel( functions.channels.InviteToChannel(
channel=peer, channel=peer,
users=[ users=[
self.resolve_peer(user_id) await self.resolve_peer(user_id)
for user_id in user_ids for user_id in user_ids
] ]
) )

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class ArchiveChats(BaseClient): class ArchiveChats(BaseClient):
def archive_chats( async def archive_chats(
self, self,
chat_ids: Union[int, str, List[Union[int, str]]], chat_ids: Union[int, str, List[Union[int, str]]],
) -> bool: ) -> bool:
@ -50,14 +50,19 @@ class ArchiveChats(BaseClient):
if not isinstance(chat_ids, list): if not isinstance(chat_ids, list):
chat_ids = [chat_ids] chat_ids = [chat_ids]
self.send( folder_peers = []
functions.folders.EditPeerFolders(
folder_peers=[ for chat in chat_ids:
folder_peers.append(
types.InputFolderPeer( types.InputFolderPeer(
peer=self.resolve_peer(chat), peer=await self.resolve_peer(chat),
folder_id=1 folder_id=1
) for chat in chat_ids )
] )
await self.send(
functions.folders.EditPeerFolders(
folder_peers=folder_peers
) )
) )

View File

@ -22,7 +22,7 @@ from ...ext import BaseClient
class CreateChannel(BaseClient): class CreateChannel(BaseClient):
def create_channel( async def create_channel(
self, self,
title: str, title: str,
description: str = "" description: str = ""
@ -44,7 +44,7 @@ class CreateChannel(BaseClient):
app.create_channel("Channel Title", "Channel Description") app.create_channel("Channel Title", "Channel Description")
""" """
r = self.send( r = await self.send(
functions.channels.CreateChannel( functions.channels.CreateChannel(
title=title, title=title,
about=description, about=description,

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient
class CreateGroup(BaseClient): class CreateGroup(BaseClient):
def create_group( async def create_group(
self, self,
title: str, title: str,
users: Union[Union[int, str], List[Union[int, str]]] users: Union[Union[int, str], List[Union[int, str]]]
@ -55,10 +55,10 @@ class CreateGroup(BaseClient):
if not isinstance(users, list): if not isinstance(users, list):
users = [users] users = [users]
r = self.send( r = await self.send(
functions.messages.CreateChat( functions.messages.CreateChat(
title=title, title=title,
users=[self.resolve_peer(u) for u in users] users=[await self.resolve_peer(u) for u in users]
) )
) )

View File

@ -22,7 +22,7 @@ from ...ext import BaseClient
class CreateSupergroup(BaseClient): class CreateSupergroup(BaseClient):
def create_supergroup( async def create_supergroup(
self, self,
title: str, title: str,
description: str = "" description: str = ""
@ -48,7 +48,7 @@ class CreateSupergroup(BaseClient):
app.create_supergroup("Supergroup Title", "Supergroup Description") app.create_supergroup("Supergroup Title", "Supergroup Description")
""" """
r = self.send( r = await self.send(
functions.channels.CreateChannel( functions.channels.CreateChannel(
title=title, title=title,
about=description, about=description,

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class DeleteChannel(BaseClient): class DeleteChannel(BaseClient):
def delete_channel(self, chat_id: Union[int, str]) -> bool: async def delete_channel(self, chat_id: Union[int, str]) -> bool:
"""Delete a channel. """Delete a channel.
Parameters: Parameters:
@ -38,9 +38,9 @@ class DeleteChannel(BaseClient):
app.delete_channel(channel_id) app.delete_channel(channel_id)
""" """
self.send( await self.send(
functions.channels.DeleteChannel( functions.channels.DeleteChannel(
channel=self.resolve_peer(chat_id) channel=await self.resolve_peer(chat_id)
) )
) )

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class DeleteChatPhoto(BaseClient): class DeleteChatPhoto(BaseClient):
def delete_chat_photo( async def delete_chat_photo(
self, self,
chat_id: Union[int, str] chat_id: Union[int, str]
) -> bool: ) -> bool:
@ -46,17 +46,17 @@ class DeleteChatPhoto(BaseClient):
app.delete_chat_photo(chat_id) app.delete_chat_photo(chat_id)
""" """
peer = self.resolve_peer(chat_id) peer = await self.resolve_peer(chat_id)
if isinstance(peer, types.InputPeerChat): if isinstance(peer, types.InputPeerChat):
self.send( await self.send(
functions.messages.EditChatPhoto( functions.messages.EditChatPhoto(
chat_id=peer.chat_id, chat_id=peer.chat_id,
photo=types.InputChatPhotoEmpty() photo=types.InputChatPhotoEmpty()
) )
) )
elif isinstance(peer, types.InputPeerChannel): elif isinstance(peer, types.InputPeerChannel):
self.send( await self.send(
functions.channels.EditPhoto( functions.channels.EditPhoto(
channel=peer, channel=peer,
photo=types.InputChatPhotoEmpty() photo=types.InputChatPhotoEmpty()

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class DeleteSupergroup(BaseClient): class DeleteSupergroup(BaseClient):
def delete_supergroup(self, chat_id: Union[int, str]) -> bool: async def delete_supergroup(self, chat_id: Union[int, str]) -> bool:
"""Delete a supergroup. """Delete a supergroup.
Parameters: Parameters:
@ -38,9 +38,9 @@ class DeleteSupergroup(BaseClient):
app.delete_supergroup(supergroup_id) app.delete_supergroup(supergroup_id)
""" """
self.send( await self.send(
functions.channels.DeleteChannel( functions.channels.DeleteChannel(
channel=self.resolve_peer(chat_id) channel=await self.resolve_peer(chat_id)
) )
) )

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class ExportChatInviteLink(BaseClient): class ExportChatInviteLink(BaseClient):
def export_chat_invite_link( async def export_chat_invite_link(
self, self,
chat_id: Union[int, str] chat_id: Union[int, str]
) -> str: ) -> str:
@ -55,13 +55,13 @@ class ExportChatInviteLink(BaseClient):
link = app.export_chat_invite_link(chat_id) link = app.export_chat_invite_link(chat_id)
print(link) print(link)
""" """
peer = self.resolve_peer(chat_id) peer = await self.resolve_peer(chat_id)
if isinstance(peer, (types.InputPeerChat, types.InputPeerChannel)): if isinstance(peer, (types.InputPeerChat, types.InputPeerChannel)):
return self.send( return (await self.send(
functions.messages.ExportChatInvite( functions.messages.ExportChatInvite(
peer=peer peer=peer
) )
).link )).link
else: else:
raise ValueError('The chat_id "{}" belongs to a user'.format(chat_id)) raise ValueError('The chat_id "{}" belongs to a user'.format(chat_id))

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient, utils
class GetChat(BaseClient): class GetChat(BaseClient):
def get_chat( async def get_chat(
self, self,
chat_id: Union[int, str] chat_id: Union[int, str]
) -> Union["pyrogram.Chat", "pyrogram.ChatPreview"]: ) -> Union["pyrogram.Chat", "pyrogram.ChatPreview"]:
@ -55,7 +55,7 @@ class GetChat(BaseClient):
match = self.INVITE_LINK_RE.match(str(chat_id)) match = self.INVITE_LINK_RE.match(str(chat_id))
if match: if match:
r = self.send( r = await self.send(
functions.messages.CheckChatInvite( functions.messages.CheckChatInvite(
hash=match.group(1) hash=match.group(1)
) )
@ -72,13 +72,13 @@ class GetChat(BaseClient):
if isinstance(r.chat, types.Channel): if isinstance(r.chat, types.Channel):
chat_id = utils.get_channel_id(r.chat.id) chat_id = utils.get_channel_id(r.chat.id)
peer = self.resolve_peer(chat_id) peer = await self.resolve_peer(chat_id)
if isinstance(peer, types.InputPeerChannel): if isinstance(peer, types.InputPeerChannel):
r = self.send(functions.channels.GetFullChannel(channel=peer)) r = await self.send(functions.channels.GetFullChannel(channel=peer))
elif isinstance(peer, (types.InputPeerUser, types.InputPeerSelf)): elif isinstance(peer, (types.InputPeerUser, types.InputPeerSelf)):
r = self.send(functions.users.GetFullUser(id=peer)) r = await self.send(functions.users.GetFullUser(id=peer))
else: else:
r = self.send(functions.messages.GetFullChat(chat_id=peer.chat_id)) r = await self.send(functions.messages.GetFullChat(chat_id=peer.chat_id))
return pyrogram.Chat._parse_full(self, r) return await pyrogram.Chat._parse_full(self, r)

View File

@ -21,11 +21,12 @@ from typing import Union
import pyrogram import pyrogram
from pyrogram.api import functions, types from pyrogram.api import functions, types
from pyrogram.errors import UserNotParticipant from pyrogram.errors import UserNotParticipant
from ...ext import BaseClient from ...ext import BaseClient
class GetChatMember(BaseClient): class GetChatMember(BaseClient):
def get_chat_member( async def get_chat_member(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
user_id: Union[int, str] user_id: Union[int, str]
@ -50,11 +51,11 @@ class GetChatMember(BaseClient):
dan = app.get_chat_member("pyrogramchat", "haskell") dan = app.get_chat_member("pyrogramchat", "haskell")
print(dan) print(dan)
""" """
chat = self.resolve_peer(chat_id) chat = await self.resolve_peer(chat_id)
user = self.resolve_peer(user_id) user = await self.resolve_peer(user_id)
if isinstance(chat, types.InputPeerChat): if isinstance(chat, types.InputPeerChat):
r = self.send( r = await self.send(
functions.messages.GetFullChat( functions.messages.GetFullChat(
chat_id=chat.chat_id chat_id=chat.chat_id
) )
@ -75,7 +76,7 @@ class GetChatMember(BaseClient):
else: else:
raise UserNotParticipant raise UserNotParticipant
elif isinstance(chat, types.InputPeerChannel): elif isinstance(chat, types.InputPeerChannel):
r = self.send( r = await self.send(
functions.channels.GetParticipant( functions.channels.GetParticipant(
channel=chat, channel=chat,
user_id=user user_id=user

View File

@ -36,7 +36,7 @@ class Filters:
class GetChatMembers(BaseClient): class GetChatMembers(BaseClient):
def get_chat_members( async def get_chat_members(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
offset: int = 0, offset: int = 0,
@ -103,10 +103,10 @@ class GetChatMembers(BaseClient):
# Get all bots # Get all bots
app.get_chat_members("pyrogramchat", filter="bots") app.get_chat_members("pyrogramchat", filter="bots")
""" """
peer = self.resolve_peer(chat_id) peer = await self.resolve_peer(chat_id)
if isinstance(peer, types.InputPeerChat): if isinstance(peer, types.InputPeerChat):
r = self.send( r = await self.send(
functions.messages.GetFullChat( functions.messages.GetFullChat(
chat_id=peer.chat_id chat_id=peer.chat_id
) )
@ -134,7 +134,7 @@ class GetChatMembers(BaseClient):
else: else:
raise ValueError("Invalid filter \"{}\"".format(filter)) raise ValueError("Invalid filter \"{}\"".format(filter))
r = self.send( r = await self.send(
functions.channels.GetParticipants( functions.channels.GetParticipants(
channel=peer, channel=peer,
filter=filter, filter=filter,

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class GetChatMembersCount(BaseClient): class GetChatMembersCount(BaseClient):
def get_chat_members_count( async def get_chat_members_count(
self, self,
chat_id: Union[int, str] chat_id: Union[int, str]
) -> int: ) -> int:
@ -45,19 +45,23 @@ class GetChatMembersCount(BaseClient):
count = app.get_chat_members_count("pyrogramchat") count = app.get_chat_members_count("pyrogramchat")
print(count) print(count)
""" """
peer = self.resolve_peer(chat_id) peer = await self.resolve_peer(chat_id)
if isinstance(peer, types.InputPeerChat): if isinstance(peer, types.InputPeerChat):
return self.send( r = await self.send(
functions.messages.GetChats( functions.messages.GetChats(
id=[peer.chat_id] id=[peer.chat_id]
) )
).chats[0].participants_count )
return r.chats[0].participants_count
elif isinstance(peer, types.InputPeerChannel): elif isinstance(peer, types.InputPeerChannel):
return self.send( r = await self.send(
functions.channels.GetFullChannel( functions.channels.GetFullChannel(
channel=peer channel=peer
) )
).full_chat.participants_count )
return r.full_chat.participants_count
else: else:
raise ValueError("The chat_id \"{}\" belongs to a user".format(chat_id)) raise ValueError("The chat_id \"{}\" belongs to a user".format(chat_id))

View File

@ -27,7 +27,7 @@ log = logging.getLogger(__name__)
class GetDialogs(BaseClient): class GetDialogs(BaseClient):
def get_dialogs( async def get_dialogs(
self, self,
offset_date: int = 0, offset_date: int = 0,
limit: int = 100, limit: int = 100,
@ -65,9 +65,9 @@ class GetDialogs(BaseClient):
""" """
if pinned_only: if pinned_only:
r = self.send(functions.messages.GetPinnedDialogs(folder_id=0)) r = await self.send(functions.messages.GetPinnedDialogs(folder_id=0))
else: else:
r = self.send( r = await self.send(
functions.messages.GetDialogs( functions.messages.GetDialogs(
offset_date=offset_date, offset_date=offset_date,
offset_id=0, offset_id=0,
@ -94,7 +94,7 @@ class GetDialogs(BaseClient):
else: else:
chat_id = utils.get_peer_id(to_id) chat_id = utils.get_peer_id(to_id)
messages[chat_id] = pyrogram.Message._parse(self, message, users, chats) messages[chat_id] = await pyrogram.Message._parse(self, message, users, chats)
parsed_dialogs = [] parsed_dialogs = []

View File

@ -21,7 +21,7 @@ from ...ext import BaseClient
class GetDialogsCount(BaseClient): class GetDialogsCount(BaseClient):
def get_dialogs_count(self, pinned_only: bool = False) -> int: async def get_dialogs_count(self, pinned_only: bool = False) -> int:
"""Get the total count of your dialogs. """Get the total count of your dialogs.
pinned_only (``bool``, *optional*): pinned_only (``bool``, *optional*):
@ -39,9 +39,9 @@ class GetDialogsCount(BaseClient):
""" """
if pinned_only: if pinned_only:
return len(self.send(functions.messages.GetPinnedDialogs(folder_id=0)).dialogs) return len((await self.send(functions.messages.GetPinnedDialogs(folder_id=0))).dialogs)
else: else:
r = self.send( r = await self.send(
functions.messages.GetDialogs( functions.messages.GetDialogs(
offset_date=0, offset_date=0,
offset_id=0, offset_id=0,

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient, utils
class GetNearbyChats(BaseClient): class GetNearbyChats(BaseClient):
def get_nearby_chats( async def get_nearby_chats(
self, self,
latitude: float, latitude: float,
longitude: float longitude: float
@ -48,7 +48,7 @@ class GetNearbyChats(BaseClient):
print(chats) print(chats)
""" """
r = self.send( r = await self.send(
functions.contacts.GetLocated( functions.contacts.GetLocated(
geo_point=types.InputGeoPoint( geo_point=types.InputGeoPoint(
lat=latitude, lat=latitude,

View File

@ -17,10 +17,12 @@
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from string import ascii_lowercase from string import ascii_lowercase
from typing import Union, Generator from typing import Union, Generator, Optional
import pyrogram import pyrogram
from async_generator import async_generator, yield_
from pyrogram.api import types from pyrogram.api import types
from ...ext import BaseClient from ...ext import BaseClient
@ -38,13 +40,14 @@ QUERYABLE_FILTERS = (Filters.ALL, Filters.KICKED, Filters.RESTRICTED)
class IterChatMembers(BaseClient): class IterChatMembers(BaseClient):
def iter_chat_members( @async_generator
async def iter_chat_members(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
limit: int = 0, limit: int = 0,
query: str = "", query: str = "",
filter: str = Filters.ALL filter: str = Filters.ALL
) -> Generator["pyrogram.ChatMember", None, None]: ) -> Optional[Generator["pyrogram.ChatMember", None, None]]:
"""Iterate through the members of a chat sequentially. """Iterate through the members of a chat sequentially.
This convenience method does the same as repeatedly calling :meth:`~Client.get_chat_members` in a loop, thus saving you This convenience method does the same as repeatedly calling :meth:`~Client.get_chat_members` in a loop, thus saving you
@ -97,7 +100,7 @@ class IterChatMembers(BaseClient):
queries = [query] if query else QUERIES queries = [query] if query else QUERIES
total = limit or (1 << 31) - 1 total = limit or (1 << 31) - 1
limit = min(200, total) limit = min(200, total)
resolved_chat_id = self.resolve_peer(chat_id) resolved_chat_id = await self.resolve_peer(chat_id)
if filter not in QUERYABLE_FILTERS: if filter not in QUERYABLE_FILTERS:
queries = [""] queries = [""]
@ -106,7 +109,7 @@ class IterChatMembers(BaseClient):
offset = 0 offset = 0
while True: while True:
chat_members = self.get_chat_members( chat_members = await self.get_chat_members(
chat_id=chat_id, chat_id=chat_id,
offset=offset, offset=offset,
limit=limit, limit=limit,
@ -128,7 +131,7 @@ class IterChatMembers(BaseClient):
if user_id in yielded: if user_id in yielded:
continue continue
yield chat_member await yield_(chat_member)
yielded.add(chat_member.user.id) yielded.add(chat_member.user.id)

View File

@ -16,18 +16,21 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from typing import Generator from typing import Generator, Optional
from async_generator import async_generator, yield_
import pyrogram import pyrogram
from ...ext import BaseClient from ...ext import BaseClient
class IterDialogs(BaseClient): class IterDialogs(BaseClient):
def iter_dialogs( @async_generator
async def iter_dialogs(
self, self,
offset_date: int = 0, limit: int = 0,
limit: int = None offset_date: int = 0
) -> Generator["pyrogram.Dialog", None, None]: ) -> Optional[Generator["pyrogram.Dialog", None, None]]:
"""Iterate through a user's dialogs sequentially. """Iterate through a user's dialogs sequentially.
This convenience method does the same as repeatedly calling :meth:`~Client.get_dialogs` in a loop, thus saving This convenience method does the same as repeatedly calling :meth:`~Client.get_dialogs` in a loop, thus saving
@ -35,14 +38,14 @@ class IterDialogs(BaseClient):
single call. single call.
Parameters: Parameters:
offset_date (``int``):
The offset date in Unix time taken from the top message of a :obj:`Dialog`.
Defaults to 0 (most recent dialog).
limit (``int``, *optional*): limit (``int``, *optional*):
Limits the number of dialogs to be retrieved. Limits the number of dialogs to be retrieved.
By default, no limit is applied and all dialogs are returned. By default, no limit is applied and all dialogs are returned.
offset_date (``int``):
The offset date in Unix time taken from the top message of a :obj:`Dialog`.
Defaults to 0 (most recent dialog).
Returns: Returns:
``Generator``: A generator yielding :obj:`Dialog` objects. ``Generator``: A generator yielding :obj:`Dialog` objects.
@ -57,12 +60,12 @@ class IterDialogs(BaseClient):
total = limit or (1 << 31) - 1 total = limit or (1 << 31) - 1
limit = min(100, total) limit = min(100, total)
pinned_dialogs = self.get_dialogs( pinned_dialogs = await self.get_dialogs(
pinned_only=True pinned_only=True
) )
for dialog in pinned_dialogs: for dialog in pinned_dialogs:
yield dialog await yield_(dialog)
current += 1 current += 1
@ -70,7 +73,7 @@ class IterDialogs(BaseClient):
return return
while True: while True:
dialogs = self.get_dialogs( dialogs = await self.get_dialogs(
offset_date=offset_date, offset_date=offset_date,
limit=limit limit=limit
) )
@ -81,7 +84,7 @@ class IterDialogs(BaseClient):
offset_date = dialogs[-1].top_message.date offset_date = dialogs[-1].top_message.date
for dialog in dialogs: for dialog in dialogs:
yield dialog await yield_(dialog)
current += 1 current += 1

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient
class JoinChat(BaseClient): class JoinChat(BaseClient):
def join_chat( async def join_chat(
self, self,
chat_id: Union[int, str] chat_id: Union[int, str]
): ):
@ -53,7 +53,7 @@ class JoinChat(BaseClient):
match = self.INVITE_LINK_RE.match(str(chat_id)) match = self.INVITE_LINK_RE.match(str(chat_id))
if match: if match:
chat = self.send( chat = await self.send(
functions.messages.ImportChatInvite( functions.messages.ImportChatInvite(
hash=match.group(1) hash=match.group(1)
) )
@ -63,9 +63,9 @@ class JoinChat(BaseClient):
elif isinstance(chat.chats[0], types.Channel): elif isinstance(chat.chats[0], types.Channel):
return pyrogram.Chat._parse_channel_chat(self, chat.chats[0]) return pyrogram.Chat._parse_channel_chat(self, chat.chats[0])
else: else:
chat = self.send( chat = await self.send(
functions.channels.JoinChannel( functions.channels.JoinChannel(
channel=self.resolve_peer(chat_id) channel=await self.resolve_peer(chat_id)
) )
) )

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient
class KickChatMember(BaseClient): class KickChatMember(BaseClient):
def kick_chat_member( async def kick_chat_member(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
user_id: Union[int, str], user_id: Union[int, str],
@ -68,11 +68,11 @@ class KickChatMember(BaseClient):
# Kick chat member and automatically unban after 24h # Kick chat member and automatically unban after 24h
app.kick_chat_member(chat_id, user_id, int(time.time() + 86400)) app.kick_chat_member(chat_id, user_id, int(time.time() + 86400))
""" """
chat_peer = self.resolve_peer(chat_id) chat_peer = await self.resolve_peer(chat_id)
user_peer = self.resolve_peer(user_id) user_peer = await self.resolve_peer(user_id)
if isinstance(chat_peer, types.InputPeerChannel): if isinstance(chat_peer, types.InputPeerChannel):
r = self.send( r = await self.send(
functions.channels.EditBanned( functions.channels.EditBanned(
channel=chat_peer, channel=chat_peer,
user_id=user_peer, user_id=user_peer,
@ -90,7 +90,7 @@ class KickChatMember(BaseClient):
) )
) )
else: else:
r = self.send( r = await self.send(
functions.messages.DeleteChatUser( functions.messages.DeleteChatUser(
chat_id=abs(chat_id), chat_id=abs(chat_id),
user_id=user_peer user_id=user_peer
@ -99,7 +99,7 @@ class KickChatMember(BaseClient):
for i in r.updates: for i in r.updates:
if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage)): if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage)):
return pyrogram.Message._parse( return await pyrogram.Message._parse(
self, i.message, self, i.message,
{i.id: i for i in r.users}, {i.id: i for i in r.users},
{i.id: i for i in r.chats} {i.id: i for i in r.chats}

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class LeaveChat(BaseClient): class LeaveChat(BaseClient):
def leave_chat( async def leave_chat(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
delete: bool = False delete: bool = False
@ -48,16 +48,16 @@ class LeaveChat(BaseClient):
# Leave basic chat and also delete the dialog # Leave basic chat and also delete the dialog
app.leave_chat(chat_id, delete=True) app.leave_chat(chat_id, delete=True)
""" """
peer = self.resolve_peer(chat_id) peer = await self.resolve_peer(chat_id)
if isinstance(peer, types.InputPeerChannel): if isinstance(peer, types.InputPeerChannel):
return self.send( return await self.send(
functions.channels.LeaveChannel( functions.channels.LeaveChannel(
channel=self.resolve_peer(chat_id) channel=await self.resolve_peer(chat_id)
) )
) )
elif isinstance(peer, types.InputPeerChat): elif isinstance(peer, types.InputPeerChat):
r = self.send( r = await self.send(
functions.messages.DeleteChatUser( functions.messages.DeleteChatUser(
chat_id=peer.chat_id, chat_id=peer.chat_id,
user_id=types.InputPeerSelf() user_id=types.InputPeerSelf()
@ -65,7 +65,7 @@ class LeaveChat(BaseClient):
) )
if delete: if delete:
self.send( await self.send(
functions.messages.DeleteHistory( functions.messages.DeleteHistory(
peer=peer, peer=peer,
max_id=0 max_id=0

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class PinChatMessage(BaseClient): class PinChatMessage(BaseClient):
def pin_chat_message( async def pin_chat_message(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
message_id: int, message_id: int,
@ -56,9 +56,9 @@ class PinChatMessage(BaseClient):
# Pin without notification # Pin without notification
app.pin_chat_message(chat_id, message_id, disable_notification=True) app.pin_chat_message(chat_id, message_id, disable_notification=True)
""" """
self.send( await self.send(
functions.messages.UpdatePinnedMessage( functions.messages.UpdatePinnedMessage(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
id=message_id, id=message_id,
silent=disable_notification or None silent=disable_notification or None
) )

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class PromoteChatMember(BaseClient): class PromoteChatMember(BaseClient):
def promote_chat_member( async def promote_chat_member(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
user_id: Union[int, str], user_id: Union[int, str],
@ -84,10 +84,10 @@ class PromoteChatMember(BaseClient):
# Promote chat member to supergroup admin # Promote chat member to supergroup admin
app.promote_chat_member(chat_id, user_id) app.promote_chat_member(chat_id, user_id)
""" """
self.send( await self.send(
functions.channels.EditAdmin( functions.channels.EditAdmin(
channel=self.resolve_peer(chat_id), channel=await self.resolve_peer(chat_id),
user_id=self.resolve_peer(user_id), user_id=await self.resolve_peer(user_id),
admin_rights=types.ChatAdminRights( admin_rights=types.ChatAdminRights(
change_info=can_change_info or None, change_info=can_change_info or None,
post_messages=can_post_messages or None, post_messages=can_post_messages or None,

View File

@ -24,7 +24,7 @@ from ...types.user_and_chats import Chat, ChatPermissions
class RestrictChatMember(BaseClient): class RestrictChatMember(BaseClient):
def restrict_chat_member( async def restrict_chat_member(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
user_id: Union[int, str], user_id: Union[int, str],
@ -71,10 +71,10 @@ class RestrictChatMember(BaseClient):
# Chat member can only send text messages # Chat member can only send text messages
app.restrict_chat_member(chat_id, user_id, ChatPermissions(can_send_messages=True)) app.restrict_chat_member(chat_id, user_id, ChatPermissions(can_send_messages=True))
""" """
r = self.send( r = await self.send(
functions.channels.EditBanned( functions.channels.EditBanned(
channel=self.resolve_peer(chat_id), channel=await self.resolve_peer(chat_id),
user_id=self.resolve_peer(user_id), user_id=await self.resolve_peer(user_id),
banned_rights=types.ChatBannedRights( banned_rights=types.ChatBannedRights(
until_date=until_date, until_date=until_date,
send_messages=True if not permissions.can_send_messages else None, send_messages=True if not permissions.can_send_messages else None,

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class SetAdministratorTitle(BaseClient): class SetAdministratorTitle(BaseClient):
def set_administrator_title( async def set_administrator_title(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
user_id: Union[int, str], user_id: Union[int, str],
@ -54,15 +54,15 @@ class SetAdministratorTitle(BaseClient):
app.set_administrator_title(chat_id, user_id, "ฅ^•ﻌ•^ฅ") app.set_administrator_title(chat_id, user_id, "ฅ^•ﻌ•^ฅ")
""" """
chat_id = self.resolve_peer(chat_id) chat_id = await self.resolve_peer(chat_id)
user_id = self.resolve_peer(user_id) user_id = await self.resolve_peer(user_id)
r = self.send( r = (await self.send(
functions.channels.GetParticipant( functions.channels.GetParticipant(
channel=chat_id, channel=chat_id,
user_id=user_id user_id=user_id
) )
).participant )).participant
if isinstance(r, types.ChannelParticipantCreator): if isinstance(r, types.ChannelParticipantCreator):
admin_rights = types.ChatAdminRights( admin_rights = types.ChatAdminRights(
@ -104,7 +104,7 @@ class SetAdministratorTitle(BaseClient):
if not admin_rights.add_admins: if not admin_rights.add_admins:
admin_rights.add_admins = None admin_rights.add_admins = None
self.send( await self.send(
functions.channels.EditAdmin( functions.channels.EditAdmin(
channel=chat_id, channel=chat_id,
user_id=user_id, user_id=user_id,

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class SetChatDescription(BaseClient): class SetChatDescription(BaseClient):
def set_chat_description( async def set_chat_description(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
description: str description: str
@ -49,10 +49,10 @@ class SetChatDescription(BaseClient):
app.set_chat_description(chat_id, "New Description") app.set_chat_description(chat_id, "New Description")
""" """
peer = self.resolve_peer(chat_id) peer = await self.resolve_peer(chat_id)
if isinstance(peer, (types.InputPeerChannel, types.InputPeerChat)): if isinstance(peer, (types.InputPeerChannel, types.InputPeerChat)):
self.send( await self.send(
functions.messages.EditChatAbout( functions.messages.EditChatAbout(
peer=peer, peer=peer,
about=description about=description

View File

@ -24,7 +24,7 @@ from ...types.user_and_chats import Chat, ChatPermissions
class SetChatPermissions(BaseClient): class SetChatPermissions(BaseClient):
def set_chat_permissions( async def set_chat_permissions(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
permissions: ChatPermissions, permissions: ChatPermissions,
@ -63,9 +63,9 @@ class SetChatPermissions(BaseClient):
) )
) )
""" """
r = self.send( r = await self.send(
functions.messages.EditChatDefaultBannedRights( functions.messages.EditChatDefaultBannedRights(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
banned_rights=types.ChatBannedRights( banned_rights=types.ChatBannedRights(
until_date=0, until_date=0,
send_messages=True if not permissions.can_send_messages else None, send_messages=True if not permissions.can_send_messages else None,

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient, utils
class SetChatPhoto(BaseClient): class SetChatPhoto(BaseClient):
def set_chat_photo( async def set_chat_photo(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
*, *,
@ -79,32 +79,32 @@ class SetChatPhoto(BaseClient):
# Set chat photo using an exiting Video file_id # Set chat photo using an exiting Video file_id
app.set_chat_photo(chat_id, video=video.file_id, file_ref=video.file_ref) app.set_chat_photo(chat_id, video=video.file_id, file_ref=video.file_ref)
""" """
peer = self.resolve_peer(chat_id) peer = await self.resolve_peer(chat_id)
if isinstance(photo, str): if isinstance(photo, str):
if os.path.isfile(photo): if os.path.isfile(photo):
photo = types.InputChatUploadedPhoto( photo = types.InputChatUploadedPhoto(
file=self.save_file(photo), file=await self.save_file(photo),
video=self.save_file(video) video=await self.save_file(video)
) )
else: else:
photo = utils.get_input_media_from_file_id(photo, file_ref, 2) photo = utils.get_input_media_from_file_id(photo, file_ref, 2)
photo = types.InputChatPhoto(id=photo.id) photo = types.InputChatPhoto(id=photo.id)
else: else:
photo = types.InputChatUploadedPhoto( photo = types.InputChatUploadedPhoto(
file=self.save_file(photo), file=await self.save_file(photo),
video=self.save_file(video) video=await self.save_file(video)
) )
if isinstance(peer, types.InputPeerChat): if isinstance(peer, types.InputPeerChat):
self.send( await self.send(
functions.messages.EditChatPhoto( functions.messages.EditChatPhoto(
chat_id=peer.chat_id, chat_id=peer.chat_id,
photo=photo photo=photo
) )
) )
elif isinstance(peer, types.InputPeerChannel): elif isinstance(peer, types.InputPeerChannel):
self.send( await self.send(
functions.channels.EditPhoto( functions.channels.EditPhoto(
channel=peer, channel=peer,
photo=photo photo=photo

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class SetChatTitle(BaseClient): class SetChatTitle(BaseClient):
def set_chat_title( async def set_chat_title(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
title: str title: str
@ -54,17 +54,17 @@ class SetChatTitle(BaseClient):
app.set_chat_title(chat_id, "New Title") app.set_chat_title(chat_id, "New Title")
""" """
peer = self.resolve_peer(chat_id) peer = await self.resolve_peer(chat_id)
if isinstance(peer, types.InputPeerChat): if isinstance(peer, types.InputPeerChat):
self.send( await self.send(
functions.messages.EditChatTitle( functions.messages.EditChatTitle(
chat_id=peer.chat_id, chat_id=peer.chat_id,
title=title title=title
) )
) )
elif isinstance(peer, types.InputPeerChannel): elif isinstance(peer, types.InputPeerChannel):
self.send( await self.send(
functions.channels.EditTitle( functions.channels.EditTitle(
channel=peer, channel=peer,
title=title title=title

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class SetSlowMode(BaseClient): class SetSlowMode(BaseClient):
def set_slow_mode( async def set_slow_mode(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
seconds: Union[int, None] seconds: Union[int, None]
@ -51,9 +51,9 @@ class SetSlowMode(BaseClient):
app.set_slow_mode("pyrogramchat", None) app.set_slow_mode("pyrogramchat", None)
""" """
self.send( await self.send(
functions.channels.ToggleSlowMode( functions.channels.ToggleSlowMode(
channel=self.resolve_peer(chat_id), channel=await self.resolve_peer(chat_id),
seconds=0 if seconds is None else seconds seconds=0 if seconds is None else seconds
) )
) )

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class UnarchiveChats(BaseClient): class UnarchiveChats(BaseClient):
def unarchive_chats( async def unarchive_chats(
self, self,
chat_ids: Union[int, str, List[Union[int, str]]], chat_ids: Union[int, str, List[Union[int, str]]],
) -> bool: ) -> bool:
@ -50,14 +50,19 @@ class UnarchiveChats(BaseClient):
if not isinstance(chat_ids, list): if not isinstance(chat_ids, list):
chat_ids = [chat_ids] chat_ids = [chat_ids]
self.send( folder_peers = []
functions.folders.EditPeerFolders(
folder_peers=[ for chat in chat_ids:
folder_peers.append(
types.InputFolderPeer( types.InputFolderPeer(
peer=self.resolve_peer(chat), peer=await self.resolve_peer(chat),
folder_id=0 folder_id=0
) for chat in chat_ids )
] )
await self.send(
functions.folders.EditPeerFolders(
folder_peers=folder_peers
) )
) )

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class UnbanChatMember(BaseClient): class UnbanChatMember(BaseClient):
def unban_chat_member( async def unban_chat_member(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
user_id: Union[int, str] user_id: Union[int, str]
@ -49,10 +49,10 @@ class UnbanChatMember(BaseClient):
# Unban chat member right now # Unban chat member right now
app.unban_chat_member(chat_id, user_id) app.unban_chat_member(chat_id, user_id)
""" """
self.send( await self.send(
functions.channels.EditBanned( functions.channels.EditBanned(
channel=self.resolve_peer(chat_id), channel=await self.resolve_peer(chat_id),
user_id=self.resolve_peer(user_id), user_id=await self.resolve_peer(user_id),
banned_rights=types.ChatBannedRights( banned_rights=types.ChatBannedRights(
until_date=0 until_date=0
) )

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class UnpinChatMessage(BaseClient): class UnpinChatMessage(BaseClient):
def unpin_chat_message( async def unpin_chat_message(
self, self,
chat_id: Union[int, str] chat_id: Union[int, str]
) -> bool: ) -> bool:
@ -43,9 +43,9 @@ class UnpinChatMessage(BaseClient):
app.unpin_chat_message(chat_id) app.unpin_chat_message(chat_id)
""" """
self.send( await self.send(
functions.messages.UpdatePinnedMessage( functions.messages.UpdatePinnedMessage(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
id=0 id=0
) )
) )

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class UpdateChatUsername(BaseClient): class UpdateChatUsername(BaseClient):
def update_chat_username( async def update_chat_username(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
username: Union[str, None] username: Union[str, None]
@ -50,11 +50,11 @@ class UpdateChatUsername(BaseClient):
app.update_chat_username(chat_id, "new_username") app.update_chat_username(chat_id, "new_username")
""" """
peer = self.resolve_peer(chat_id) peer = await self.resolve_peer(chat_id)
if isinstance(peer, types.InputPeerChannel): if isinstance(peer, types.InputPeerChannel):
return bool( return bool(
self.send( await self.send(
functions.channels.UpdateUsername( functions.channels.UpdateUsername(
channel=peer, channel=peer,
username=username or "" username=username or ""

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient
class AddContacts(BaseClient): class AddContacts(BaseClient):
def add_contacts( async def add_contacts(
self, self,
contacts: List["pyrogram.InputPhoneContact"] contacts: List["pyrogram.InputPhoneContact"]
): ):
@ -47,7 +47,7 @@ class AddContacts(BaseClient):
InputPhoneContact("38987654321", "Bar"), InputPhoneContact("38987654321", "Bar"),
InputPhoneContact("01234567891", "Baz")]) InputPhoneContact("01234567891", "Baz")])
""" """
imported_contacts = self.send( imported_contacts = await self.send(
functions.contacts.ImportContacts( functions.contacts.ImportContacts(
contacts=contacts contacts=contacts
) )

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient
class DeleteContacts(BaseClient): class DeleteContacts(BaseClient):
def delete_contacts( async def delete_contacts(
self, self,
ids: List[int] ids: List[int]
): ):
@ -47,14 +47,14 @@ class DeleteContacts(BaseClient):
for i in ids: for i in ids:
try: try:
input_user = self.resolve_peer(i) input_user = await self.resolve_peer(i)
except PeerIdInvalid: except PeerIdInvalid:
continue continue
else: else:
if isinstance(input_user, types.InputPeerUser): if isinstance(input_user, types.InputPeerUser):
contacts.append(input_user) contacts.append(input_user)
return self.send( return await self.send(
functions.contacts.DeleteContacts( functions.contacts.DeleteContacts(
id=contacts id=contacts
) )

View File

@ -16,6 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging import logging
from typing import List from typing import List
@ -27,7 +28,7 @@ log = logging.getLogger(__name__)
class GetContacts(BaseClient): class GetContacts(BaseClient):
def get_contacts(self) -> List["pyrogram.User"]: async def get_contacts(self) -> List["pyrogram.User"]:
"""Get contacts from your Telegram address book. """Get contacts from your Telegram address book.
Returns: Returns:
@ -39,5 +40,5 @@ class GetContacts(BaseClient):
contacts = app.get_contacts() contacts = app.get_contacts()
print(contacts) print(contacts)
""" """
contacts = self.send(functions.contacts.GetContacts(hash=0)) contacts = await self.send(functions.contacts.GetContacts(hash=0))
return pyrogram.List(pyrogram.User._parse(self, user) for user in contacts.users) return pyrogram.List(pyrogram.User._parse(self, user) for user in contacts.users)

View File

@ -21,7 +21,7 @@ from ...ext import BaseClient
class GetContactsCount(BaseClient): class GetContactsCount(BaseClient):
def get_contacts_count(self) -> int: async def get_contacts_count(self) -> int:
"""Get the total count of contacts from your Telegram address book. """Get the total count of contacts from your Telegram address book.
Returns: Returns:
@ -34,4 +34,4 @@ class GetContactsCount(BaseClient):
print(count) print(count)
""" """
return len(self.send(functions.contacts.GetContacts(hash=0)).contacts) return len((await self.send(functions.contacts.GetContacts(hash=0))).contacts)

View File

@ -23,7 +23,7 @@ from pyrogram.client.ext import BaseClient
class DeleteMessages(BaseClient): class DeleteMessages(BaseClient):
def delete_messages( async def delete_messages(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
message_ids: Iterable[int], message_ids: Iterable[int],
@ -62,18 +62,18 @@ class DeleteMessages(BaseClient):
# Delete messages only on your side (without revoking) # Delete messages only on your side (without revoking)
app.delete_messages(chat_id, message_id, revoke=False) app.delete_messages(chat_id, message_id, revoke=False)
""" """
peer = self.resolve_peer(chat_id) peer = await self.resolve_peer(chat_id)
message_ids = list(message_ids) if not isinstance(message_ids, int) else [message_ids] message_ids = list(message_ids) if not isinstance(message_ids, int) else [message_ids]
if isinstance(peer, types.InputPeerChannel): if isinstance(peer, types.InputPeerChannel):
r = self.send( r = await self.send(
functions.channels.DeleteMessages( functions.channels.DeleteMessages(
channel=peer, channel=peer,
id=message_ids id=message_ids
) )
) )
else: else:
r = self.send( r = await self.send(
functions.messages.DeleteMessages( functions.messages.DeleteMessages(
id=message_ids, id=message_ids,
revoke=revoke or None revoke=revoke or None

View File

@ -16,6 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import binascii import binascii
import os import os
import struct import struct
@ -32,7 +33,7 @@ DEFAULT_DOWNLOAD_DIR = "downloads/"
class DownloadMedia(BaseClient): class DownloadMedia(BaseClient):
def download_media( async def download_media(
self, self,
message: Union["pyrogram.Message", str], message: Union["pyrogram.Message", str],
file_ref: str = None, file_ref: str = None,
@ -202,7 +203,7 @@ class DownloadMedia(BaseClient):
except (AssertionError, binascii.Error, struct.error): except (AssertionError, binascii.Error, struct.error):
raise FileIdInvalid from None raise FileIdInvalid from None
done = Event() done = asyncio.Event()
path = [None] path = [None]
directory, file_name = os.path.split(file_name) directory, file_name = os.path.split(file_name)
@ -239,9 +240,9 @@ class DownloadMedia(BaseClient):
) )
# Cast to string because Path objects aren't supported by Python 3.5 # Cast to string because Path objects aren't supported by Python 3.5
self.download_queue.put((data, str(directory), str(file_name), done, progress, progress_args, path)) self.download_queue.put_nowait((data, str(directory), str(file_name), done, progress, progress_args, path))
if block: if block:
done.wait() await done.wait()
return path[0] return path[0]

View File

@ -23,7 +23,7 @@ from pyrogram.client.ext import BaseClient
class EditInlineCaption(BaseClient): class EditInlineCaption(BaseClient):
def edit_inline_caption( async def edit_inline_caption(
self, self,
inline_message_id: str, inline_message_id: str,
caption: str, caption: str,
@ -58,7 +58,7 @@ class EditInlineCaption(BaseClient):
# Bots only # Bots only
app.edit_inline_caption(inline_message_id, "new media caption") app.edit_inline_caption(inline_message_id, "new media caption")
""" """
return self.edit_inline_text( return await self.edit_inline_text(
inline_message_id=inline_message_id, inline_message_id=inline_message_id,
text=caption, text=caption,
parse_mode=parse_mode, parse_mode=parse_mode,

View File

@ -29,7 +29,7 @@ from pyrogram.client.types.input_media import InputMedia
class EditInlineMedia(BaseClient): class EditInlineMedia(BaseClient):
def edit_inline_media( async def edit_inline_media(
self, self,
inline_message_id: str, inline_message_id: str,
media: InputMedia, media: InputMedia,
@ -109,11 +109,11 @@ class EditInlineMedia(BaseClient):
else: else:
media = utils.get_input_media_from_file_id(media.media, media.file_ref, 5) media = utils.get_input_media_from_file_id(media.media, media.file_ref, 5)
return self.send( return await self.send(
functions.messages.EditInlineBotMessage( functions.messages.EditInlineBotMessage(
id=utils.unpack_inline_message_id(inline_message_id), id=utils.unpack_inline_message_id(inline_message_id),
media=media, media=media,
reply_markup=reply_markup.write() if reply_markup else None, reply_markup=reply_markup.write() if reply_markup else None,
**self.parser.parse(caption, parse_mode) **await self.parser.parse(caption, parse_mode)
) )
) )

View File

@ -22,7 +22,7 @@ from pyrogram.client.ext import BaseClient, utils
class EditInlineReplyMarkup(BaseClient): class EditInlineReplyMarkup(BaseClient):
def edit_inline_reply_markup( async def edit_inline_reply_markup(
self, self,
inline_message_id: str, inline_message_id: str,
reply_markup: "pyrogram.InlineKeyboardMarkup" = None reply_markup: "pyrogram.InlineKeyboardMarkup" = None
@ -50,7 +50,7 @@ class EditInlineReplyMarkup(BaseClient):
InlineKeyboardMarkup([[ InlineKeyboardMarkup([[
InlineKeyboardButton("New button", callback_data="new_data")]])) InlineKeyboardButton("New button", callback_data="new_data")]]))
""" """
return self.send( return await self.send(
functions.messages.EditInlineBotMessage( functions.messages.EditInlineBotMessage(
id=utils.unpack_inline_message_id(inline_message_id), id=utils.unpack_inline_message_id(inline_message_id),
reply_markup=reply_markup.write() if reply_markup else None, reply_markup=reply_markup.write() if reply_markup else None,

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient, utils
class EditInlineText(BaseClient): class EditInlineText(BaseClient):
def edit_inline_text( async def edit_inline_text(
self, self,
inline_message_id: str, inline_message_id: str,
text: str, text: str,
@ -71,11 +71,11 @@ class EditInlineText(BaseClient):
disable_web_page_preview=True) disable_web_page_preview=True)
""" """
return self.send( return await self.send(
functions.messages.EditInlineBotMessage( functions.messages.EditInlineBotMessage(
id=utils.unpack_inline_message_id(inline_message_id), id=utils.unpack_inline_message_id(inline_message_id),
no_webpage=disable_web_page_preview or None, no_webpage=disable_web_page_preview or None,
reply_markup=reply_markup.write() if reply_markup else None, reply_markup=reply_markup.write() if reply_markup else None,
**self.parser.parse(text, parse_mode) **await self.parser.parse(text, parse_mode)
) )
) )

View File

@ -23,7 +23,7 @@ from pyrogram.client.ext import BaseClient
class EditMessageCaption(BaseClient): class EditMessageCaption(BaseClient):
def edit_message_caption( async def edit_message_caption(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
message_id: int, message_id: int,
@ -63,7 +63,7 @@ class EditMessageCaption(BaseClient):
app.edit_message_caption(chat_id, message_id, "new media caption") app.edit_message_caption(chat_id, message_id, "new media caption")
""" """
return self.edit_message_text( return await self.edit_message_text(
chat_id=chat_id, chat_id=chat_id,
message_id=message_id, message_id=message_id,
text=caption, text=caption,

View File

@ -31,7 +31,7 @@ from pyrogram.client.types.input_media import InputMedia
class EditMessageMedia(BaseClient): class EditMessageMedia(BaseClient):
def edit_message_media( async def edit_message_media(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
message_id: int, message_id: int,
@ -85,11 +85,11 @@ class EditMessageMedia(BaseClient):
if isinstance(media, InputMediaPhoto): if isinstance(media, InputMediaPhoto):
if os.path.isfile(media.media): if os.path.isfile(media.media):
media = self.send( media = await self.send(
functions.messages.UploadMedia( functions.messages.UploadMedia(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=types.InputMediaUploadedPhoto( media=types.InputMediaUploadedPhoto(
file=self.save_file(media.media) file=await self.save_file(media.media)
) )
) )
) )
@ -109,13 +109,13 @@ class EditMessageMedia(BaseClient):
media = utils.get_input_media_from_file_id(media.media, media.file_ref, 2) media = utils.get_input_media_from_file_id(media.media, media.file_ref, 2)
elif isinstance(media, InputMediaVideo): elif isinstance(media, InputMediaVideo):
if os.path.isfile(media.media): if os.path.isfile(media.media):
media = self.send( media = await self.send(
functions.messages.UploadMedia( functions.messages.UploadMedia(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=types.InputMediaUploadedDocument( media=types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(media.media) or "video/mp4", mime_type=self.guess_mime_type(media.media) or "video/mp4",
thumb=self.save_file(media.thumb), thumb=await self.save_file(media.thumb),
file=self.save_file(media.media), file=await self.save_file(media.media),
attributes=[ attributes=[
types.DocumentAttributeVideo( types.DocumentAttributeVideo(
supports_streaming=media.supports_streaming or None, supports_streaming=media.supports_streaming or None,
@ -146,13 +146,13 @@ class EditMessageMedia(BaseClient):
media = utils.get_input_media_from_file_id(media.media, media.file_ref, 4) media = utils.get_input_media_from_file_id(media.media, media.file_ref, 4)
elif isinstance(media, InputMediaAudio): elif isinstance(media, InputMediaAudio):
if os.path.isfile(media.media): if os.path.isfile(media.media):
media = self.send( media = await self.send(
functions.messages.UploadMedia( functions.messages.UploadMedia(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=types.InputMediaUploadedDocument( media=types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(media.media) or "audio/mpeg", mime_type=self.guess_mime_type(media.media) or "audio/mpeg",
thumb=self.save_file(media.thumb), thumb=await self.save_file(media.thumb),
file=self.save_file(media.media), file=await self.save_file(media.media),
attributes=[ attributes=[
types.DocumentAttributeAudio( types.DocumentAttributeAudio(
duration=media.duration, duration=media.duration,
@ -182,13 +182,13 @@ class EditMessageMedia(BaseClient):
media = utils.get_input_media_from_file_id(media.media, media.file_ref, 9) media = utils.get_input_media_from_file_id(media.media, media.file_ref, 9)
elif isinstance(media, InputMediaAnimation): elif isinstance(media, InputMediaAnimation):
if os.path.isfile(media.media): if os.path.isfile(media.media):
media = self.send( media = await self.send(
functions.messages.UploadMedia( functions.messages.UploadMedia(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=types.InputMediaUploadedDocument( media=types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(media.media) or "video/mp4", mime_type=self.guess_mime_type(media.media) or "video/mp4",
thumb=self.save_file(media.thumb), thumb=self.save_file(media.thumb),
file=self.save_file(media.media), file=await self.save_file(media.media),
attributes=[ attributes=[
types.DocumentAttributeVideo( types.DocumentAttributeVideo(
supports_streaming=True, supports_streaming=True,
@ -220,13 +220,13 @@ class EditMessageMedia(BaseClient):
media = utils.get_input_media_from_file_id(media.media, media.file_ref, 10) media = utils.get_input_media_from_file_id(media.media, media.file_ref, 10)
elif isinstance(media, InputMediaDocument): elif isinstance(media, InputMediaDocument):
if os.path.isfile(media.media): if os.path.isfile(media.media):
media = self.send( media = await self.send(
functions.messages.UploadMedia( functions.messages.UploadMedia(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=types.InputMediaUploadedDocument( media=types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(media.media) or "application/zip", mime_type=self.guess_mime_type(media.media) or "application/zip",
thumb=self.save_file(media.thumb), thumb=await self.save_file(media.thumb),
file=self.save_file(media.media), file=await self.save_file(media.media),
attributes=[ attributes=[
types.DocumentAttributeFilename( types.DocumentAttributeFilename(
file_name=file_name or os.path.basename(media.media) file_name=file_name or os.path.basename(media.media)
@ -250,19 +250,19 @@ class EditMessageMedia(BaseClient):
else: else:
media = utils.get_input_media_from_file_id(media.media, media.file_ref, 5) media = utils.get_input_media_from_file_id(media.media, media.file_ref, 5)
r = self.send( r = await self.send(
functions.messages.EditMessage( functions.messages.EditMessage(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
id=message_id, id=message_id,
media=media, media=media,
reply_markup=reply_markup.write() if reply_markup else None, reply_markup=reply_markup.write() if reply_markup else None,
**self.parser.parse(caption, parse_mode) **await self.parser.parse(caption, parse_mode)
) )
) )
for i in r.updates: for i in r.updates:
if isinstance(i, (types.UpdateEditMessage, types.UpdateEditChannelMessage)): if isinstance(i, (types.UpdateEditMessage, types.UpdateEditChannelMessage)):
return pyrogram.Message._parse( return await pyrogram.Message._parse(
self, i.message, self, i.message,
{i.id: i for i in r.users}, {i.id: i for i in r.users},
{i.id: i for i in r.chats} {i.id: i for i in r.chats}

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class EditMessageReplyMarkup(BaseClient): class EditMessageReplyMarkup(BaseClient):
def edit_message_reply_markup( async def edit_message_reply_markup(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
message_id: int, message_id: int,
@ -58,9 +58,9 @@ class EditMessageReplyMarkup(BaseClient):
InlineKeyboardMarkup([[ InlineKeyboardMarkup([[
InlineKeyboardButton("New button", callback_data="new_data")]])) InlineKeyboardButton("New button", callback_data="new_data")]]))
""" """
r = self.send( r = await self.send(
functions.messages.EditMessage( functions.messages.EditMessage(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
id=message_id, id=message_id,
reply_markup=reply_markup.write() if reply_markup else None, reply_markup=reply_markup.write() if reply_markup else None,
) )
@ -68,7 +68,7 @@ class EditMessageReplyMarkup(BaseClient):
for i in r.updates: for i in r.updates:
if isinstance(i, (types.UpdateEditMessage, types.UpdateEditChannelMessage)): if isinstance(i, (types.UpdateEditMessage, types.UpdateEditChannelMessage)):
return pyrogram.Message._parse( return await pyrogram.Message._parse(
self, i.message, self, i.message,
{i.id: i for i in r.users}, {i.id: i for i in r.users},
{i.id: i for i in r.chats} {i.id: i for i in r.chats}

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class EditMessageText(BaseClient): class EditMessageText(BaseClient):
def edit_message_text( async def edit_message_text(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
message_id: int, message_id: int,
@ -75,19 +75,19 @@ class EditMessageText(BaseClient):
disable_web_page_preview=True) disable_web_page_preview=True)
""" """
r = self.send( r = await self.send(
functions.messages.EditMessage( functions.messages.EditMessage(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
id=message_id, id=message_id,
no_webpage=disable_web_page_preview or None, no_webpage=disable_web_page_preview or None,
reply_markup=reply_markup.write() if reply_markup else None, reply_markup=reply_markup.write() if reply_markup else None,
**self.parser.parse(text, parse_mode) **await self.parser.parse(text, parse_mode)
) )
) )
for i in r.updates: for i in r.updates:
if isinstance(i, (types.UpdateEditMessage, types.UpdateEditChannelMessage)): if isinstance(i, (types.UpdateEditMessage, types.UpdateEditChannelMessage)):
return pyrogram.Message._parse( return await pyrogram.Message._parse(
self, i.message, self, i.message,
{i.id: i for i in r.users}, {i.id: i for i in r.users},
{i.id: i for i in r.chats} {i.id: i for i in r.chats}

View File

@ -20,11 +20,12 @@ from typing import Union, Iterable, List
import pyrogram import pyrogram
from pyrogram.api import functions, types from pyrogram.api import functions, types
from ...ext import BaseClient from ...ext import BaseClient
class ForwardMessages(BaseClient): class ForwardMessages(BaseClient):
def forward_messages( async def forward_messages(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
from_chat_id: Union[int, str], from_chat_id: Union[int, str],
@ -94,11 +95,11 @@ class ForwardMessages(BaseClient):
forwarded_messages = [] forwarded_messages = []
for chunk in [message_ids[i:i + 200] for i in range(0, len(message_ids), 200)]: for chunk in [message_ids[i:i + 200] for i in range(0, len(message_ids), 200)]:
messages = self.get_messages(chat_id=from_chat_id, message_ids=chunk) messages = await self.get_messages(chat_id=from_chat_id, message_ids=chunk)
for message in messages: for message in messages:
forwarded_messages.append( forwarded_messages.append(
message.forward( await message.forward(
chat_id, chat_id,
disable_notification=disable_notification, disable_notification=disable_notification,
as_copy=True, as_copy=True,
@ -109,10 +110,10 @@ class ForwardMessages(BaseClient):
return pyrogram.List(forwarded_messages) if is_iterable else forwarded_messages[0] return pyrogram.List(forwarded_messages) if is_iterable else forwarded_messages[0]
else: else:
r = self.send( r = await self.send(
functions.messages.ForwardMessages( functions.messages.ForwardMessages(
to_peer=self.resolve_peer(chat_id), to_peer=await self.resolve_peer(chat_id),
from_peer=self.resolve_peer(from_chat_id), from_peer=await self.resolve_peer(from_chat_id),
id=message_ids, id=message_ids,
silent=disable_notification or None, silent=disable_notification or None,
random_id=[self.rnd_id() for _ in message_ids], random_id=[self.rnd_id() for _ in message_ids],
@ -128,7 +129,7 @@ class ForwardMessages(BaseClient):
for i in r.updates: for i in r.updates:
if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)): if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)):
forwarded_messages.append( forwarded_messages.append(
pyrogram.Message._parse( await pyrogram.Message._parse(
self, i.message, self, i.message,
users, chats users, chats
) )

View File

@ -16,6 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging import logging
from typing import Union, List from typing import Union, List
@ -28,7 +29,7 @@ log = logging.getLogger(__name__)
class GetHistory(BaseClient): class GetHistory(BaseClient):
def get_history( async def get_history(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
limit: int = 100, limit: int = 100,
@ -83,11 +84,11 @@ class GetHistory(BaseClient):
offset_id = offset_id or (1 if reverse else 0) offset_id = offset_id or (1 if reverse else 0)
messages = utils.parse_messages( messages = await utils.parse_messages(
self, self,
self.send( await self.send(
functions.messages.GetHistory( functions.messages.GetHistory(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
offset_id=offset_id, offset_id=offset_id,
offset_date=offset_date, offset_date=offset_date,
add_offset=offset * (-1 if reverse else 1) - (limit if reverse else 0), add_offset=offset * (-1 if reverse else 1) - (limit if reverse else 0),

View File

@ -26,7 +26,7 @@ log = logging.getLogger(__name__)
class GetHistoryCount(BaseClient): class GetHistoryCount(BaseClient):
def get_history_count( async def get_history_count(
self, self,
chat_id: Union[int, str] chat_id: Union[int, str]
) -> int: ) -> int:
@ -51,9 +51,9 @@ class GetHistoryCount(BaseClient):
app.get_history_count("pyrogramchat") app.get_history_count("pyrogramchat")
""" """
r = self.send( r = await self.send(
functions.messages.GetHistory( functions.messages.GetHistory(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
offset_id=0, offset_id=0,
offset_date=0, offset_date=0,
add_offset=0, add_offset=0,

View File

@ -16,6 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging import logging
from typing import Union, Iterable, List from typing import Union, Iterable, List
@ -30,7 +31,7 @@ log = logging.getLogger(__name__)
class GetMessages(BaseClient): class GetMessages(BaseClient):
def get_messages( async def get_messages(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
message_ids: Union[int, Iterable[int]] = None, message_ids: Union[int, Iterable[int]] = None,
@ -96,7 +97,7 @@ class GetMessages(BaseClient):
if ids is None: if ids is None:
raise ValueError("No argument supplied. Either pass message_ids or reply_to_message_ids") raise ValueError("No argument supplied. Either pass message_ids or reply_to_message_ids")
peer = self.resolve_peer(chat_id) peer = await self.resolve_peer(chat_id)
is_iterable = not isinstance(ids, int) is_iterable = not isinstance(ids, int)
ids = list(ids) if is_iterable else [ids] ids = list(ids) if is_iterable else [ids]
@ -110,8 +111,8 @@ class GetMessages(BaseClient):
else: else:
rpc = functions.messages.GetMessages(id=ids) rpc = functions.messages.GetMessages(id=ids)
r = self.send(rpc) r = await self.send(rpc)
messages = utils.parse_messages(self, r, replies=replies) messages = await utils.parse_messages(self, r, replies=replies)
return messages if is_iterable else messages[0] return messages if is_iterable else messages[0]

View File

@ -16,14 +16,17 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from typing import Union, Generator from typing import Union, Optional, Generator
import pyrogram import pyrogram
from async_generator import async_generator, yield_
from ...ext import BaseClient from ...ext import BaseClient
class IterHistory(BaseClient): class IterHistory(BaseClient):
def iter_history( @async_generator
async def iter_history(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
limit: int = 0, limit: int = 0,
@ -31,7 +34,7 @@ class IterHistory(BaseClient):
offset_id: int = 0, offset_id: int = 0,
offset_date: int = 0, offset_date: int = 0,
reverse: bool = False reverse: bool = False
) -> Generator["pyrogram.Message", None, None]: ) -> Optional[Generator["pyrogram.Message", None, None]]:
"""Iterate through a chat history sequentially. """Iterate through a chat history sequentially.
This convenience method does the same as repeatedly calling :meth:`~Client.get_history` in a loop, thus saving This convenience method does the same as repeatedly calling :meth:`~Client.get_history` in a loop, thus saving
@ -76,7 +79,7 @@ class IterHistory(BaseClient):
limit = min(100, total) limit = min(100, total)
while True: while True:
messages = self.get_history( messages = await self.get_history(
chat_id=chat_id, chat_id=chat_id,
limit=limit, limit=limit,
offset=offset, offset=offset,
@ -91,7 +94,7 @@ class IterHistory(BaseClient):
offset_id = messages[-1].message_id + (1 if reverse else 0) offset_id = messages[-1].message_id + (1 if reverse else 0)
for message in messages: for message in messages:
yield message await yield_(message)
current += 1 current += 1

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class ReadHistory(BaseClient): class ReadHistory(BaseClient):
def read_history( async def read_history(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
max_id: int = 0 max_id: int = 0
@ -53,7 +53,7 @@ class ReadHistory(BaseClient):
app.read_history("pyrogramlounge", 123456) app.read_history("pyrogramlounge", 123456)
""" """
peer = self.resolve_peer(chat_id) peer = await self.resolve_peer(chat_id)
if isinstance(peer, types.InputPeerChannel): if isinstance(peer, types.InputPeerChannel):
q = functions.channels.ReadHistory( q = functions.channels.ReadHistory(
@ -66,6 +66,6 @@ class ReadHistory(BaseClient):
max_id=max_id max_id=max_id
) )
self.send(q) await self.send(q)
return True return True

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class RetractVote(BaseClient): class RetractVote(BaseClient):
def retract_vote( async def retract_vote(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
message_id: int message_id: int
@ -48,9 +48,9 @@ class RetractVote(BaseClient):
app.retract_vote(chat_id, message_id) app.retract_vote(chat_id, message_id)
""" """
r = self.send( r = await self.send(
functions.messages.SendVote( functions.messages.SendVote(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
msg_id=message_id, msg_id=message_id,
options=[] options=[]
) )

View File

@ -16,7 +16,9 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from typing import Generator from typing import Generator, Optional
from async_generator import async_generator, yield_
import pyrogram import pyrogram
from pyrogram.api import functions, types from pyrogram.api import functions, types
@ -24,11 +26,12 @@ from pyrogram.client.ext import BaseClient, utils
class SearchGlobal(BaseClient): class SearchGlobal(BaseClient):
def search_global( @async_generator
async def search_global(
self, self,
query: str, query: str,
limit: int = 0, limit: int = 0,
) -> Generator["pyrogram.Message", None, None]: ) -> Optional[Generator["pyrogram.Message", None, None]]:
"""Search messages globally from all of your chats. """Search messages globally from all of your chats.
.. note:: .. note::
@ -64,9 +67,9 @@ class SearchGlobal(BaseClient):
offset_id = 0 offset_id = 0
while True: while True:
messages = utils.parse_messages( messages = await utils.parse_messages(
self, self,
self.send( await self.send(
functions.messages.SearchGlobal( functions.messages.SearchGlobal(
q=query, q=query,
offset_rate=offset_date, offset_rate=offset_date,
@ -84,11 +87,11 @@ class SearchGlobal(BaseClient):
last = messages[-1] last = messages[-1]
offset_date = last.date offset_date = last.date
offset_peer = self.resolve_peer(last.chat.id) offset_peer = await self.resolve_peer(last.chat.id)
offset_id = last.message_id offset_id = last.message_id
for message in messages: for message in messages:
yield message await yield_(message)
current += 1 current += 1

View File

@ -16,11 +16,12 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from typing import Union, List, Generator from typing import Union, List, Generator, Optional
import pyrogram import pyrogram
from pyrogram.client.ext import BaseClient, utils from pyrogram.client.ext import BaseClient, utils
from pyrogram.api import functions, types from pyrogram.api import functions, types
from async_generator import async_generator, yield_
class Filters: class Filters:
@ -46,7 +47,7 @@ POSSIBLE_VALUES = list(map(lambda x: x.lower(), filter(lambda x: not x.startswit
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
def get_chunk( async def get_chunk(
client: BaseClient, client: BaseClient,
chat_id: Union[int, str], chat_id: Union[int, str],
query: str = "", query: str = "",
@ -61,9 +62,9 @@ def get_chunk(
raise ValueError('Invalid filter "{}". Possible values are: {}'.format( raise ValueError('Invalid filter "{}". Possible values are: {}'.format(
filter, ", ".join('"{}"'.format(v) for v in POSSIBLE_VALUES))) from None filter, ", ".join('"{}"'.format(v) for v in POSSIBLE_VALUES))) from None
r = client.send( r = await client.send(
functions.messages.Search( functions.messages.Search(
peer=client.resolve_peer(chat_id), peer=await client.resolve_peer(chat_id),
q=query, q=query,
filter=filter, filter=filter,
min_date=0, min_date=0,
@ -74,7 +75,7 @@ def get_chunk(
min_id=0, min_id=0,
max_id=0, max_id=0,
from_id=( from_id=(
client.resolve_peer(from_user) await client.resolve_peer(from_user)
if from_user if from_user
else None else None
), ),
@ -82,12 +83,13 @@ def get_chunk(
) )
) )
return utils.parse_messages(client, r) return await utils.parse_messages(client, r)
class SearchMessages(BaseClient): class SearchMessages(BaseClient):
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
def search_messages( @async_generator
async def search_messages(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
query: str = "", query: str = "",
@ -95,7 +97,7 @@ class SearchMessages(BaseClient):
filter: str = "empty", filter: str = "empty",
limit: int = 0, limit: int = 0,
from_user: Union[int, str] = None from_user: Union[int, str] = None
) -> Generator["pyrogram.Message", None, None]: ) -> Optional[Generator["pyrogram.Message", None, None]]:
"""Search for text and media messages inside a specific chat. """Search for text and media messages inside a specific chat.
Parameters: Parameters:
@ -160,7 +162,7 @@ class SearchMessages(BaseClient):
limit = min(100, total) limit = min(100, total)
while True: while True:
messages = get_chunk( messages = await get_chunk(
client=self, client=self,
chat_id=chat_id, chat_id=chat_id,
query=query, query=query,
@ -176,7 +178,7 @@ class SearchMessages(BaseClient):
offset += 100 offset += 100
for message in messages: for message in messages:
yield message await yield_(message)
current += 1 current += 1

View File

@ -27,7 +27,7 @@ from pyrogram.errors import FilePartMissing
class SendAnimation(BaseClient): class SendAnimation(BaseClient):
def send_animation( async def send_animation(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
animation: Union[str, BinaryIO], animation: Union[str, BinaryIO],
@ -167,8 +167,8 @@ class SendAnimation(BaseClient):
try: try:
if isinstance(animation, str): if isinstance(animation, str):
if os.path.isfile(animation): if os.path.isfile(animation):
thumb = self.save_file(thumb) thumb = await self.save_file(thumb)
file = self.save_file(animation, progress=progress, progress_args=progress_args) file = await self.save_file(animation, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument( media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(animation) or "video/mp4", mime_type=self.guess_mime_type(animation) or "video/mp4",
file=file, file=file,
@ -191,8 +191,8 @@ class SendAnimation(BaseClient):
else: else:
media = utils.get_input_media_from_file_id(animation, file_ref, 10) media = utils.get_input_media_from_file_id(animation, file_ref, 10)
else: else:
thumb = self.save_file(thumb) thumb = await self.save_file(thumb)
file = self.save_file(animation, progress=progress, progress_args=progress_args) file = await self.save_file(animation, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument( media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(animation.name) or "video/mp4", mime_type=self.guess_mime_type(animation.name) or "video/mp4",
file=file, file=file,
@ -211,27 +211,27 @@ class SendAnimation(BaseClient):
while True: while True:
try: try:
r = self.send( r = await self.send(
functions.messages.SendMedia( functions.messages.SendMedia(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=media, media=media,
silent=disable_notification or None, silent=disable_notification or None,
reply_to_msg_id=reply_to_message_id, reply_to_msg_id=reply_to_message_id,
random_id=self.rnd_id(), random_id=self.rnd_id(),
schedule_date=schedule_date, schedule_date=schedule_date,
reply_markup=reply_markup.write() if reply_markup else None, reply_markup=reply_markup.write() if reply_markup else None,
**self.parser.parse(caption, parse_mode) **await self.parser.parse(caption, parse_mode)
) )
) )
except FilePartMissing as e: except FilePartMissing as e:
self.save_file(animation, file_id=file.id, file_part=e.x) await self.save_file(animation, file_id=file.id, file_part=e.x)
else: else:
for i in r.updates: for i in r.updates:
if isinstance( if isinstance(
i, i,
(types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage) (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)
): ):
message = pyrogram.Message._parse( message = await pyrogram.Message._parse(
self, i.message, self, i.message,
{i.id: i for i in r.users}, {i.id: i for i in r.users},
{i.id: i for i in r.chats}, {i.id: i for i in r.chats},
@ -242,7 +242,7 @@ class SendAnimation(BaseClient):
document = message.animation or message.document document = message.animation or message.document
document_id = utils.get_input_media_from_file_id(document.file_id, document.file_ref).id document_id = utils.get_input_media_from_file_id(document.file_id, document.file_ref).id
self.send( await self.send(
functions.messages.SaveGif( functions.messages.SaveGif(
id=document_id, id=document_id,
unsave=True unsave=True

View File

@ -27,7 +27,7 @@ from pyrogram.errors import FilePartMissing
class SendAudio(BaseClient): class SendAudio(BaseClient):
def send_audio( async def send_audio(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
audio: Union[str, BinaryIO], audio: Union[str, BinaryIO],
@ -37,8 +37,7 @@ class SendAudio(BaseClient):
duration: int = 0, duration: int = 0,
performer: str = None, performer: str = None,
title: str = None, title: str = None,
thumb: Union[str, BinaryIO] = None, thumb: Union[str, BinaryIO] = None, file_name: str = None,
file_name: str = None,
disable_notification: bool = None, disable_notification: bool = None,
reply_to_message_id: int = None, reply_to_message_id: int = None,
schedule_date: int = None, schedule_date: int = None,
@ -167,8 +166,8 @@ class SendAudio(BaseClient):
try: try:
if isinstance(audio, str): if isinstance(audio, str):
if os.path.isfile(audio): if os.path.isfile(audio):
thumb = self.save_file(thumb) thumb = await self.save_file(thumb)
file = self.save_file(audio, progress=progress, progress_args=progress_args) file = await self.save_file(audio, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument( media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(audio) or "audio/mpeg", mime_type=self.guess_mime_type(audio) or "audio/mpeg",
file=file, file=file,
@ -189,8 +188,8 @@ class SendAudio(BaseClient):
else: else:
media = utils.get_input_media_from_file_id(audio, file_ref, 9) media = utils.get_input_media_from_file_id(audio, file_ref, 9)
else: else:
thumb = self.save_file(thumb) thumb = await self.save_file(thumb)
file = self.save_file(audio, progress=progress, progress_args=progress_args) file = await self.save_file(audio, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument( media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(audio.name) or "audio/mpeg", mime_type=self.guess_mime_type(audio.name) or "audio/mpeg",
file=file, file=file,
@ -207,27 +206,27 @@ class SendAudio(BaseClient):
while True: while True:
try: try:
r = self.send( r = await self.send(
functions.messages.SendMedia( functions.messages.SendMedia(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=media, media=media,
silent=disable_notification or None, silent=disable_notification or None,
reply_to_msg_id=reply_to_message_id, reply_to_msg_id=reply_to_message_id,
random_id=self.rnd_id(), random_id=self.rnd_id(),
schedule_date=schedule_date, schedule_date=schedule_date,
reply_markup=reply_markup.write() if reply_markup else None, reply_markup=reply_markup.write() if reply_markup else None,
**self.parser.parse(caption, parse_mode) **await self.parser.parse(caption, parse_mode)
) )
) )
except FilePartMissing as e: except FilePartMissing as e:
self.save_file(audio, file_id=file.id, file_part=e.x) await self.save_file(audio, file_id=file.id, file_part=e.x)
else: else:
for i in r.updates: for i in r.updates:
if isinstance( if isinstance(
i, i,
(types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage) (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)
): ):
return pyrogram.Message._parse( return await pyrogram.Message._parse(
self, i.message, self, i.message,
{i.id: i for i in r.users}, {i.id: i for i in r.users},
{i.id: i for i in r.chats}, {i.id: i for i in r.chats},

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient, utils
class SendCachedMedia(BaseClient): class SendCachedMedia(BaseClient):
def send_cached_media( async def send_cached_media(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
file_id: str, file_id: str,
@ -94,22 +94,22 @@ class SendCachedMedia(BaseClient):
app.send_cached_media("me", "CAADBAADzg4AAvLQYAEz_x2EOgdRwBYE") app.send_cached_media("me", "CAADBAADzg4AAvLQYAEz_x2EOgdRwBYE")
""" """
r = self.send( r = await self.send(
functions.messages.SendMedia( functions.messages.SendMedia(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=utils.get_input_media_from_file_id(file_id, file_ref), media=utils.get_input_media_from_file_id(file_id, file_ref),
silent=disable_notification or None, silent=disable_notification or None,
reply_to_msg_id=reply_to_message_id, reply_to_msg_id=reply_to_message_id,
random_id=self.rnd_id(), random_id=self.rnd_id(),
schedule_date=schedule_date, schedule_date=schedule_date,
reply_markup=reply_markup.write() if reply_markup else None, reply_markup=reply_markup.write() if reply_markup else None,
**self.parser.parse(caption, parse_mode) **await self.parser.parse(caption, parse_mode)
) )
) )
for i in r.updates: for i in r.updates:
if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)): if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)):
return pyrogram.Message._parse( return await pyrogram.Message._parse(
self, i.message, self, i.message,
{i.id: i for i in r.users}, {i.id: i for i in r.users},
{i.id: i for i in r.chats}, {i.id: i for i in r.chats},

View File

@ -43,7 +43,7 @@ POSSIBLE_VALUES = list(map(lambda x: x.lower(), filter(lambda x: not x.startswit
class SendChatAction(BaseClient): class SendChatAction(BaseClient):
def send_chat_action(self, chat_id: Union[int, str], action: str) -> bool: async def send_chat_action(self, chat_id: Union[int, str], action: str) -> bool:
"""Tell the other party that something is happening on your side. """Tell the other party that something is happening on your side.
Parameters: Parameters:
@ -93,9 +93,9 @@ class SendChatAction(BaseClient):
else: else:
action = action() action = action()
return self.send( return await self.send(
functions.messages.SetTyping( functions.messages.SetTyping(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
action=action action=action
) )
) )

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class SendContact(BaseClient): class SendContact(BaseClient):
def send_contact( async def send_contact(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
phone_number: str, phone_number: str,
@ -83,9 +83,9 @@ class SendContact(BaseClient):
app.send_contact("me", "+39 123 456 7890", "Dan") app.send_contact("me", "+39 123 456 7890", "Dan")
""" """
r = self.send( r = await self.send(
functions.messages.SendMedia( functions.messages.SendMedia(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=types.InputMediaContact( media=types.InputMediaContact(
phone_number=phone_number, phone_number=phone_number,
first_name=first_name, first_name=first_name,
@ -103,7 +103,7 @@ class SendContact(BaseClient):
for i in r.updates: for i in r.updates:
if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)): if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)):
return pyrogram.Message._parse( return await pyrogram.Message._parse(
self, i.message, self, i.message,
{i.id: i for i in r.users}, {i.id: i for i in r.users},
{i.id: i for i in r.chats}, {i.id: i for i in r.chats},

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class SendDice(BaseClient): class SendDice(BaseClient):
def send_dice( async def send_dice(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
emoji: str = "🎲", emoji: str = "🎲",
@ -80,9 +80,9 @@ class SendDice(BaseClient):
app.send_dice("pyrogramlounge", "🏀") app.send_dice("pyrogramlounge", "🏀")
""" """
r = self.send( r = await self.send(
functions.messages.SendMedia( functions.messages.SendMedia(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=types.InputMediaDice(emoticon=emoji), media=types.InputMediaDice(emoticon=emoji),
silent=disable_notification or None, silent=disable_notification or None,
reply_to_msg_id=reply_to_message_id, reply_to_msg_id=reply_to_message_id,
@ -94,11 +94,8 @@ class SendDice(BaseClient):
) )
for i in r.updates: for i in r.updates:
if isinstance( if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)):
i, return await pyrogram.Message._parse(
(types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)
):
return pyrogram.Message._parse(
self, i.message, self, i.message,
{i.id: i for i in r.users}, {i.id: i for i in r.users},
{i.id: i for i in r.chats}, {i.id: i for i in r.chats},

View File

@ -27,7 +27,7 @@ from pyrogram.errors import FilePartMissing
class SendDocument(BaseClient): class SendDocument(BaseClient):
def send_document( async def send_document(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
document: Union[str, BinaryIO], document: Union[str, BinaryIO],
@ -147,8 +147,8 @@ class SendDocument(BaseClient):
try: try:
if isinstance(document, str): if isinstance(document, str):
if os.path.isfile(document): if os.path.isfile(document):
thumb = self.save_file(thumb) thumb = await self.save_file(thumb)
file = self.save_file(document, progress=progress, progress_args=progress_args) file = await self.save_file(document, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument( media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(document) or "application/zip", mime_type=self.guess_mime_type(document) or "application/zip",
file=file, file=file,
@ -165,8 +165,8 @@ class SendDocument(BaseClient):
else: else:
media = utils.get_input_media_from_file_id(document, file_ref, 5) media = utils.get_input_media_from_file_id(document, file_ref, 5)
else: else:
thumb = self.save_file(thumb) thumb = await self.save_file(thumb)
file = self.save_file(document, progress=progress, progress_args=progress_args) file = await self.save_file(document, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument( media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(document.name) or "application/zip", mime_type=self.guess_mime_type(document.name) or "application/zip",
file=file, file=file,
@ -178,27 +178,27 @@ class SendDocument(BaseClient):
while True: while True:
try: try:
r = self.send( r = await self.send(
functions.messages.SendMedia( functions.messages.SendMedia(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=media, media=media,
silent=disable_notification or None, silent=disable_notification or None,
reply_to_msg_id=reply_to_message_id, reply_to_msg_id=reply_to_message_id,
random_id=self.rnd_id(), random_id=self.rnd_id(),
schedule_date=schedule_date, schedule_date=schedule_date,
reply_markup=reply_markup.write() if reply_markup else None, reply_markup=reply_markup.write() if reply_markup else None,
**self.parser.parse(caption, parse_mode) **await self.parser.parse(caption, parse_mode)
) )
) )
except FilePartMissing as e: except FilePartMissing as e:
self.save_file(document, file_id=file.id, file_part=e.x) await self.save_file(document, file_id=file.id, file_part=e.x)
else: else:
for i in r.updates: for i in r.updates:
if isinstance( if isinstance(
i, i,
(types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage) (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)
): ):
return pyrogram.Message._parse( return await pyrogram.Message._parse(
self, i.message, self, i.message,
{i.id: i for i in r.users}, {i.id: i for i in r.users},
{i.id: i for i in r.chats}, {i.id: i for i in r.chats},

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class SendLocation(BaseClient): class SendLocation(BaseClient):
def send_location( async def send_location(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
latitude: float, latitude: float,
@ -75,9 +75,9 @@ class SendLocation(BaseClient):
app.send_location("me", 51.500729, -0.124583) app.send_location("me", 51.500729, -0.124583)
""" """
r = self.send( r = await self.send(
functions.messages.SendMedia( functions.messages.SendMedia(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=types.InputMediaGeoPoint( media=types.InputMediaGeoPoint(
geo_point=types.InputGeoPoint( geo_point=types.InputGeoPoint(
lat=latitude, lat=latitude,
@ -95,7 +95,7 @@ class SendLocation(BaseClient):
for i in r.updates: for i in r.updates:
if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)): if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)):
return pyrogram.Message._parse( return await pyrogram.Message._parse(
self, i.message, self, i.message,
{i.id: i for i in r.users}, {i.id: i for i in r.users},
{i.id: i for i in r.chats}, {i.id: i for i in r.chats},

View File

@ -30,7 +30,7 @@ log = logging.getLogger(__name__)
class SendMediaGroup(BaseClient): class SendMediaGroup(BaseClient):
# TODO: Add progress parameter # TODO: Add progress parameter
def send_media_group( async def send_media_group(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
media: List[Union["pyrogram.InputMediaPhoto", "pyrogram.InputMediaVideo"]], media: List[Union["pyrogram.InputMediaPhoto", "pyrogram.InputMediaVideo"]],
@ -77,11 +77,11 @@ class SendMediaGroup(BaseClient):
for i in media: for i in media:
if isinstance(i, pyrogram.InputMediaPhoto): if isinstance(i, pyrogram.InputMediaPhoto):
if os.path.isfile(i.media): if os.path.isfile(i.media):
media = self.send( media = await self.send(
functions.messages.UploadMedia( functions.messages.UploadMedia(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=types.InputMediaUploadedPhoto( media=types.InputMediaUploadedPhoto(
file=self.save_file(i.media) file=await self.save_file(i.media)
) )
) )
) )
@ -94,9 +94,9 @@ class SendMediaGroup(BaseClient):
) )
) )
elif re.match("^https?://", i.media): elif re.match("^https?://", i.media):
media = self.send( media = await self.send(
functions.messages.UploadMedia( functions.messages.UploadMedia(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=types.InputMediaPhotoExternal( media=types.InputMediaPhotoExternal(
url=i.media url=i.media
) )
@ -114,11 +114,11 @@ class SendMediaGroup(BaseClient):
media = utils.get_input_media_from_file_id(i.media, i.file_ref, 2) media = utils.get_input_media_from_file_id(i.media, i.file_ref, 2)
elif isinstance(i, pyrogram.InputMediaVideo): elif isinstance(i, pyrogram.InputMediaVideo):
if os.path.isfile(i.media): if os.path.isfile(i.media):
media = self.send( media = await self.send(
functions.messages.UploadMedia( functions.messages.UploadMedia(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=types.InputMediaUploadedDocument( media=types.InputMediaUploadedDocument(
file=self.save_file(i.media), file=await self.save_file(i.media),
thumb=self.save_file(i.thumb), thumb=self.save_file(i.thumb),
mime_type=self.guess_mime_type(i.media) or "video/mp4", mime_type=self.guess_mime_type(i.media) or "video/mp4",
attributes=[ attributes=[
@ -142,9 +142,9 @@ class SendMediaGroup(BaseClient):
) )
) )
elif re.match("^https?://", i.media): elif re.match("^https?://", i.media):
media = self.send( media = await self.send(
functions.messages.UploadMedia( functions.messages.UploadMedia(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=types.InputMediaDocumentExternal( media=types.InputMediaDocumentExternal(
url=i.media url=i.media
) )
@ -165,20 +165,20 @@ class SendMediaGroup(BaseClient):
types.InputSingleMedia( types.InputSingleMedia(
media=media, media=media,
random_id=self.rnd_id(), random_id=self.rnd_id(),
**self.parser.parse(i.caption, i.parse_mode) **await self.parser.parse(i.caption, i.parse_mode)
) )
) )
r = self.send( r = await self.send(
functions.messages.SendMultiMedia( functions.messages.SendMultiMedia(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
multi_media=multi_media, multi_media=multi_media,
silent=disable_notification or None, silent=disable_notification or None,
reply_to_msg_id=reply_to_message_id reply_to_msg_id=reply_to_message_id
) )
) )
return utils.parse_messages( return await utils.parse_messages(
self, self,
types.messages.Messages( types.messages.Messages(
messages=[m.message for m in filter( messages=[m.message for m in filter(

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient
class SendMessage(BaseClient): class SendMessage(BaseClient):
def send_message( async def send_message(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
text: str, text: str,
@ -116,11 +116,11 @@ class SendMessage(BaseClient):
])) ]))
""" """
message, entities = self.parser.parse(text, parse_mode).values() message, entities = (await self.parser.parse(text, parse_mode)).values()
r = self.send( r = await self.send(
functions.messages.SendMessage( functions.messages.SendMessage(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
no_webpage=disable_web_page_preview or None, no_webpage=disable_web_page_preview or None,
silent=disable_notification or None, silent=disable_notification or None,
reply_to_msg_id=reply_to_message_id, reply_to_msg_id=reply_to_message_id,
@ -133,7 +133,7 @@ class SendMessage(BaseClient):
) )
if isinstance(r, types.UpdateShortSentMessage): if isinstance(r, types.UpdateShortSentMessage):
peer = self.resolve_peer(chat_id) peer = await self.resolve_peer(chat_id)
peer_id = ( peer_id = (
peer.user_id peer.user_id
@ -160,7 +160,7 @@ class SendMessage(BaseClient):
for i in r.updates: for i in r.updates:
if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)): if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)):
return pyrogram.Message._parse( return await pyrogram.Message._parse(
self, i.message, self, i.message,
{i.id: i for i in r.users}, {i.id: i for i in r.users},
{i.id: i for i in r.chats}, {i.id: i for i in r.chats},

View File

@ -27,7 +27,7 @@ from pyrogram.errors import FilePartMissing
class SendPhoto(BaseClient): class SendPhoto(BaseClient):
def send_photo( async def send_photo(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
photo: Union[str, BinaryIO], photo: Union[str, BinaryIO],
@ -141,7 +141,7 @@ class SendPhoto(BaseClient):
try: try:
if isinstance(photo, str): if isinstance(photo, str):
if os.path.isfile(photo): if os.path.isfile(photo):
file = self.save_file(photo, progress=progress, progress_args=progress_args) file = await self.save_file(photo, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedPhoto( media = types.InputMediaUploadedPhoto(
file=file, file=file,
ttl_seconds=ttl_seconds ttl_seconds=ttl_seconds
@ -154,7 +154,7 @@ class SendPhoto(BaseClient):
else: else:
media = utils.get_input_media_from_file_id(photo, file_ref, 2) media = utils.get_input_media_from_file_id(photo, file_ref, 2)
else: else:
file = self.save_file(photo, progress=progress, progress_args=progress_args) file = await self.save_file(photo, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedPhoto( media = types.InputMediaUploadedPhoto(
file=file, file=file,
ttl_seconds=ttl_seconds ttl_seconds=ttl_seconds
@ -162,27 +162,27 @@ class SendPhoto(BaseClient):
while True: while True:
try: try:
r = self.send( r = await self.send(
functions.messages.SendMedia( functions.messages.SendMedia(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=media, media=media,
silent=disable_notification or None, silent=disable_notification or None,
reply_to_msg_id=reply_to_message_id, reply_to_msg_id=reply_to_message_id,
random_id=self.rnd_id(), random_id=self.rnd_id(),
schedule_date=schedule_date, schedule_date=schedule_date,
reply_markup=reply_markup.write() if reply_markup else None, reply_markup=reply_markup.write() if reply_markup else None,
**self.parser.parse(caption, parse_mode) **await self.parser.parse(caption, parse_mode)
) )
) )
except FilePartMissing as e: except FilePartMissing as e:
self.save_file(photo, file_id=file.id, file_part=e.x) await self.save_file(photo, file_id=file.id, file_part=e.x)
else: else:
for i in r.updates: for i in r.updates:
if isinstance( if isinstance(
i, i,
(types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage) (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)
): ):
return pyrogram.Message._parse( return await pyrogram.Message._parse(
self, i.message, self, i.message,
{i.id: i for i in r.users}, {i.id: i for i in r.users},
{i.id: i for i in r.chats}, {i.id: i for i in r.chats},

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class SendPoll(BaseClient): class SendPoll(BaseClient):
def send_poll( async def send_poll(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
question: str, question: str,
@ -95,9 +95,9 @@ class SendPoll(BaseClient):
app.send_poll(chat_id, "Is this a poll question?", ["Yes", "No", "Maybe"]) app.send_poll(chat_id, "Is this a poll question?", ["Yes", "No", "Maybe"])
""" """
r = self.send( r = await self.send(
functions.messages.SendMedia( functions.messages.SendMedia(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=types.InputMediaPoll( media=types.InputMediaPoll(
poll=types.Poll( poll=types.Poll(
id=0, id=0,
@ -123,7 +123,7 @@ class SendPoll(BaseClient):
for i in r.updates: for i in r.updates:
if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)): if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)):
return pyrogram.Message._parse( return await pyrogram.Message._parse(
self, i.message, self, i.message,
{i.id: i for i in r.users}, {i.id: i for i in r.users},
{i.id: i for i in r.chats}, {i.id: i for i in r.chats},

View File

@ -27,7 +27,7 @@ from pyrogram.errors import FilePartMissing
class SendSticker(BaseClient): class SendSticker(BaseClient):
def send_sticker( async def send_sticker(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
sticker: Union[str, BinaryIO], sticker: Union[str, BinaryIO],
@ -117,7 +117,7 @@ class SendSticker(BaseClient):
try: try:
if isinstance(sticker, str): if isinstance(sticker, str):
if os.path.isfile(sticker): if os.path.isfile(sticker):
file = self.save_file(sticker, progress=progress, progress_args=progress_args) file = await self.save_file(sticker, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument( media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(sticker) or "image/webp", mime_type=self.guess_mime_type(sticker) or "image/webp",
file=file, file=file,
@ -132,7 +132,7 @@ class SendSticker(BaseClient):
else: else:
media = utils.get_input_media_from_file_id(sticker, file_ref, 8) media = utils.get_input_media_from_file_id(sticker, file_ref, 8)
else: else:
file = self.save_file(sticker, progress=progress, progress_args=progress_args) file = await self.save_file(sticker, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument( media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(sticker.name) or "image/webp", mime_type=self.guess_mime_type(sticker.name) or "image/webp",
file=file, file=file,
@ -143,9 +143,9 @@ class SendSticker(BaseClient):
while True: while True:
try: try:
r = self.send( r = await self.send(
functions.messages.SendMedia( functions.messages.SendMedia(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=media, media=media,
silent=disable_notification or None, silent=disable_notification or None,
reply_to_msg_id=reply_to_message_id, reply_to_msg_id=reply_to_message_id,
@ -156,14 +156,14 @@ class SendSticker(BaseClient):
) )
) )
except FilePartMissing as e: except FilePartMissing as e:
self.save_file(sticker, file_id=file.id, file_part=e.x) await self.save_file(sticker, file_id=file.id, file_part=e.x)
else: else:
for i in r.updates: for i in r.updates:
if isinstance( if isinstance(
i, i,
(types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage) (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)
): ):
return pyrogram.Message._parse( return await pyrogram.Message._parse(
self, i.message, self, i.message,
{i.id: i for i in r.users}, {i.id: i for i in r.users},
{i.id: i for i in r.chats}, {i.id: i for i in r.chats},

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class SendVenue(BaseClient): class SendVenue(BaseClient):
def send_venue( async def send_venue(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
latitude: float, latitude: float,
@ -94,9 +94,9 @@ class SendVenue(BaseClient):
"me", 51.500729, -0.124583, "me", 51.500729, -0.124583,
"Elizabeth Tower", "Westminster, London SW1A 0AA, UK") "Elizabeth Tower", "Westminster, London SW1A 0AA, UK")
""" """
r = self.send( r = await self.send(
functions.messages.SendMedia( functions.messages.SendMedia(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=types.InputMediaVenue( media=types.InputMediaVenue(
geo_point=types.InputGeoPoint( geo_point=types.InputGeoPoint(
lat=latitude, lat=latitude,
@ -119,7 +119,7 @@ class SendVenue(BaseClient):
for i in r.updates: for i in r.updates:
if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)): if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)):
return pyrogram.Message._parse( return await pyrogram.Message._parse(
self, i.message, self, i.message,
{i.id: i for i in r.users}, {i.id: i for i in r.users},
{i.id: i for i in r.chats}, {i.id: i for i in r.chats},

View File

@ -27,7 +27,7 @@ from pyrogram.errors import FilePartMissing
class SendVideo(BaseClient): class SendVideo(BaseClient):
def send_video( async def send_video(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
video: Union[str, BinaryIO], video: Union[str, BinaryIO],
@ -164,8 +164,8 @@ class SendVideo(BaseClient):
try: try:
if isinstance(video, str): if isinstance(video, str):
if os.path.isfile(video): if os.path.isfile(video):
thumb = self.save_file(thumb) thumb = await self.save_file(thumb)
file = self.save_file(video, progress=progress, progress_args=progress_args) file = await self.save_file(video, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument( media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(video) or "video/mp4", mime_type=self.guess_mime_type(video) or "video/mp4",
file=file, file=file,
@ -187,8 +187,8 @@ class SendVideo(BaseClient):
else: else:
media = utils.get_input_media_from_file_id(video, file_ref, 4) media = utils.get_input_media_from_file_id(video, file_ref, 4)
else: else:
thumb = self.save_file(thumb) thumb = await self.save_file(thumb)
file = self.save_file(video, progress=progress, progress_args=progress_args) file = await self.save_file(video, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument( media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(video.name) or "video/mp4", mime_type=self.guess_mime_type(video.name) or "video/mp4",
file=file, file=file,
@ -206,27 +206,27 @@ class SendVideo(BaseClient):
while True: while True:
try: try:
r = self.send( r = await self.send(
functions.messages.SendMedia( functions.messages.SendMedia(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=media, media=media,
silent=disable_notification or None, silent=disable_notification or None,
reply_to_msg_id=reply_to_message_id, reply_to_msg_id=reply_to_message_id,
random_id=self.rnd_id(), random_id=self.rnd_id(),
schedule_date=schedule_date, schedule_date=schedule_date,
reply_markup=reply_markup.write() if reply_markup else None, reply_markup=reply_markup.write() if reply_markup else None,
**self.parser.parse(caption, parse_mode) **await self.parser.parse(caption, parse_mode)
) )
) )
except FilePartMissing as e: except FilePartMissing as e:
self.save_file(video, file_id=file.id, file_part=e.x) await self.save_file(video, file_id=file.id, file_part=e.x)
else: else:
for i in r.updates: for i in r.updates:
if isinstance( if isinstance(
i, i,
(types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage) (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)
): ):
return pyrogram.Message._parse( return await pyrogram.Message._parse(
self, i.message, self, i.message,
{i.id: i for i in r.users}, {i.id: i for i in r.users},
{i.id: i for i in r.chats}, {i.id: i for i in r.chats},

View File

@ -26,7 +26,7 @@ from pyrogram.errors import FilePartMissing
class SendVideoNote(BaseClient): class SendVideoNote(BaseClient):
def send_video_note( async def send_video_note(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
video_note: Union[str, BinaryIO], video_note: Union[str, BinaryIO],
@ -131,8 +131,8 @@ class SendVideoNote(BaseClient):
try: try:
if isinstance(video_note, str): if isinstance(video_note, str):
if os.path.isfile(video_note): if os.path.isfile(video_note):
thumb = self.save_file(thumb) thumb = await self.save_file(thumb)
file = self.save_file(video_note, progress=progress, progress_args=progress_args) file = await self.save_file(video_note, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument( media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(video_note) or "video/mp4", mime_type=self.guess_mime_type(video_note) or "video/mp4",
file=file, file=file,
@ -149,8 +149,8 @@ class SendVideoNote(BaseClient):
else: else:
media = utils.get_input_media_from_file_id(video_note, file_ref, 13) media = utils.get_input_media_from_file_id(video_note, file_ref, 13)
else: else:
thumb = self.save_file(thumb) thumb = await self.save_file(thumb)
file = self.save_file(video_note, progress=progress, progress_args=progress_args) file = await self.save_file(video_note, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument( media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(video_note.name) or "video/mp4", mime_type=self.guess_mime_type(video_note.name) or "video/mp4",
file=file, file=file,
@ -167,9 +167,9 @@ class SendVideoNote(BaseClient):
while True: while True:
try: try:
r = self.send( r = await self.send(
functions.messages.SendMedia( functions.messages.SendMedia(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=media, media=media,
silent=disable_notification or None, silent=disable_notification or None,
reply_to_msg_id=reply_to_message_id, reply_to_msg_id=reply_to_message_id,
@ -180,14 +180,14 @@ class SendVideoNote(BaseClient):
) )
) )
except FilePartMissing as e: except FilePartMissing as e:
self.save_file(video_note, file_id=file.id, file_part=e.x) await self.save_file(video_note, file_id=file.id, file_part=e.x)
else: else:
for i in r.updates: for i in r.updates:
if isinstance( if isinstance(
i, i,
(types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage) (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)
): ):
return pyrogram.Message._parse( return await pyrogram.Message._parse(
self, i.message, self, i.message,
{i.id: i for i in r.users}, {i.id: i for i in r.users},
{i.id: i for i in r.chats}, {i.id: i for i in r.chats},

View File

@ -27,7 +27,7 @@ from pyrogram.errors import FilePartMissing
class SendVoice(BaseClient): class SendVoice(BaseClient):
def send_voice( async def send_voice(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
voice: Union[str, BinaryIO], voice: Union[str, BinaryIO],
@ -136,7 +136,7 @@ class SendVoice(BaseClient):
try: try:
if isinstance(voice, str): if isinstance(voice, str):
if os.path.isfile(voice): if os.path.isfile(voice):
file = self.save_file(voice, progress=progress, progress_args=progress_args) file = await self.save_file(voice, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument( media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(voice) or "audio/mpeg", mime_type=self.guess_mime_type(voice) or "audio/mpeg",
file=file, file=file,
@ -154,7 +154,7 @@ class SendVoice(BaseClient):
else: else:
media = utils.get_input_media_from_file_id(voice, file_ref, 3) media = utils.get_input_media_from_file_id(voice, file_ref, 3)
else: else:
file = self.save_file(voice, progress=progress, progress_args=progress_args) file = await self.save_file(voice, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument( media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(voice.name) or "audio/mpeg", mime_type=self.guess_mime_type(voice.name) or "audio/mpeg",
file=file, file=file,
@ -168,27 +168,27 @@ class SendVoice(BaseClient):
while True: while True:
try: try:
r = self.send( r = await self.send(
functions.messages.SendMedia( functions.messages.SendMedia(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
media=media, media=media,
silent=disable_notification or None, silent=disable_notification or None,
reply_to_msg_id=reply_to_message_id, reply_to_msg_id=reply_to_message_id,
random_id=self.rnd_id(), random_id=self.rnd_id(),
schedule_date=schedule_date, schedule_date=schedule_date,
reply_markup=reply_markup.write() if reply_markup else None, reply_markup=reply_markup.write() if reply_markup else None,
**self.parser.parse(caption, parse_mode) **await self.parser.parse(caption, parse_mode)
) )
) )
except FilePartMissing as e: except FilePartMissing as e:
self.save_file(voice, file_id=file.id, file_part=e.x) await self.save_file(voice, file_id=file.id, file_part=e.x)
else: else:
for i in r.updates: for i in r.updates:
if isinstance( if isinstance(
i, i,
(types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage) (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)
): ):
return pyrogram.Message._parse( return await pyrogram.Message._parse(
self, i.message, self, i.message,
{i.id: i for i in r.users}, {i.id: i for i in r.users},
{i.id: i for i in r.chats}, {i.id: i for i in r.chats},

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class StopPoll(BaseClient): class StopPoll(BaseClient):
def stop_poll( async def stop_poll(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
message_id: int, message_id: int,
@ -54,11 +54,11 @@ class StopPoll(BaseClient):
app.stop_poll(chat_id, message_id) app.stop_poll(chat_id, message_id)
""" """
poll = self.get_messages(chat_id, message_id).poll poll = (await self.get_messages(chat_id, message_id)).poll
r = self.send( r = await self.send(
functions.messages.EditMessage( functions.messages.EditMessage(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
id=message_id, id=message_id,
media=types.InputMediaPoll( media=types.InputMediaPoll(
poll=types.Poll( poll=types.Poll(

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class VotePoll(BaseClient): class VotePoll(BaseClient):
def vote_poll( async def vote_poll(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
message_id: id, message_id: id,
@ -53,12 +53,12 @@ class VotePoll(BaseClient):
app.vote_poll(chat_id, message_id, 6) app.vote_poll(chat_id, message_id, 6)
""" """
poll = self.get_messages(chat_id, message_id).poll poll = (await self.get_messages(chat_id, message_id)).poll
options = [options] if not isinstance(options, list) else options options = [options] if not isinstance(options, list) else options
r = self.send( r = await self.send(
functions.messages.SendVote( functions.messages.SendVote(
peer=self.resolve_peer(chat_id), peer=await self.resolve_peer(chat_id),
msg_id=message_id, msg_id=message_id,
options=[poll.options[option].data for option in options] options=[poll.options[option].data for option in options]
) )

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient
class ChangeCloudPassword(BaseClient): class ChangeCloudPassword(BaseClient):
def change_cloud_password( async def change_cloud_password(
self, self,
current_password: str, current_password: str,
new_password: str, new_password: str,
@ -57,7 +57,7 @@ class ChangeCloudPassword(BaseClient):
# Change password and hint # Change password and hint
app.change_cloud_password("current_password", "new_password", new_hint="hint") app.change_cloud_password("current_password", "new_password", new_hint="hint")
""" """
r = self.send(functions.account.GetPassword()) r = await self.send(functions.account.GetPassword())
if not r.has_password: if not r.has_password:
raise ValueError("There is no cloud password to change") raise ValueError("There is no cloud password to change")
@ -66,7 +66,7 @@ class ChangeCloudPassword(BaseClient):
new_hash = btoi(compute_hash(r.new_algo, new_password)) new_hash = btoi(compute_hash(r.new_algo, new_password))
new_hash = itob(pow(r.new_algo.g, new_hash, btoi(r.new_algo.p))) new_hash = itob(pow(r.new_algo.g, new_hash, btoi(r.new_algo.p)))
self.send( await self.send(
functions.account.UpdatePasswordSettings( functions.account.UpdatePasswordSettings(
password=compute_check(r, current_password), password=compute_check(r, current_password),
new_settings=types.account.PasswordInputSettings( new_settings=types.account.PasswordInputSettings(

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient
class EnableCloudPassword(BaseClient): class EnableCloudPassword(BaseClient):
def enable_cloud_password( async def enable_cloud_password(
self, self,
password: str, password: str,
hint: str = "", hint: str = "",
@ -62,7 +62,7 @@ class EnableCloudPassword(BaseClient):
# Enable password with hint and email # Enable password with hint and email
app.enable_cloud_password("password", hint="hint", email="user@email.com") app.enable_cloud_password("password", hint="hint", email="user@email.com")
""" """
r = self.send(functions.account.GetPassword()) r = await self.send(functions.account.GetPassword())
if r.has_password: if r.has_password:
raise ValueError("There is already a cloud password enabled") raise ValueError("There is already a cloud password enabled")
@ -71,7 +71,7 @@ class EnableCloudPassword(BaseClient):
new_hash = btoi(compute_hash(r.new_algo, password)) new_hash = btoi(compute_hash(r.new_algo, password))
new_hash = itob(pow(r.new_algo.g, new_hash, btoi(r.new_algo.p))) new_hash = itob(pow(r.new_algo.g, new_hash, btoi(r.new_algo.p)))
self.send( await self.send(
functions.account.UpdatePasswordSettings( functions.account.UpdatePasswordSettings(
password=types.InputCheckPasswordEmpty(), password=types.InputCheckPasswordEmpty(),
new_settings=types.account.PasswordInputSettings( new_settings=types.account.PasswordInputSettings(

View File

@ -22,7 +22,7 @@ from ...ext import BaseClient
class RemoveCloudPassword(BaseClient): class RemoveCloudPassword(BaseClient):
def remove_cloud_password( async def remove_cloud_password(
self, self,
password: str password: str
) -> bool: ) -> bool:
@ -43,12 +43,12 @@ class RemoveCloudPassword(BaseClient):
app.remove_cloud_password("password") app.remove_cloud_password("password")
""" """
r = self.send(functions.account.GetPassword()) r = await self.send(functions.account.GetPassword())
if not r.has_password: if not r.has_password:
raise ValueError("There is no cloud password to remove") raise ValueError("There is no cloud password to remove")
self.send( await self.send(
functions.account.UpdatePasswordSettings( functions.account.UpdatePasswordSettings(
password=compute_check(r, password), password=compute_check(r, password),
new_settings=types.account.PasswordInputSettings( new_settings=types.account.PasswordInputSettings(

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class BlockUser(BaseClient): class BlockUser(BaseClient):
def block_user( async def block_user(
self, self,
user_id: Union[int, str] user_id: Union[int, str]
) -> bool: ) -> bool:
@ -44,9 +44,9 @@ class BlockUser(BaseClient):
app.block_user(user_id) app.block_user(user_id)
""" """
return bool( return bool(
self.send( await self.send(
functions.contacts.Block( functions.contacts.Block(
id=self.resolve_peer(user_id) id=await self.resolve_peer(user_id)
) )
) )
) )

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient
class DeleteProfilePhotos(BaseClient): class DeleteProfilePhotos(BaseClient):
def delete_profile_photos( async def delete_profile_photos(
self, self,
photo_ids: Union[str, List[str]] photo_ids: Union[str, List[str]]
) -> bool: ) -> bool:
@ -53,7 +53,7 @@ class DeleteProfilePhotos(BaseClient):
photo_ids = photo_ids if isinstance(photo_ids, list) else [photo_ids] photo_ids = photo_ids if isinstance(photo_ids, list) else [photo_ids]
input_photos = [utils.get_input_media_from_file_id(i).id for i in photo_ids] input_photos = [utils.get_input_media_from_file_id(i).id for i in photo_ids]
return bool(self.send( return bool(await self.send(
functions.photos.DeletePhotos( functions.photos.DeletePhotos(
id=input_photos id=input_photos
) )

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient
class GetCommonChats(BaseClient): class GetCommonChats(BaseClient):
def get_common_chats(self, user_id: Union[int, str]) -> list: async def get_common_chats(self, user_id: Union[int, str]) -> list:
"""Get the common chats you have with a user. """Get the common chats you have with a user.
Parameters: Parameters:
@ -46,10 +46,10 @@ class GetCommonChats(BaseClient):
print(common) print(common)
""" """
peer = self.resolve_peer(user_id) peer = await self.resolve_peer(user_id)
if isinstance(peer, types.InputPeerUser): if isinstance(peer, types.InputPeerUser):
r = self.send( r = await self.send(
functions.messages.GetCommonChats( functions.messages.GetCommonChats(
user_id=peer, user_id=peer,
max_id=0, max_id=0,

View File

@ -22,7 +22,7 @@ from ...ext import BaseClient
class GetMe(BaseClient): class GetMe(BaseClient):
def get_me(self) -> "pyrogram.User": async def get_me(self) -> "pyrogram.User":
"""Get your own user identity. """Get your own user identity.
Returns: Returns:
@ -36,9 +36,9 @@ class GetMe(BaseClient):
""" """
return pyrogram.User._parse( return pyrogram.User._parse(
self, self,
self.send( (await self.send(
functions.users.GetFullUser( functions.users.GetFullUser(
id=types.InputPeerSelf() id=types.InputPeerSelf()
) )
).user )).user
) )

View File

@ -25,7 +25,7 @@ from ...ext import BaseClient
class GetProfilePhotos(BaseClient): class GetProfilePhotos(BaseClient):
def get_profile_photos( async def get_profile_photos(
self, self,
chat_id: Union[int, str], chat_id: Union[int, str],
offset: int = 0, offset: int = 0,
@ -62,12 +62,12 @@ class GetProfilePhotos(BaseClient):
# Get 3 profile photos of a user, skip the first 5 # Get 3 profile photos of a user, skip the first 5
app.get_profile_photos("haskell", limit=3, offset=5) app.get_profile_photos("haskell", limit=3, offset=5)
""" """
peer_id = self.resolve_peer(chat_id) peer_id = await self.resolve_peer(chat_id)
if isinstance(peer_id, types.InputPeerChannel): if isinstance(peer_id, types.InputPeerChannel):
r = utils.parse_messages( r = await utils.parse_messages(
self, self,
self.send( await self.send(
functions.messages.Search( functions.messages.Search(
peer=peer_id, peer=peer_id,
q="", q="",
@ -86,7 +86,7 @@ class GetProfilePhotos(BaseClient):
return pyrogram.List([message.new_chat_photo for message in r][:limit]) return pyrogram.List([message.new_chat_photo for message in r][:limit])
else: else:
r = self.send( r = await self.send(
functions.photos.GetUserPhotos( functions.photos.GetUserPhotos(
user_id=peer_id, user_id=peer_id,
offset=offset, offset=offset,

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class GetProfilePhotosCount(BaseClient): class GetProfilePhotosCount(BaseClient):
def get_profile_photos_count(self, chat_id: Union[int, str]) -> int: async def get_profile_photos_count(self, chat_id: Union[int, str]) -> int:
"""Get the total count of profile pictures for a user. """Get the total count of profile pictures for a user.
Parameters: Parameters:
@ -42,10 +42,10 @@ class GetProfilePhotosCount(BaseClient):
print(count) print(count)
""" """
peer_id = self.resolve_peer(chat_id) peer_id = await self.resolve_peer(chat_id)
if isinstance(peer_id, types.InputPeerChannel): if isinstance(peer_id, types.InputPeerChannel):
r = self.send( r = await self.send(
functions.messages.GetSearchCounters( functions.messages.GetSearchCounters(
peer=peer_id, peer=peer_id,
filters=[types.InputMessagesFilterChatPhotos()], filters=[types.InputMessagesFilterChatPhotos()],
@ -54,7 +54,7 @@ class GetProfilePhotosCount(BaseClient):
return r[0].count return r[0].count
else: else:
r = self.send( r = await self.send(
functions.photos.GetUserPhotos( functions.photos.GetUserPhotos(
user_id=peer_id, user_id=peer_id,
offset=0, offset=0,

View File

@ -16,6 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
from typing import Iterable, Union, List from typing import Iterable, Union, List
import pyrogram import pyrogram
@ -24,7 +25,7 @@ from ...ext import BaseClient
class GetUsers(BaseClient): class GetUsers(BaseClient):
def get_users( async def get_users(
self, self,
user_ids: Union[Iterable[Union[int, str]], int, str] user_ids: Union[Iterable[Union[int, str]], int, str]
) -> Union["pyrogram.User", List["pyrogram.User"]]: ) -> Union["pyrogram.User", List["pyrogram.User"]]:
@ -53,9 +54,9 @@ class GetUsers(BaseClient):
""" """
is_iterable = not isinstance(user_ids, (int, str)) is_iterable = not isinstance(user_ids, (int, str))
user_ids = list(user_ids) if is_iterable else [user_ids] user_ids = list(user_ids) if is_iterable else [user_ids]
user_ids = [self.resolve_peer(i) for i in user_ids] user_ids = await asyncio.gather(*[self.resolve_peer(i) for i in user_ids])
r = self.send( r = await self.send(
functions.users.GetUsers( functions.users.GetUsers(
id=user_ids id=user_ids
) )

Some files were not shown because too many files have changed in this diff Show More