Change logging hierarchy for loading plugins (#451)

Loading plugins shouldn't be considered a warning
This commit is contained in:
CyanBook 2020-08-21 07:28:27 +02:00 committed by GitHub
parent 2e08266f56
commit c8c6faa96e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
133 changed files with 1349 additions and 1207 deletions

View File

@ -16,7 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
__version__ = "0.18.0"
__version__ = "0.18.0-async"
__license__ = "GNU Lesser General Public License v3 or later (LGPLv3+)"
__copyright__ = "Copyright (C) 2017-2020 Dan <https://github.com/delivrance>"

View File

@ -16,6 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import io
import logging
import math
@ -23,14 +24,11 @@ import os
import re
import shutil
import tempfile
import threading
import time
from configparser import ConfigParser
from hashlib import sha256, md5
from importlib import import_module
from pathlib import Path
from signal import signal, SIGINT, SIGTERM, SIGABRT
from threading import Thread
from typing import Union, List, BinaryIO
from pyrogram.api import functions, types
@ -40,11 +38,13 @@ from pyrogram.client.handlers.handler import Handler
from pyrogram.client.methods.password.utils import compute_check
from pyrogram.crypto import AES
from pyrogram.errors import (
PhoneMigrate, NetworkMigrate, SessionPasswordNeeded, PeerIdInvalid, VolumeLocNotFound, UserMigrate, ChannelPrivate,
PhoneMigrate, NetworkMigrate, SessionPasswordNeeded,
PeerIdInvalid, VolumeLocNotFound, UserMigrate, ChannelPrivate,
AuthBytesInvalid, BadRequest
)
from pyrogram.session import Auth, Session
from .ext import utils, Syncer, BaseClient, Dispatcher
from .ext.utils import ainput
from .methods import Methods
from .storage import Storage, FileStorage, MemoryStorage
from .types import User, SentCode, TermsOfService
@ -127,7 +127,7 @@ class Client(Methods, BaseClient):
Defaults to False.
workers (``int``, *optional*):
Thread pool size for handling incoming updates.
Number of maximum concurrent workers for handling incoming updates.
Defaults to 4.
workdir (``str``, *optional*):
@ -243,6 +243,12 @@ class Client(Methods, BaseClient):
except ConnectionError:
pass
async def __aenter__(self):
return await self.start()
async def __aexit__(self, *args):
await self.stop()
@property
def proxy(self):
return self._proxy
@ -259,7 +265,7 @@ class Client(Methods, BaseClient):
self._proxy["enabled"] = bool(value.get("enabled", True))
self._proxy.update(value)
def connect(self) -> bool:
async def connect(self) -> bool:
"""
Connect the client to Telegram servers.
@ -274,16 +280,17 @@ class Client(Methods, BaseClient):
raise ConnectionError("Client is already connected")
self.load_config()
self.load_session()
await self.load_session()
self.session = Session(self, self.storage.dc_id(), self.storage.auth_key())
self.session.start()
await self.session.start()
self.is_connected = True
return bool(self.storage.user_id())
def disconnect(self):
async def disconnect(self):
"""Disconnect the client from Telegram servers.
Raises:
@ -296,11 +303,11 @@ class Client(Methods, BaseClient):
if self.is_initialized:
raise ConnectionError("Can't disconnect an initialized client")
self.session.stop()
await self.session.stop()
self.storage.close()
self.is_connected = False
def initialize(self):
async def initialize(self):
"""Initialize the client by starting up workers.
This method will start updates and download workers.
@ -319,33 +326,26 @@ class Client(Methods, BaseClient):
self.load_plugins()
if not self.no_updates:
for i in range(self.UPDATES_WORKERS):
self.updates_workers_list.append(
Thread(
target=self.updates_worker,
name="UpdatesWorker#{}".format(i + 1)
)
for _ in range(Client.UPDATES_WORKERS):
self.updates_worker_tasks.append(
asyncio.ensure_future(self.updates_worker())
)
self.updates_workers_list[-1].start()
logging.info("Started {} UpdatesWorkerTasks".format(Client.UPDATES_WORKERS))
for i in range(self.DOWNLOAD_WORKERS):
self.download_workers_list.append(
Thread(
target=self.download_worker,
name="DownloadWorker#{}".format(i + 1)
)
for _ in range(Client.DOWNLOAD_WORKERS):
self.download_worker_tasks.append(
asyncio.ensure_future(self.download_worker())
)
self.download_workers_list[-1].start()
logging.info("Started {} DownloadWorkerTasks".format(Client.DOWNLOAD_WORKERS))
self.dispatcher.start()
Syncer.add(self)
await self.dispatcher.start()
await Syncer.add(self)
self.is_initialized = True
def terminate(self):
async def terminate(self):
"""Terminate the client by shutting down workers.
This method does the opposite of :meth:`~Client.initialize`.
@ -358,37 +358,41 @@ class Client(Methods, BaseClient):
raise ConnectionError("Client is already terminated")
if self.takeout_id:
self.send(functions.account.FinishTakeoutSession())
await self.send(functions.account.FinishTakeoutSession())
log.warning("Takeout session {} finished".format(self.takeout_id))
Syncer.remove(self)
self.dispatcher.stop()
await Syncer.remove(self)
await self.dispatcher.stop()
for _ in range(self.DOWNLOAD_WORKERS):
self.download_queue.put(None)
for _ in range(Client.DOWNLOAD_WORKERS):
self.download_queue.put_nowait(None)
for i in self.download_workers_list:
i.join()
for task in self.download_worker_tasks:
await task
self.download_workers_list.clear()
self.download_worker_tasks.clear()
logging.info("Stopped {} DownloadWorkerTasks".format(Client.DOWNLOAD_WORKERS))
if not self.no_updates:
for _ in range(self.UPDATES_WORKERS):
self.updates_queue.put(None)
for _ in range(Client.UPDATES_WORKERS):
self.updates_queue.put_nowait(None)
for i in self.updates_workers_list:
i.join()
for task in self.updates_worker_tasks:
await task
self.updates_workers_list.clear()
self.updates_worker_tasks.clear()
for i in self.media_sessions.values():
i.stop()
logging.info("Stopped {} UpdatesWorkerTasks".format(Client.UPDATES_WORKERS))
for media_session in self.media_sessions.values():
await media_session.stop()
self.media_sessions.clear()
self.is_initialized = False
def send_code(self, phone_number: str) -> SentCode:
async def send_code(self, phone_number: str) -> SentCode:
"""Send the confirmation code to the given phone number.
Parameters:
@ -405,7 +409,7 @@ class Client(Methods, BaseClient):
while True:
try:
r = self.send(
r = await self.send(
functions.auth.SendCode(
phone_number=phone_number,
api_id=self.api_id,
@ -414,17 +418,17 @@ class Client(Methods, BaseClient):
)
)
except (PhoneMigrate, NetworkMigrate) as e:
self.session.stop()
await self.session.stop()
self.storage.dc_id(e.x)
self.storage.auth_key(Auth(self, self.storage.dc_id()).create())
self.storage.auth_key(await Auth(self, self.storage.dc_id()).create())
self.session = Session(self, self.storage.dc_id(), self.storage.auth_key())
self.session.start()
await self.session.start()
else:
return SentCode._parse(r)
def resend_code(self, phone_number: str, phone_code_hash: str) -> SentCode:
async def resend_code(self, phone_number: str, phone_code_hash: str) -> SentCode:
"""Re-send the confirmation code using a different type.
The type of the code to be re-sent is specified in the *next_type* attribute of the :obj:`SentCode` object
@ -445,7 +449,7 @@ class Client(Methods, BaseClient):
"""
phone_number = phone_number.strip(" +")
r = self.send(
r = await self.send(
functions.auth.ResendCode(
phone_number=phone_number,
phone_code_hash=phone_code_hash
@ -454,7 +458,8 @@ class Client(Methods, BaseClient):
return SentCode._parse(r)
def sign_in(self, phone_number: str, phone_code_hash: str, phone_code: str) -> Union[User, TermsOfService, bool]:
async def sign_in(self, phone_number: str, phone_code_hash: str, phone_code: str) -> Union[
User, TermsOfService, bool]:
"""Authorize a user in Telegram with a valid confirmation code.
Parameters:
@ -479,7 +484,7 @@ class Client(Methods, BaseClient):
"""
phone_number = phone_number.strip(" +")
r = self.send(
r = await self.send(
functions.auth.SignIn(
phone_number=phone_number,
phone_code_hash=phone_code_hash,
@ -498,7 +503,7 @@ class Client(Methods, BaseClient):
return User._parse(self, r.user)
def sign_up(self, phone_number: str, phone_code_hash: str, first_name: str, last_name: str = "") -> User:
async def sign_up(self, phone_number: str, phone_code_hash: str, first_name: str, last_name: str = "") -> User:
"""Register a new user in Telegram.
Parameters:
@ -522,7 +527,7 @@ class Client(Methods, BaseClient):
"""
phone_number = phone_number.strip(" +")
r = self.send(
r = await self.send(
functions.auth.SignUp(
phone_number=phone_number,
first_name=first_name,
@ -536,7 +541,7 @@ class Client(Methods, BaseClient):
return User._parse(self, r.user)
def sign_in_bot(self, bot_token: str) -> User:
async def sign_in_bot(self, bot_token: str) -> User:
"""Authorize a bot using its bot token generated by BotFather.
Parameters:
@ -551,7 +556,7 @@ class Client(Methods, BaseClient):
"""
while True:
try:
r = self.send(
r = await self.send(
functions.auth.ImportBotAuthorization(
flags=0,
api_id=self.api_id,
@ -560,28 +565,28 @@ class Client(Methods, BaseClient):
)
)
except UserMigrate as e:
self.session.stop()
await self.session.stop()
self.storage.dc_id(e.x)
self.storage.auth_key(Auth(self, self.storage.dc_id()).create())
self.storage.auth_key(await Auth(self, self.storage.dc_id()).create())
self.session = Session(self, self.storage.dc_id(), self.storage.auth_key())
self.session.start()
await self.session.start()
else:
self.storage.user_id(r.user.id)
self.storage.is_bot(True)
return User._parse(self, r.user)
def get_password_hint(self) -> str:
async def get_password_hint(self) -> str:
"""Get your Two-Step Verification password hint.
Returns:
``str``: On success, the password hint as string is returned.
"""
return self.send(functions.account.GetPassword()).hint
return (await self.send(functions.account.GetPassword())).hint
def check_password(self, password: str) -> User:
async def check_password(self, password: str) -> User:
"""Check your Two-Step Verification password and log in.
Parameters:
@ -594,10 +599,10 @@ class Client(Methods, BaseClient):
Raises:
BadRequest: In case the password is invalid.
"""
r = self.send(
r = await self.send(
functions.auth.CheckPassword(
password=compute_check(
self.send(functions.account.GetPassword()),
await self.send(functions.account.GetPassword()),
password
)
)
@ -608,7 +613,7 @@ class Client(Methods, BaseClient):
return User._parse(self, r.user)
def send_recovery_code(self) -> str:
async def send_recovery_code(self) -> str:
"""Send a code to your email to recover your password.
Returns:
@ -617,11 +622,11 @@ class Client(Methods, BaseClient):
Raises:
BadRequest: In case no recovery email was set up.
"""
return self.send(
return (await self.send(
functions.auth.RequestPasswordRecovery()
).email_pattern
)).email_pattern
def recover_password(self, recovery_code: str) -> User:
async def recover_password(self, recovery_code: str) -> User:
"""Recover your password with a recovery code and log in.
Parameters:
@ -634,7 +639,7 @@ class Client(Methods, BaseClient):
Raises:
BadRequest: In case the recovery code is invalid.
"""
r = self.send(
r = await self.send(
functions.auth.RecoverPassword(
code=recovery_code
)
@ -645,14 +650,14 @@ class Client(Methods, BaseClient):
return User._parse(self, r.user)
def accept_terms_of_service(self, terms_of_service_id: str) -> bool:
async def accept_terms_of_service(self, terms_of_service_id: str) -> bool:
"""Accept the given terms of service.
Parameters:
terms_of_service_id (``str``):
The terms of service identifier.
"""
r = self.send(
r = await self.send(
functions.help.AcceptTermsOfService(
id=types.DataJSON(
data=terms_of_service_id
@ -664,15 +669,15 @@ class Client(Methods, BaseClient):
return True
def authorize(self) -> User:
async def authorize(self) -> User:
if self.bot_token:
return self.sign_in_bot(self.bot_token)
return await self.sign_in_bot(self.bot_token)
while True:
try:
if not self.phone_number:
while True:
value = input("Enter phone number or bot token: ")
value = await ainput("Enter phone number or bot token: ")
if not value:
continue
@ -684,11 +689,11 @@ class Client(Methods, BaseClient):
if ":" in value:
self.bot_token = value
return self.sign_in_bot(value)
return await self.sign_in_bot(value)
else:
self.phone_number = value
sent_code = self.send_code(self.phone_number)
sent_code = await self.send_code(self.phone_number)
except BadRequest as e:
print(e.MESSAGE)
self.phone_number = None
@ -697,7 +702,7 @@ class Client(Methods, BaseClient):
break
if self.force_sms:
sent_code = self.resend_code(self.phone_number, sent_code.phone_code_hash)
sent_code = await self.resend_code(self.phone_number, sent_code.phone_code_hash)
print("The confirmation code has been sent via {}".format(
{
@ -710,10 +715,10 @@ class Client(Methods, BaseClient):
while True:
if not self.phone_code:
self.phone_code = input("Enter confirmation code: ")
self.phone_code = await ainput("Enter confirmation code: ")
try:
signed_in = self.sign_in(self.phone_number, sent_code.phone_code_hash, self.phone_code)
signed_in = await self.sign_in(self.phone_number, sent_code.phone_code_hash, self.phone_code)
except BadRequest as e:
print(e.MESSAGE)
self.phone_code = None
@ -721,24 +726,24 @@ class Client(Methods, BaseClient):
print(e.MESSAGE)
while True:
print("Password hint: {}".format(self.get_password_hint()))
print("Password hint: {}".format(await self.get_password_hint()))
if not self.password:
self.password = input("Enter password (empty to recover): ")
self.password = await ainput("Enter password (empty to recover): ")
try:
if not self.password:
confirm = input("Confirm password recovery (y/n): ")
confirm = await ainput("Confirm password recovery (y/n): ")
if confirm == "y":
email_pattern = self.send_recovery_code()
email_pattern = await self.send_recovery_code()
print("The recovery code has been sent to {}".format(email_pattern))
while True:
recovery_code = input("Enter recovery code: ")
recovery_code = await ainput("Enter recovery code: ")
try:
return self.recover_password(recovery_code)
return await self.recover_password(recovery_code)
except BadRequest as e:
print(e.MESSAGE)
except Exception as e:
@ -747,7 +752,7 @@ class Client(Methods, BaseClient):
else:
self.password = None
else:
return self.check_password(self.password)
return await self.check_password(self.password)
except BadRequest as e:
print(e.MESSAGE)
self.password = None
@ -758,11 +763,11 @@ class Client(Methods, BaseClient):
return signed_in
while True:
first_name = input("Enter first name: ")
last_name = input("Enter last name (empty to skip): ")
first_name = await ainput("Enter first name: ")
last_name = await ainput("Enter last name (empty to skip): ")
try:
signed_up = self.sign_up(
signed_up = await self.sign_up(
self.phone_number,
sent_code.phone_code_hash,
first_name,
@ -775,11 +780,11 @@ class Client(Methods, BaseClient):
if isinstance(signed_in, TermsOfService):
print("\n" + signed_in.text + "\n")
self.accept_terms_of_service(signed_in.id)
await self.accept_terms_of_service(signed_in.id)
return signed_up
def log_out(self):
async def log_out(self):
"""Log out from Telegram and delete the *\\*.session* file.
When you log out, the current client is stopped and the storage session deleted.
@ -794,13 +799,13 @@ class Client(Methods, BaseClient):
# Log out.
app.log_out()
"""
self.send(functions.auth.LogOut())
self.stop()
await self.send(functions.auth.LogOut())
await self.stop()
self.storage.delete()
return True
def start(self):
async def start(self):
"""Start the client.
This method connects the client to Telegram and, in case of new sessions, automatically manages the full
@ -825,25 +830,25 @@ class Client(Methods, BaseClient):
app.stop()
"""
is_authorized = self.connect()
is_authorized = await self.connect()
try:
if not is_authorized:
self.authorize()
await self.authorize()
if not self.storage.is_bot() and self.takeout:
self.takeout_id = self.send(functions.account.InitTakeoutSession()).id
self.takeout_id = (await self.send(functions.account.InitTakeoutSession())).id
log.warning("Takeout session {} initiated".format(self.takeout_id))
self.send(functions.updates.GetState())
await self.send(functions.updates.GetState())
except (Exception, KeyboardInterrupt):
self.disconnect()
await self.disconnect()
raise
else:
self.initialize()
await self.initialize()
return self
def stop(self, block: bool = True):
async def stop(self, block: bool = True):
"""Stop the Client.
This method disconnects the client from Telegram and stops the underlying tasks.
@ -874,18 +879,18 @@ class Client(Methods, BaseClient):
app.stop()
"""
def do_it():
self.terminate()
self.disconnect()
async def do_it():
await self.terminate()
await self.disconnect()
if block:
do_it()
await do_it()
else:
Thread(target=do_it).start()
asyncio.ensure_future(do_it())
return self
def restart(self, block: bool = True):
async def restart(self, block: bool = True):
"""Restart the Client.
This method will first call :meth:`~Client.stop` and then :meth:`~Client.start` in a row in order to restart
@ -921,19 +926,19 @@ class Client(Methods, BaseClient):
app.stop()
"""
def do_it():
self.stop()
self.start()
async def do_it():
await self.stop()
await self.start()
if block:
do_it()
await do_it()
else:
Thread(target=do_it).start()
asyncio.ensure_future(do_it())
return self
@staticmethod
def idle(stop_signals: tuple = (SIGINT, SIGTERM, SIGABRT)):
async def idle(stop_signals: tuple = (SIGINT, SIGTERM, SIGABRT)):
"""Block the main script execution until a signal is received.
This static method will run an infinite loop in order to block the main script execution and prevent it from
@ -978,6 +983,7 @@ class Client(Methods, BaseClient):
"""
def signal_handler(_, __):
logging.info("Stop signal received ({}). Exiting...".format(_))
Client.is_idling = False
for s in stop_signals:
@ -986,9 +992,9 @@ class Client(Methods, BaseClient):
Client.is_idling = True
while Client.is_idling:
time.sleep(1)
await asyncio.sleep(1)
def run(self):
def run(self, coroutine=None):
"""Start the client, idle the main script and finally stop the client.
This is a convenience method that calls :meth:`~Client.start`, :meth:`~Client.idle` and :meth:`~Client.stop` in
@ -1010,9 +1016,17 @@ class Client(Methods, BaseClient):
app.run()
"""
self.start()
Client.idle()
self.stop()
loop = asyncio.get_event_loop()
run = loop.run_until_complete
if coroutine is not None:
run(coroutine)
else:
run(self.start())
run(Client.idle())
run(self.stop())
loop.close()
def add_handler(self, handler: Handler, group: int = 0):
"""Register an update handler.
@ -1236,12 +1250,9 @@ class Client(Methods, BaseClient):
return is_min
def download_worker(self):
name = threading.current_thread().name
log.debug("{} started".format(name))
async def download_worker(self):
while True:
packet = self.download_queue.get()
packet = await self.download_queue.get()
if packet is None:
break
@ -1252,7 +1263,7 @@ class Client(Methods, BaseClient):
try:
data, directory, file_name, done, progress, progress_args, path = packet
temp_file_path = self.get_file(
temp_file_path = await self.get_file(
media_type=data.media_type,
dc_id=data.dc_id,
document_id=data.document_id,
@ -1289,14 +1300,9 @@ class Client(Methods, BaseClient):
finally:
done.set()
log.debug("{} stopped".format(name))
def updates_worker(self):
name = threading.current_thread().name
log.debug("{} started".format(name))
async def updates_worker(self):
while True:
updates = self.updates_queue.get()
updates = await self.updates_queue.get()
if updates is None:
break
@ -1328,9 +1334,9 @@ class Client(Methods, BaseClient):
if not isinstance(message, types.MessageEmpty):
try:
diff = self.send(
diff = await self.send(
functions.updates.GetChannelDifference(
channel=self.resolve_peer(utils.get_channel_id(channel_id)),
channel=await self.resolve_peer(utils.get_channel_id(channel_id)),
filter=types.ChannelMessagesFilter(
ranges=[types.MessageRange(
min_id=update.message.id,
@ -1348,9 +1354,9 @@ class Client(Methods, BaseClient):
users.update({u.id: u for u in diff.users})
chats.update({c.id: c for c in diff.chats})
self.dispatcher.updates_queue.put((update, users, chats))
self.dispatcher.updates_queue.put_nowait((update, users, chats))
elif isinstance(updates, (types.UpdateShortMessage, types.UpdateShortChatMessage)):
diff = self.send(
diff = await self.send(
functions.updates.GetDifference(
pts=updates.pts - updates.pts_count,
date=updates.date,
@ -1359,7 +1365,7 @@ class Client(Methods, BaseClient):
)
if diff.new_messages:
self.dispatcher.updates_queue.put((
self.dispatcher.updates_queue.put_nowait((
types.UpdateNewMessage(
message=diff.new_messages[0],
pts=updates.pts,
@ -1369,17 +1375,15 @@ class Client(Methods, BaseClient):
{c.id: c for c in diff.chats}
))
else:
self.dispatcher.updates_queue.put((diff.other_updates[0], {}, {}))
self.dispatcher.updates_queue.put_nowait((diff.other_updates[0], {}, {}))
elif isinstance(updates, types.UpdateShort):
self.dispatcher.updates_queue.put((updates.update, {}, {}))
self.dispatcher.updates_queue.put_nowait((updates.update, {}, {}))
elif isinstance(updates, types.UpdatesTooLong):
log.info(updates)
except Exception as e:
log.error(e, exc_info=True)
log.debug("{} stopped".format(name))
def send(self, data: TLObject, retries: int = Session.MAX_RETRIES, timeout: float = Session.WAIT_TIMEOUT):
async def send(self, data: TLObject, retries: int = Session.MAX_RETRIES, timeout: float = Session.WAIT_TIMEOUT):
"""Send raw Telegram queries.
This method makes it possible to manually call every single Telegram API method in a low-level manner.
@ -1417,7 +1421,7 @@ class Client(Methods, BaseClient):
if self.takeout_id:
data = functions.InvokeWithTakeout(takeout_id=self.takeout_id, query=data)
r = self.session.send(data, retries, timeout, self.sleep_threshold)
r = await self.session.send(data, retries, timeout, self.sleep_threshold)
self.fetch_peers(getattr(r, "users", []))
self.fetch_peers(getattr(r, "chats", []))
@ -1497,7 +1501,7 @@ class Client(Methods, BaseClient):
except KeyError:
self.plugins = None
def load_session(self):
async def load_session(self):
self.storage.open()
session_empty = any([
@ -1512,7 +1516,7 @@ class Client(Methods, BaseClient):
self.storage.date(0)
self.storage.test_mode(self.test_mode)
self.storage.auth_key(Auth(self, self.storage.dc_id()).create())
self.storage.auth_key(await Auth(self, self.storage.dc_id()).create())
self.storage.user_id(None)
self.storage.is_bot(None)
@ -1632,13 +1636,13 @@ class Client(Methods, BaseClient):
self.session_name, name, module_path))
if count > 0:
log.warning('[{}] Successfully loaded {} plugin{} from "{}"'.format(
log.info('[{}] Successfully loaded {} plugin{} from "{}"'.format(
self.session_name, count, "s" if count > 1 else "", root))
else:
log.warning('[{}] No plugin loaded from "{}"'.format(
self.session_name, root))
def resolve_peer(self, peer_id: Union[int, str]):
async def resolve_peer(self, peer_id: Union[int, str]):
"""Get the InputPeer of a known peer id.
Useful whenever an InputPeer type is required.
@ -1677,7 +1681,7 @@ class Client(Methods, BaseClient):
try:
return self.storage.get_peer_by_username(peer_id)
except KeyError:
self.send(
await self.send(
functions.contacts.ResolveUsername(
username=peer_id
)
@ -1694,7 +1698,7 @@ class Client(Methods, BaseClient):
if peer_type == "user":
self.fetch_peers(
self.send(
await self.send(
functions.users.GetUsers(
id=[
types.InputUser(
@ -1706,13 +1710,13 @@ class Client(Methods, BaseClient):
)
)
elif peer_type == "chat":
self.send(
await self.send(
functions.messages.GetChats(
id=[-peer_id]
)
)
else:
self.send(
await self.send(
functions.channels.GetChannels(
id=[
types.InputChannel(
@ -1728,7 +1732,7 @@ class Client(Methods, BaseClient):
except KeyError:
raise PeerIdInvalid
def save_file(
async def save_file(
self,
path: Union[str, BinaryIO],
file_id: int = None,
@ -1786,6 +1790,18 @@ class Client(Methods, BaseClient):
if path is None:
return None
async def worker(session):
while True:
data = await queue.get()
if data is None:
return
try:
await asyncio.ensure_future(session.send(data))
except Exception as e:
logging.error(e)
part_size = 512 * 1024
if isinstance(path, str):
@ -1808,15 +1824,20 @@ class Client(Methods, BaseClient):
raise ValueError("Telegram doesn't support uploading files bigger than 2000 MiB")
file_total_parts = int(math.ceil(file_size / part_size))
is_big = True if file_size > 10 * 1024 * 1024 else False
is_missing_part = True if file_id is not None else False
is_big = file_size > 10 * 1024 * 1024
pool_size = 3 if is_big else 1
workers_count = 4 if is_big else 1
is_missing_part = file_id is not None
file_id = file_id or self.rnd_id()
md5_sum = md5() if not is_big and not is_missing_part else None
session = Session(self, self.storage.dc_id(), self.storage.auth_key(), is_media=True)
session.start()
pool = [Session(self, self.storage.dc_id(), self.storage.auth_key(), is_media=True) for _ in range(pool_size)]
workers = [asyncio.ensure_future(worker(session)) for session in pool for _ in range(workers_count)]
queue = asyncio.Queue(16)
try:
for session in pool:
await session.start()
with fp:
fp.seek(part_size * file_part)
@ -1828,7 +1849,6 @@ class Client(Methods, BaseClient):
md5_sum = "".join([hex(i)[2:].zfill(2) for i in md5_sum.digest()])
break
for _ in range(3):
if is_big:
rpc = functions.upload.SaveBigFilePart(
file_id=file_id,
@ -1843,10 +1863,7 @@ class Client(Methods, BaseClient):
bytes=chunk
)
if session.send(rpc):
break
else:
raise AssertionError("Telegram didn't accept chunk #{} of {}".format(file_part, path))
await queue.put(rpc)
if is_missing_part:
return
@ -1857,7 +1874,7 @@ class Client(Methods, BaseClient):
file_part += 1
if progress:
progress(min(file_part * part_size, file_size), file_size, *progress_args)
await progress(min(file_part * part_size, file_size), file_size, *progress_args)
except Client.StopTransmission:
raise
except Exception as e:
@ -1878,9 +1895,15 @@ class Client(Methods, BaseClient):
md5_checksum=md5_sum
)
finally:
session.stop()
for _ in workers:
await queue.put(None)
def get_file(
await asyncio.gather(*workers)
for session in pool:
await session.stop()
async def get_file(
self,
media_type: int,
dc_id: int,
@ -1898,23 +1921,23 @@ class Client(Methods, BaseClient):
progress: callable,
progress_args: tuple = ()
) -> str:
with self.media_sessions_lock:
async with self.media_sessions_lock:
session = self.media_sessions.get(dc_id, None)
if session is None:
if dc_id != self.storage.dc_id():
session = Session(self, dc_id, Auth(self, dc_id).create(), is_media=True)
session.start()
session = Session(self, dc_id, await Auth(self, dc_id).create(), is_media=True)
await session.start()
for _ in range(3):
exported_auth = self.send(
exported_auth = await self.send(
functions.auth.ExportAuthorization(
dc_id=dc_id
)
)
try:
session.send(
await session.send(
functions.auth.ImportAuthorization(
id=exported_auth.id,
bytes=exported_auth.bytes
@ -1925,11 +1948,11 @@ class Client(Methods, BaseClient):
else:
break
else:
session.stop()
await session.stop()
raise AuthBytesInvalid
else:
session = Session(self, dc_id, self.storage.auth_key(), is_media=True)
session.start()
await session.start()
self.media_sessions[dc_id] = session
@ -1984,7 +2007,7 @@ class Client(Methods, BaseClient):
file_name = ""
try:
r = session.send(
r = await session.send(
functions.upload.GetFile(
location=location,
offset=offset,
@ -2007,7 +2030,7 @@ class Client(Methods, BaseClient):
offset += limit
if progress:
progress(
await progress(
min(offset, file_size)
if file_size != 0
else offset,
@ -2015,7 +2038,7 @@ class Client(Methods, BaseClient):
*progress_args
)
r = session.send(
r = await session.send(
functions.upload.GetFile(
location=location,
offset=offset,
@ -2024,13 +2047,16 @@ class Client(Methods, BaseClient):
)
elif isinstance(r, types.upload.FileCdnRedirect):
with self.media_sessions_lock:
async with self.media_sessions_lock:
cdn_session = self.media_sessions.get(r.dc_id, None)
if cdn_session is None:
cdn_session = Session(self, r.dc_id, Auth(self, r.dc_id).create(), is_media=True, is_cdn=True)
cdn_session = Session(
self,
r.dc_id,
await Auth(self, r.dc_id).create(), is_media=True, is_cdn=True)
cdn_session.start()
await cdn_session.start()
self.media_sessions[r.dc_id] = cdn_session
@ -2039,7 +2065,7 @@ class Client(Methods, BaseClient):
file_name = f.name
while True:
r2 = cdn_session.send(
r2 = await cdn_session.send(
functions.upload.GetCdnFile(
file_token=r.file_token,
offset=offset,
@ -2049,7 +2075,7 @@ class Client(Methods, BaseClient):
if isinstance(r2, types.upload.CdnFileReuploadNeeded):
try:
session.send(
await session.send(
functions.upload.ReuploadCdnFile(
file_token=r.file_token,
request_token=r2.request_token
@ -2072,7 +2098,7 @@ class Client(Methods, BaseClient):
)
)
hashes = session.send(
hashes = await session.send(
functions.upload.GetCdnFileHashes(
file_token=r.file_token,
offset=offset
@ -2089,7 +2115,7 @@ class Client(Methods, BaseClient):
offset += limit
if progress:
progress(
await progress(
min(offset, file_size)
if file_size != 0
else offset,

View File

@ -16,13 +16,12 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import os
import platform
import re
import sys
from pathlib import Path
from queue import Queue
from threading import Lock
from pyrogram import __version__
from ..parser import Parser
@ -30,7 +29,7 @@ from ...session.internals import MsgId
class BaseClient:
class StopTransmission(StopIteration):
class StopTransmission(StopAsyncIteration):
pass
APP_VERSION = "Pyrogram {}".format(__version__)
@ -52,7 +51,7 @@ class BaseClient:
INVITE_LINK_RE = re.compile(r"^(?:https?://)?(?:www\.)?(?:t(?:elegram)?\.(?:org|me|dog)/joinchat/)([\w-]+)$")
DIALOGS_AT_ONCE = 100
UPDATES_WORKERS = 4
DOWNLOAD_WORKERS = 1
DOWNLOAD_WORKERS = 4
OFFLINE_SLEEP = 900
WORKERS = 4
WORKDIR = PARENT_DIR
@ -100,24 +99,24 @@ class BaseClient:
self.session = None
self.media_sessions = {}
self.media_sessions_lock = Lock()
self.media_sessions_lock = asyncio.Lock()
self.is_connected = None
self.is_initialized = None
self.takeout_id = None
self.updates_queue = Queue()
self.updates_workers_list = []
self.download_queue = Queue()
self.download_workers_list = []
self.updates_queue = asyncio.Queue()
self.updates_worker_tasks = []
self.download_queue = asyncio.Queue()
self.download_worker_tasks = []
self.disconnect_handler = None
def send(self, *args, **kwargs):
async def send(self, *args, **kwargs):
pass
def resolve_peer(self, *args, **kwargs):
async def resolve_peer(self, *args, **kwargs):
pass
def fetch_peers(self, *args, **kwargs):
@ -126,25 +125,46 @@ class BaseClient:
def add_handler(self, *args, **kwargs):
pass
def save_file(self, *args, **kwargs):
async def save_file(self, *args, **kwargs):
pass
def get_messages(self, *args, **kwargs):
async def get_messages(self, *args, **kwargs):
pass
def get_history(self, *args, **kwargs):
async def get_history(self, *args, **kwargs):
pass
def get_dialogs(self, *args, **kwargs):
async def get_dialogs(self, *args, **kwargs):
pass
def get_chat_members(self, *args, **kwargs):
async def get_chat_members(self, *args, **kwargs):
pass
def get_chat_members_count(self, *args, **kwargs):
async def get_chat_members_count(self, *args, **kwargs):
pass
def answer_inline_query(self, *args, **kwargs):
async def answer_inline_query(self, *args, **kwargs):
pass
async def get_profile_photos(self, *args, **kwargs):
pass
async def edit_message_text(self, *args, **kwargs):
pass
async def edit_inline_text(self, *args, **kwargs):
pass
async def edit_message_media(self, *args, **kwargs):
pass
async def edit_inline_media(self, *args, **kwargs):
pass
async def edit_message_reply_markup(self, *args, **kwargs):
pass
async def edit_inline_reply_markup(self, *args, **kwargs):
pass
def guess_mime_type(self, *args, **kwargs):
@ -152,24 +172,3 @@ class BaseClient:
def guess_extension(self, *args, **kwargs):
pass
def get_profile_photos(self, *args, **kwargs):
pass
def edit_message_text(self, *args, **kwargs):
pass
def edit_inline_text(self, *args, **kwargs):
pass
def edit_message_media(self, *args, **kwargs):
pass
def edit_inline_media(self, *args, **kwargs):
pass
def edit_message_reply_markup(self, *args, **kwargs):
pass
def edit_inline_reply_markup(self, *args, **kwargs):
pass

View File

@ -16,11 +16,9 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging
import threading
from collections import OrderedDict
from queue import Queue
from threading import Thread, Lock
import pyrogram
from pyrogram.api.types import (
@ -69,75 +67,74 @@ class Dispatcher:
self.client = client
self.workers = workers
self.workers_list = []
self.update_worker_tasks = []
self.locks_list = []
self.updates_queue = Queue()
self.updates_queue = asyncio.Queue()
self.groups = OrderedDict()
async def message_parser(update, users, chats):
return await pyrogram.Message._parse(
self.client, update.message, users, chats,
isinstance(update, UpdateNewScheduledMessage)
), MessageHandler
async def deleted_messages_parser(update, users, chats):
return utils.parse_deleted_messages(self.client, update), DeletedMessagesHandler
async def callback_query_parser(update, users, chats):
return await pyrogram.CallbackQuery._parse(self.client, update, users), CallbackQueryHandler
async def user_status_parser(update, users, chats):
return pyrogram.User._parse_user_status(self.client, update), UserStatusHandler
async def inline_query_parser(update, users, chats):
return pyrogram.InlineQuery._parse(self.client, update, users), InlineQueryHandler
async def poll_parser(update, users, chats):
return pyrogram.Poll._parse_update(self.client, update), PollHandler
async def chosen_inline_result_parser(update, users, chats):
return pyrogram.ChosenInlineResult._parse(self.client, update, users), ChosenInlineResultHandler
self.update_parsers = {
Dispatcher.MESSAGE_UPDATES:
lambda upd, usr, cht: (
pyrogram.Message._parse(
self.client,
upd.message,
usr,
cht,
isinstance(upd, UpdateNewScheduledMessage)
),
MessageHandler
),
Dispatcher.DELETE_MESSAGES_UPDATES:
lambda upd, usr, cht: (utils.parse_deleted_messages(self.client, upd), DeletedMessagesHandler),
Dispatcher.CALLBACK_QUERY_UPDATES:
lambda upd, usr, cht: (pyrogram.CallbackQuery._parse(self.client, upd, usr), CallbackQueryHandler),
(UpdateUserStatus,):
lambda upd, usr, cht: (pyrogram.User._parse_user_status(self.client, upd), UserStatusHandler),
(UpdateBotInlineQuery,):
lambda upd, usr, cht: (pyrogram.InlineQuery._parse(self.client, upd, usr), InlineQueryHandler),
(UpdateMessagePoll,):
lambda upd, usr, cht: (pyrogram.Poll._parse_update(self.client, upd), PollHandler),
(UpdateBotInlineSend,):
lambda upd, usr, cht: (pyrogram.ChosenInlineResult._parse(self.client, upd, usr),
ChosenInlineResultHandler)
Dispatcher.MESSAGE_UPDATES: message_parser,
Dispatcher.DELETE_MESSAGES_UPDATES: deleted_messages_parser,
Dispatcher.CALLBACK_QUERY_UPDATES: callback_query_parser,
(UpdateUserStatus,): user_status_parser,
(UpdateBotInlineQuery,): inline_query_parser,
(UpdateMessagePoll,): poll_parser,
(UpdateBotInlineSend,): chosen_inline_result_parser
}
self.update_parsers = {key: value for key_tuple, value in self.update_parsers.items() for key in key_tuple}
def start(self):
async def start(self):
for i in range(self.workers):
self.locks_list.append(Lock())
self.locks_list.append(asyncio.Lock())
self.workers_list.append(
Thread(
target=self.update_worker,
name="UpdateWorker#{}".format(i + 1),
args=(self.locks_list[-1],)
)
self.update_worker_tasks.append(
asyncio.ensure_future(self.update_worker(self.locks_list[-1]))
)
self.workers_list[-1].start()
logging.info("Started {} UpdateWorkerTasks".format(self.workers))
def stop(self):
for _ in range(self.workers):
self.updates_queue.put(None)
async def stop(self):
for i in range(self.workers):
self.updates_queue.put_nowait(None)
for worker in self.workers_list:
worker.join()
for i in self.update_worker_tasks:
await i
self.workers_list.clear()
self.locks_list.clear()
self.update_worker_tasks.clear()
self.groups.clear()
logging.info("Stopped {} UpdateWorkerTasks".format(self.workers))
def add_handler(self, handler, group: int):
async def fn():
for lock in self.locks_list:
lock.acquire()
await lock.acquire()
try:
if group not in self.groups:
@ -149,9 +146,12 @@ class Dispatcher:
for lock in self.locks_list:
lock.release()
asyncio.ensure_future(fn())
def remove_handler(self, handler, group: int):
async def fn():
for lock in self.locks_list:
lock.acquire()
await lock.acquire()
try:
if group not in self.groups:
@ -162,12 +162,11 @@ class Dispatcher:
for lock in self.locks_list:
lock.release()
def update_worker(self, lock):
name = threading.current_thread().name
log.debug("{} started".format(name))
asyncio.ensure_future(fn())
async def update_worker(self, lock):
while True:
packet = self.updates_queue.get()
packet = await self.updates_queue.get()
if packet is None:
break
@ -177,12 +176,12 @@ class Dispatcher:
parser = self.update_parsers.get(type(update), None)
parsed_update, handler_type = (
parser(update, users, chats)
await parser(update, users, chats)
if parser is not None
else (None, type(None))
)
with lock:
async with lock:
for group in self.groups.values():
for handler in group:
args = None
@ -202,7 +201,7 @@ class Dispatcher:
continue
try:
handler.callback(self.client, *args)
await handler.callback(self.client, *args)
except pyrogram.StopPropagation:
raise
except pyrogram.ContinuePropagation:
@ -215,5 +214,3 @@ class Dispatcher:
pass
except Exception as e:
log.error(e, exc_info=True)
log.debug("{} stopped".format(name))

View File

@ -16,9 +16,9 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging
import time
from threading import Thread, Event, Lock
log = logging.getLogger(__name__)
@ -27,13 +27,18 @@ class Syncer:
INTERVAL = 20
clients = {}
thread = None
event = Event()
lock = Lock()
event = None
lock = None
@classmethod
def add(cls, client):
with cls.lock:
async def add(cls, client):
if cls.event is None:
cls.event = asyncio.Event()
if cls.lock is None:
cls.lock = asyncio.Lock()
async with cls.lock:
cls.sync(client)
cls.clients[id(client)] = client
@ -42,8 +47,8 @@ class Syncer:
cls.start()
@classmethod
def remove(cls, client):
with cls.lock:
async def remove(cls, client):
async with cls.lock:
cls.sync(client)
del cls.clients[id(client)]
@ -54,24 +59,23 @@ class Syncer:
@classmethod
def start(cls):
cls.event.clear()
cls.thread = Thread(target=cls.worker, name=cls.__name__)
cls.thread.start()
asyncio.ensure_future(cls.worker())
@classmethod
def stop(cls):
cls.event.set()
@classmethod
def worker(cls):
async def worker(cls):
while True:
cls.event.wait(cls.INTERVAL)
if cls.event.is_set():
break
with cls.lock:
try:
await asyncio.wait_for(cls.event.wait(), cls.INTERVAL)
except asyncio.TimeoutError:
async with cls.lock:
for client in cls.clients.values():
cls.sync(client)
else:
break
@classmethod
def sync(cls, client):

View File

@ -16,8 +16,11 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import base64
import struct
import sys
from concurrent.futures.thread import ThreadPoolExecutor
from typing import List
from typing import Union
@ -80,6 +83,15 @@ def decode_file_ref(file_ref: str) -> bytes:
return base64.urlsafe_b64decode(file_ref + "=" * (-len(file_ref) % 4))
async def ainput(prompt: str = ""):
print(prompt, end="", flush=True)
with ThreadPoolExecutor(1) as executor:
return (await asyncio.get_event_loop().run_in_executor(
executor, sys.stdin.readline
)).rstrip()
def get_offset_date(dialogs):
for m in reversed(dialogs.messages):
if isinstance(m, types.MessageEmpty):
@ -141,24 +153,24 @@ def get_input_media_from_file_id(
raise ValueError("Unknown media type: {}".format(file_id_str))
def parse_messages(client, messages: types.messages.Messages, replies: int = 1) -> List["pyrogram.Message"]:
async def parse_messages(client, messages: types.messages.Messages, replies: int = 1) -> List["pyrogram.Message"]:
users = {i.id: i for i in messages.users}
chats = {i.id: i for i in messages.chats}
if not messages.messages:
return pyrogram.List()
parsed_messages = [
pyrogram.Message._parse(client, message, users, chats, replies=0)
for message in messages.messages
]
parsed_messages = []
for message in messages.messages:
parsed_messages.append(await pyrogram.Message._parse(client, message, users, chats, replies=0))
if replies:
messages_with_replies = {i.id: getattr(i, "reply_to_msg_id", None) for i in messages.messages}
reply_message_ids = [i[0] for i in filter(lambda x: x[1] is not None, messages_with_replies.items())]
if reply_message_ids:
reply_messages = client.get_messages(
reply_messages = await client.get_messages(
parsed_messages[0].chat.id,
reply_to_message_ids=reply_message_ids,
replies=replies - 1

View File

@ -21,7 +21,7 @@ from pyrogram.client.ext import BaseClient
class AnswerCallbackQuery(BaseClient):
def answer_callback_query(
async def answer_callback_query(
self,
callback_query_id: str,
text: str = None,
@ -68,7 +68,7 @@ class AnswerCallbackQuery(BaseClient):
# Answer with alert
app.answer_callback_query(query_id, text=text, show_alert=True)
"""
return self.send(
return await self.send(
functions.messages.SetBotCallbackAnswer(
query_id=int(callback_query_id),
cache_time=cache_time,

View File

@ -24,7 +24,7 @@ from ...types.inline_mode import InlineQueryResult
class AnswerInlineQuery(BaseClient):
def answer_inline_query(
async def answer_inline_query(
self,
inline_query_id: str,
results: List[InlineQueryResult],
@ -93,10 +93,15 @@ class AnswerInlineQuery(BaseClient):
"Title",
InputTextMessageContent("Message content"))])
"""
return self.send(
written_results = [] # Py 3.5 doesn't support await inside comprehensions
for r in results:
written_results.append(await r.write())
return await self.send(
functions.messages.SetInlineBotResults(
query_id=int(inline_query_id),
results=[r.write() for r in results],
results=written_results,
cache_time=cache_time,
gallery=is_gallery or None,
private=is_personal or None,

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class GetGameHighScores(BaseClient):
def get_game_high_scores(
async def get_game_high_scores(
self,
user_id: Union[int, str],
chat_id: Union[int, str],
@ -59,11 +59,11 @@ class GetGameHighScores(BaseClient):
"""
# TODO: inline_message_id
r = self.send(
r = await self.send(
functions.messages.GetGameHighScores(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
id=message_id,
user_id=self.resolve_peer(user_id)
user_id=await self.resolve_peer(user_id)
)
)

View File

@ -24,7 +24,7 @@ from pyrogram.errors import UnknownError
class GetInlineBotResults(BaseClient):
def get_inline_bot_results(
async def get_inline_bot_results(
self,
bot: Union[int, str],
query: str = "",
@ -70,9 +70,9 @@ class GetInlineBotResults(BaseClient):
# TODO: Don't return the raw type
try:
return self.send(
return await self.send(
functions.messages.GetInlineBotResults(
bot=self.resolve_peer(bot),
bot=await self.resolve_peer(bot),
peer=types.InputPeerSelf(),
query=query,
offset=offset,

View File

@ -23,7 +23,7 @@ from pyrogram.client.ext import BaseClient
class RequestCallbackAnswer(BaseClient):
def request_callback_answer(
async def request_callback_answer(
self,
chat_id: Union[int, str],
message_id: int,
@ -64,9 +64,9 @@ class RequestCallbackAnswer(BaseClient):
# Telegram only wants bytes, but we are allowed to pass strings too.
data = bytes(callback_data, "utf-8") if isinstance(callback_data, str) else callback_data
return self.send(
return await self.send(
functions.messages.GetBotCallbackAnswer(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
msg_id=message_id,
data=data
),

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class SendGame(BaseClient):
def send_game(
async def send_game(
self,
chat_id: Union[int, str],
game_short_name: str,
@ -67,9 +67,9 @@ class SendGame(BaseClient):
app.send_game(chat_id, "gamename")
"""
r = self.send(
r = await self.send(
functions.messages.SendMedia(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
media=types.InputMediaGame(
id=types.InputGameShortName(
bot_id=types.InputUserSelf(),
@ -86,7 +86,7 @@ class SendGame(BaseClient):
for i in r.updates:
if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage)):
return pyrogram.Message._parse(
return await pyrogram.Message._parse(
self, i.message,
{i.id: i for i in r.users},
{i.id: i for i in r.chats}

View File

@ -23,7 +23,7 @@ from pyrogram.client.ext import BaseClient
class SendInlineBotResult(BaseClient):
def send_inline_bot_result(
async def send_inline_bot_result(
self,
chat_id: Union[int, str],
query_id: int,
@ -65,9 +65,9 @@ class SendInlineBotResult(BaseClient):
app.send_inline_bot_result(chat_id, query_id, result_id)
"""
return self.send(
return await self.send(
functions.messages.SendInlineBotResult(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
query_id=query_id,
id=result_id,
random_id=self.rnd_id(),

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class SetGameScore(BaseClient):
def set_game_score(
async def set_game_score(
self,
user_id: Union[int, str],
score: int,
@ -75,12 +75,12 @@ class SetGameScore(BaseClient):
# Force set new score
app.set_game_score(user_id, 25, force=True)
"""
r = self.send(
r = await self.send(
functions.messages.SetGameScore(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
score=score,
id=message_id,
user_id=self.resolve_peer(user_id),
user_id=await self.resolve_peer(user_id),
force=force or None,
edit_message=not disable_edit_message or None
)
@ -88,7 +88,7 @@ class SetGameScore(BaseClient):
for i in r.updates:
if isinstance(i, (types.UpdateEditMessage, types.UpdateEditChannelMessage)):
return pyrogram.Message._parse(
return await pyrogram.Message._parse(
self, i.message,
{i.id: i for i in r.users},
{i.id: i for i in r.chats}

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class AddChatMembers(BaseClient):
def add_chat_members(
async def add_chat_members(
self,
chat_id: Union[int, str],
user_ids: Union[Union[int, str], List[Union[int, str]]],
@ -60,26 +60,26 @@ class AddChatMembers(BaseClient):
# Change forward_limit (for basic groups only)
app.add_chat_members(chat_id, user_id, forward_limit=25)
"""
peer = self.resolve_peer(chat_id)
peer = await self.resolve_peer(chat_id)
if not isinstance(user_ids, list):
user_ids = [user_ids]
if isinstance(peer, types.InputPeerChat):
for user_id in user_ids:
self.send(
await self.send(
functions.messages.AddChatUser(
chat_id=peer.chat_id,
user_id=self.resolve_peer(user_id),
user_id=await self.resolve_peer(user_id),
fwd_limit=forward_limit
)
)
else:
self.send(
await self.send(
functions.channels.InviteToChannel(
channel=peer,
users=[
self.resolve_peer(user_id)
await self.resolve_peer(user_id)
for user_id in user_ids
]
)

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class ArchiveChats(BaseClient):
def archive_chats(
async def archive_chats(
self,
chat_ids: Union[int, str, List[Union[int, str]]],
) -> bool:
@ -50,14 +50,19 @@ class ArchiveChats(BaseClient):
if not isinstance(chat_ids, list):
chat_ids = [chat_ids]
self.send(
functions.folders.EditPeerFolders(
folder_peers=[
folder_peers = []
for chat in chat_ids:
folder_peers.append(
types.InputFolderPeer(
peer=self.resolve_peer(chat),
peer=await self.resolve_peer(chat),
folder_id=1
) for chat in chat_ids
]
)
)
await self.send(
functions.folders.EditPeerFolders(
folder_peers=folder_peers
)
)

View File

@ -22,7 +22,7 @@ from ...ext import BaseClient
class CreateChannel(BaseClient):
def create_channel(
async def create_channel(
self,
title: str,
description: str = ""
@ -44,7 +44,7 @@ class CreateChannel(BaseClient):
app.create_channel("Channel Title", "Channel Description")
"""
r = self.send(
r = await self.send(
functions.channels.CreateChannel(
title=title,
about=description,

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient
class CreateGroup(BaseClient):
def create_group(
async def create_group(
self,
title: str,
users: Union[Union[int, str], List[Union[int, str]]]
@ -55,10 +55,10 @@ class CreateGroup(BaseClient):
if not isinstance(users, list):
users = [users]
r = self.send(
r = await self.send(
functions.messages.CreateChat(
title=title,
users=[self.resolve_peer(u) for u in users]
users=[await self.resolve_peer(u) for u in users]
)
)

View File

@ -22,7 +22,7 @@ from ...ext import BaseClient
class CreateSupergroup(BaseClient):
def create_supergroup(
async def create_supergroup(
self,
title: str,
description: str = ""
@ -48,7 +48,7 @@ class CreateSupergroup(BaseClient):
app.create_supergroup("Supergroup Title", "Supergroup Description")
"""
r = self.send(
r = await self.send(
functions.channels.CreateChannel(
title=title,
about=description,

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class DeleteChannel(BaseClient):
def delete_channel(self, chat_id: Union[int, str]) -> bool:
async def delete_channel(self, chat_id: Union[int, str]) -> bool:
"""Delete a channel.
Parameters:
@ -38,9 +38,9 @@ class DeleteChannel(BaseClient):
app.delete_channel(channel_id)
"""
self.send(
await self.send(
functions.channels.DeleteChannel(
channel=self.resolve_peer(chat_id)
channel=await self.resolve_peer(chat_id)
)
)

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class DeleteChatPhoto(BaseClient):
def delete_chat_photo(
async def delete_chat_photo(
self,
chat_id: Union[int, str]
) -> bool:
@ -46,17 +46,17 @@ class DeleteChatPhoto(BaseClient):
app.delete_chat_photo(chat_id)
"""
peer = self.resolve_peer(chat_id)
peer = await self.resolve_peer(chat_id)
if isinstance(peer, types.InputPeerChat):
self.send(
await self.send(
functions.messages.EditChatPhoto(
chat_id=peer.chat_id,
photo=types.InputChatPhotoEmpty()
)
)
elif isinstance(peer, types.InputPeerChannel):
self.send(
await self.send(
functions.channels.EditPhoto(
channel=peer,
photo=types.InputChatPhotoEmpty()

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class DeleteSupergroup(BaseClient):
def delete_supergroup(self, chat_id: Union[int, str]) -> bool:
async def delete_supergroup(self, chat_id: Union[int, str]) -> bool:
"""Delete a supergroup.
Parameters:
@ -38,9 +38,9 @@ class DeleteSupergroup(BaseClient):
app.delete_supergroup(supergroup_id)
"""
self.send(
await self.send(
functions.channels.DeleteChannel(
channel=self.resolve_peer(chat_id)
channel=await self.resolve_peer(chat_id)
)
)

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class ExportChatInviteLink(BaseClient):
def export_chat_invite_link(
async def export_chat_invite_link(
self,
chat_id: Union[int, str]
) -> str:
@ -55,13 +55,13 @@ class ExportChatInviteLink(BaseClient):
link = app.export_chat_invite_link(chat_id)
print(link)
"""
peer = self.resolve_peer(chat_id)
peer = await self.resolve_peer(chat_id)
if isinstance(peer, (types.InputPeerChat, types.InputPeerChannel)):
return self.send(
return (await self.send(
functions.messages.ExportChatInvite(
peer=peer
)
).link
)).link
else:
raise ValueError('The chat_id "{}" belongs to a user'.format(chat_id))

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient, utils
class GetChat(BaseClient):
def get_chat(
async def get_chat(
self,
chat_id: Union[int, str]
) -> Union["pyrogram.Chat", "pyrogram.ChatPreview"]:
@ -55,7 +55,7 @@ class GetChat(BaseClient):
match = self.INVITE_LINK_RE.match(str(chat_id))
if match:
r = self.send(
r = await self.send(
functions.messages.CheckChatInvite(
hash=match.group(1)
)
@ -72,13 +72,13 @@ class GetChat(BaseClient):
if isinstance(r.chat, types.Channel):
chat_id = utils.get_channel_id(r.chat.id)
peer = self.resolve_peer(chat_id)
peer = await self.resolve_peer(chat_id)
if isinstance(peer, types.InputPeerChannel):
r = self.send(functions.channels.GetFullChannel(channel=peer))
r = await self.send(functions.channels.GetFullChannel(channel=peer))
elif isinstance(peer, (types.InputPeerUser, types.InputPeerSelf)):
r = self.send(functions.users.GetFullUser(id=peer))
r = await self.send(functions.users.GetFullUser(id=peer))
else:
r = self.send(functions.messages.GetFullChat(chat_id=peer.chat_id))
r = await self.send(functions.messages.GetFullChat(chat_id=peer.chat_id))
return pyrogram.Chat._parse_full(self, r)
return await pyrogram.Chat._parse_full(self, r)

View File

@ -21,11 +21,12 @@ from typing import Union
import pyrogram
from pyrogram.api import functions, types
from pyrogram.errors import UserNotParticipant
from ...ext import BaseClient
class GetChatMember(BaseClient):
def get_chat_member(
async def get_chat_member(
self,
chat_id: Union[int, str],
user_id: Union[int, str]
@ -50,11 +51,11 @@ class GetChatMember(BaseClient):
dan = app.get_chat_member("pyrogramchat", "haskell")
print(dan)
"""
chat = self.resolve_peer(chat_id)
user = self.resolve_peer(user_id)
chat = await self.resolve_peer(chat_id)
user = await self.resolve_peer(user_id)
if isinstance(chat, types.InputPeerChat):
r = self.send(
r = await self.send(
functions.messages.GetFullChat(
chat_id=chat.chat_id
)
@ -75,7 +76,7 @@ class GetChatMember(BaseClient):
else:
raise UserNotParticipant
elif isinstance(chat, types.InputPeerChannel):
r = self.send(
r = await self.send(
functions.channels.GetParticipant(
channel=chat,
user_id=user

View File

@ -36,7 +36,7 @@ class Filters:
class GetChatMembers(BaseClient):
def get_chat_members(
async def get_chat_members(
self,
chat_id: Union[int, str],
offset: int = 0,
@ -103,10 +103,10 @@ class GetChatMembers(BaseClient):
# Get all bots
app.get_chat_members("pyrogramchat", filter="bots")
"""
peer = self.resolve_peer(chat_id)
peer = await self.resolve_peer(chat_id)
if isinstance(peer, types.InputPeerChat):
r = self.send(
r = await self.send(
functions.messages.GetFullChat(
chat_id=peer.chat_id
)
@ -134,7 +134,7 @@ class GetChatMembers(BaseClient):
else:
raise ValueError("Invalid filter \"{}\"".format(filter))
r = self.send(
r = await self.send(
functions.channels.GetParticipants(
channel=peer,
filter=filter,

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class GetChatMembersCount(BaseClient):
def get_chat_members_count(
async def get_chat_members_count(
self,
chat_id: Union[int, str]
) -> int:
@ -45,19 +45,23 @@ class GetChatMembersCount(BaseClient):
count = app.get_chat_members_count("pyrogramchat")
print(count)
"""
peer = self.resolve_peer(chat_id)
peer = await self.resolve_peer(chat_id)
if isinstance(peer, types.InputPeerChat):
return self.send(
r = await self.send(
functions.messages.GetChats(
id=[peer.chat_id]
)
).chats[0].participants_count
)
return r.chats[0].participants_count
elif isinstance(peer, types.InputPeerChannel):
return self.send(
r = await self.send(
functions.channels.GetFullChannel(
channel=peer
)
).full_chat.participants_count
)
return r.full_chat.participants_count
else:
raise ValueError("The chat_id \"{}\" belongs to a user".format(chat_id))

View File

@ -27,7 +27,7 @@ log = logging.getLogger(__name__)
class GetDialogs(BaseClient):
def get_dialogs(
async def get_dialogs(
self,
offset_date: int = 0,
limit: int = 100,
@ -65,9 +65,9 @@ class GetDialogs(BaseClient):
"""
if pinned_only:
r = self.send(functions.messages.GetPinnedDialogs(folder_id=0))
r = await self.send(functions.messages.GetPinnedDialogs(folder_id=0))
else:
r = self.send(
r = await self.send(
functions.messages.GetDialogs(
offset_date=offset_date,
offset_id=0,
@ -94,7 +94,7 @@ class GetDialogs(BaseClient):
else:
chat_id = utils.get_peer_id(to_id)
messages[chat_id] = pyrogram.Message._parse(self, message, users, chats)
messages[chat_id] = await pyrogram.Message._parse(self, message, users, chats)
parsed_dialogs = []

View File

@ -21,7 +21,7 @@ from ...ext import BaseClient
class GetDialogsCount(BaseClient):
def get_dialogs_count(self, pinned_only: bool = False) -> int:
async def get_dialogs_count(self, pinned_only: bool = False) -> int:
"""Get the total count of your dialogs.
pinned_only (``bool``, *optional*):
@ -39,9 +39,9 @@ class GetDialogsCount(BaseClient):
"""
if pinned_only:
return len(self.send(functions.messages.GetPinnedDialogs(folder_id=0)).dialogs)
return len((await self.send(functions.messages.GetPinnedDialogs(folder_id=0))).dialogs)
else:
r = self.send(
r = await self.send(
functions.messages.GetDialogs(
offset_date=0,
offset_id=0,

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient, utils
class GetNearbyChats(BaseClient):
def get_nearby_chats(
async def get_nearby_chats(
self,
latitude: float,
longitude: float
@ -48,7 +48,7 @@ class GetNearbyChats(BaseClient):
print(chats)
"""
r = self.send(
r = await self.send(
functions.contacts.GetLocated(
geo_point=types.InputGeoPoint(
lat=latitude,

View File

@ -17,10 +17,12 @@
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from string import ascii_lowercase
from typing import Union, Generator
from typing import Union, Generator, Optional
import pyrogram
from async_generator import async_generator, yield_
from pyrogram.api import types
from ...ext import BaseClient
@ -38,13 +40,14 @@ QUERYABLE_FILTERS = (Filters.ALL, Filters.KICKED, Filters.RESTRICTED)
class IterChatMembers(BaseClient):
def iter_chat_members(
@async_generator
async def iter_chat_members(
self,
chat_id: Union[int, str],
limit: int = 0,
query: str = "",
filter: str = Filters.ALL
) -> Generator["pyrogram.ChatMember", None, None]:
) -> Optional[Generator["pyrogram.ChatMember", None, None]]:
"""Iterate through the members of a chat sequentially.
This convenience method does the same as repeatedly calling :meth:`~Client.get_chat_members` in a loop, thus saving you
@ -97,7 +100,7 @@ class IterChatMembers(BaseClient):
queries = [query] if query else QUERIES
total = limit or (1 << 31) - 1
limit = min(200, total)
resolved_chat_id = self.resolve_peer(chat_id)
resolved_chat_id = await self.resolve_peer(chat_id)
if filter not in QUERYABLE_FILTERS:
queries = [""]
@ -106,7 +109,7 @@ class IterChatMembers(BaseClient):
offset = 0
while True:
chat_members = self.get_chat_members(
chat_members = await self.get_chat_members(
chat_id=chat_id,
offset=offset,
limit=limit,
@ -128,7 +131,7 @@ class IterChatMembers(BaseClient):
if user_id in yielded:
continue
yield chat_member
await yield_(chat_member)
yielded.add(chat_member.user.id)

View File

@ -16,18 +16,21 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from typing import Generator
from typing import Generator, Optional
from async_generator import async_generator, yield_
import pyrogram
from ...ext import BaseClient
class IterDialogs(BaseClient):
def iter_dialogs(
@async_generator
async def iter_dialogs(
self,
offset_date: int = 0,
limit: int = None
) -> Generator["pyrogram.Dialog", None, None]:
limit: int = 0,
offset_date: int = 0
) -> Optional[Generator["pyrogram.Dialog", None, None]]:
"""Iterate through a user's dialogs sequentially.
This convenience method does the same as repeatedly calling :meth:`~Client.get_dialogs` in a loop, thus saving
@ -35,14 +38,14 @@ class IterDialogs(BaseClient):
single call.
Parameters:
offset_date (``int``):
The offset date in Unix time taken from the top message of a :obj:`Dialog`.
Defaults to 0 (most recent dialog).
limit (``int``, *optional*):
Limits the number of dialogs to be retrieved.
By default, no limit is applied and all dialogs are returned.
offset_date (``int``):
The offset date in Unix time taken from the top message of a :obj:`Dialog`.
Defaults to 0 (most recent dialog).
Returns:
``Generator``: A generator yielding :obj:`Dialog` objects.
@ -57,12 +60,12 @@ class IterDialogs(BaseClient):
total = limit or (1 << 31) - 1
limit = min(100, total)
pinned_dialogs = self.get_dialogs(
pinned_dialogs = await self.get_dialogs(
pinned_only=True
)
for dialog in pinned_dialogs:
yield dialog
await yield_(dialog)
current += 1
@ -70,7 +73,7 @@ class IterDialogs(BaseClient):
return
while True:
dialogs = self.get_dialogs(
dialogs = await self.get_dialogs(
offset_date=offset_date,
limit=limit
)
@ -81,7 +84,7 @@ class IterDialogs(BaseClient):
offset_date = dialogs[-1].top_message.date
for dialog in dialogs:
yield dialog
await yield_(dialog)
current += 1

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient
class JoinChat(BaseClient):
def join_chat(
async def join_chat(
self,
chat_id: Union[int, str]
):
@ -53,7 +53,7 @@ class JoinChat(BaseClient):
match = self.INVITE_LINK_RE.match(str(chat_id))
if match:
chat = self.send(
chat = await self.send(
functions.messages.ImportChatInvite(
hash=match.group(1)
)
@ -63,9 +63,9 @@ class JoinChat(BaseClient):
elif isinstance(chat.chats[0], types.Channel):
return pyrogram.Chat._parse_channel_chat(self, chat.chats[0])
else:
chat = self.send(
chat = await self.send(
functions.channels.JoinChannel(
channel=self.resolve_peer(chat_id)
channel=await self.resolve_peer(chat_id)
)
)

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient
class KickChatMember(BaseClient):
def kick_chat_member(
async def kick_chat_member(
self,
chat_id: Union[int, str],
user_id: Union[int, str],
@ -68,11 +68,11 @@ class KickChatMember(BaseClient):
# Kick chat member and automatically unban after 24h
app.kick_chat_member(chat_id, user_id, int(time.time() + 86400))
"""
chat_peer = self.resolve_peer(chat_id)
user_peer = self.resolve_peer(user_id)
chat_peer = await self.resolve_peer(chat_id)
user_peer = await self.resolve_peer(user_id)
if isinstance(chat_peer, types.InputPeerChannel):
r = self.send(
r = await self.send(
functions.channels.EditBanned(
channel=chat_peer,
user_id=user_peer,
@ -90,7 +90,7 @@ class KickChatMember(BaseClient):
)
)
else:
r = self.send(
r = await self.send(
functions.messages.DeleteChatUser(
chat_id=abs(chat_id),
user_id=user_peer
@ -99,7 +99,7 @@ class KickChatMember(BaseClient):
for i in r.updates:
if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage)):
return pyrogram.Message._parse(
return await pyrogram.Message._parse(
self, i.message,
{i.id: i for i in r.users},
{i.id: i for i in r.chats}

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class LeaveChat(BaseClient):
def leave_chat(
async def leave_chat(
self,
chat_id: Union[int, str],
delete: bool = False
@ -48,16 +48,16 @@ class LeaveChat(BaseClient):
# Leave basic chat and also delete the dialog
app.leave_chat(chat_id, delete=True)
"""
peer = self.resolve_peer(chat_id)
peer = await self.resolve_peer(chat_id)
if isinstance(peer, types.InputPeerChannel):
return self.send(
return await self.send(
functions.channels.LeaveChannel(
channel=self.resolve_peer(chat_id)
channel=await self.resolve_peer(chat_id)
)
)
elif isinstance(peer, types.InputPeerChat):
r = self.send(
r = await self.send(
functions.messages.DeleteChatUser(
chat_id=peer.chat_id,
user_id=types.InputPeerSelf()
@ -65,7 +65,7 @@ class LeaveChat(BaseClient):
)
if delete:
self.send(
await self.send(
functions.messages.DeleteHistory(
peer=peer,
max_id=0

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class PinChatMessage(BaseClient):
def pin_chat_message(
async def pin_chat_message(
self,
chat_id: Union[int, str],
message_id: int,
@ -56,9 +56,9 @@ class PinChatMessage(BaseClient):
# Pin without notification
app.pin_chat_message(chat_id, message_id, disable_notification=True)
"""
self.send(
await self.send(
functions.messages.UpdatePinnedMessage(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
id=message_id,
silent=disable_notification or None
)

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class PromoteChatMember(BaseClient):
def promote_chat_member(
async def promote_chat_member(
self,
chat_id: Union[int, str],
user_id: Union[int, str],
@ -84,10 +84,10 @@ class PromoteChatMember(BaseClient):
# Promote chat member to supergroup admin
app.promote_chat_member(chat_id, user_id)
"""
self.send(
await self.send(
functions.channels.EditAdmin(
channel=self.resolve_peer(chat_id),
user_id=self.resolve_peer(user_id),
channel=await self.resolve_peer(chat_id),
user_id=await self.resolve_peer(user_id),
admin_rights=types.ChatAdminRights(
change_info=can_change_info or None,
post_messages=can_post_messages or None,

View File

@ -24,7 +24,7 @@ from ...types.user_and_chats import Chat, ChatPermissions
class RestrictChatMember(BaseClient):
def restrict_chat_member(
async def restrict_chat_member(
self,
chat_id: Union[int, str],
user_id: Union[int, str],
@ -71,10 +71,10 @@ class RestrictChatMember(BaseClient):
# Chat member can only send text messages
app.restrict_chat_member(chat_id, user_id, ChatPermissions(can_send_messages=True))
"""
r = self.send(
r = await self.send(
functions.channels.EditBanned(
channel=self.resolve_peer(chat_id),
user_id=self.resolve_peer(user_id),
channel=await self.resolve_peer(chat_id),
user_id=await self.resolve_peer(user_id),
banned_rights=types.ChatBannedRights(
until_date=until_date,
send_messages=True if not permissions.can_send_messages else None,

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class SetAdministratorTitle(BaseClient):
def set_administrator_title(
async def set_administrator_title(
self,
chat_id: Union[int, str],
user_id: Union[int, str],
@ -54,15 +54,15 @@ class SetAdministratorTitle(BaseClient):
app.set_administrator_title(chat_id, user_id, "ฅ^•ﻌ•^ฅ")
"""
chat_id = self.resolve_peer(chat_id)
user_id = self.resolve_peer(user_id)
chat_id = await self.resolve_peer(chat_id)
user_id = await self.resolve_peer(user_id)
r = self.send(
r = (await self.send(
functions.channels.GetParticipant(
channel=chat_id,
user_id=user_id
)
).participant
)).participant
if isinstance(r, types.ChannelParticipantCreator):
admin_rights = types.ChatAdminRights(
@ -104,7 +104,7 @@ class SetAdministratorTitle(BaseClient):
if not admin_rights.add_admins:
admin_rights.add_admins = None
self.send(
await self.send(
functions.channels.EditAdmin(
channel=chat_id,
user_id=user_id,

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class SetChatDescription(BaseClient):
def set_chat_description(
async def set_chat_description(
self,
chat_id: Union[int, str],
description: str
@ -49,10 +49,10 @@ class SetChatDescription(BaseClient):
app.set_chat_description(chat_id, "New Description")
"""
peer = self.resolve_peer(chat_id)
peer = await self.resolve_peer(chat_id)
if isinstance(peer, (types.InputPeerChannel, types.InputPeerChat)):
self.send(
await self.send(
functions.messages.EditChatAbout(
peer=peer,
about=description

View File

@ -24,7 +24,7 @@ from ...types.user_and_chats import Chat, ChatPermissions
class SetChatPermissions(BaseClient):
def set_chat_permissions(
async def set_chat_permissions(
self,
chat_id: Union[int, str],
permissions: ChatPermissions,
@ -63,9 +63,9 @@ class SetChatPermissions(BaseClient):
)
)
"""
r = self.send(
r = await self.send(
functions.messages.EditChatDefaultBannedRights(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
banned_rights=types.ChatBannedRights(
until_date=0,
send_messages=True if not permissions.can_send_messages else None,

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient, utils
class SetChatPhoto(BaseClient):
def set_chat_photo(
async def set_chat_photo(
self,
chat_id: Union[int, str],
*,
@ -79,32 +79,32 @@ class SetChatPhoto(BaseClient):
# Set chat photo using an exiting Video file_id
app.set_chat_photo(chat_id, video=video.file_id, file_ref=video.file_ref)
"""
peer = self.resolve_peer(chat_id)
peer = await self.resolve_peer(chat_id)
if isinstance(photo, str):
if os.path.isfile(photo):
photo = types.InputChatUploadedPhoto(
file=self.save_file(photo),
video=self.save_file(video)
file=await self.save_file(photo),
video=await self.save_file(video)
)
else:
photo = utils.get_input_media_from_file_id(photo, file_ref, 2)
photo = types.InputChatPhoto(id=photo.id)
else:
photo = types.InputChatUploadedPhoto(
file=self.save_file(photo),
video=self.save_file(video)
file=await self.save_file(photo),
video=await self.save_file(video)
)
if isinstance(peer, types.InputPeerChat):
self.send(
await self.send(
functions.messages.EditChatPhoto(
chat_id=peer.chat_id,
photo=photo
)
)
elif isinstance(peer, types.InputPeerChannel):
self.send(
await self.send(
functions.channels.EditPhoto(
channel=peer,
photo=photo

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class SetChatTitle(BaseClient):
def set_chat_title(
async def set_chat_title(
self,
chat_id: Union[int, str],
title: str
@ -54,17 +54,17 @@ class SetChatTitle(BaseClient):
app.set_chat_title(chat_id, "New Title")
"""
peer = self.resolve_peer(chat_id)
peer = await self.resolve_peer(chat_id)
if isinstance(peer, types.InputPeerChat):
self.send(
await self.send(
functions.messages.EditChatTitle(
chat_id=peer.chat_id,
title=title
)
)
elif isinstance(peer, types.InputPeerChannel):
self.send(
await self.send(
functions.channels.EditTitle(
channel=peer,
title=title

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class SetSlowMode(BaseClient):
def set_slow_mode(
async def set_slow_mode(
self,
chat_id: Union[int, str],
seconds: Union[int, None]
@ -51,9 +51,9 @@ class SetSlowMode(BaseClient):
app.set_slow_mode("pyrogramchat", None)
"""
self.send(
await self.send(
functions.channels.ToggleSlowMode(
channel=self.resolve_peer(chat_id),
channel=await self.resolve_peer(chat_id),
seconds=0 if seconds is None else seconds
)
)

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class UnarchiveChats(BaseClient):
def unarchive_chats(
async def unarchive_chats(
self,
chat_ids: Union[int, str, List[Union[int, str]]],
) -> bool:
@ -50,14 +50,19 @@ class UnarchiveChats(BaseClient):
if not isinstance(chat_ids, list):
chat_ids = [chat_ids]
self.send(
functions.folders.EditPeerFolders(
folder_peers=[
folder_peers = []
for chat in chat_ids:
folder_peers.append(
types.InputFolderPeer(
peer=self.resolve_peer(chat),
peer=await self.resolve_peer(chat),
folder_id=0
) for chat in chat_ids
]
)
)
await self.send(
functions.folders.EditPeerFolders(
folder_peers=folder_peers
)
)

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class UnbanChatMember(BaseClient):
def unban_chat_member(
async def unban_chat_member(
self,
chat_id: Union[int, str],
user_id: Union[int, str]
@ -49,10 +49,10 @@ class UnbanChatMember(BaseClient):
# Unban chat member right now
app.unban_chat_member(chat_id, user_id)
"""
self.send(
await self.send(
functions.channels.EditBanned(
channel=self.resolve_peer(chat_id),
user_id=self.resolve_peer(user_id),
channel=await self.resolve_peer(chat_id),
user_id=await self.resolve_peer(user_id),
banned_rights=types.ChatBannedRights(
until_date=0
)

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class UnpinChatMessage(BaseClient):
def unpin_chat_message(
async def unpin_chat_message(
self,
chat_id: Union[int, str]
) -> bool:
@ -43,9 +43,9 @@ class UnpinChatMessage(BaseClient):
app.unpin_chat_message(chat_id)
"""
self.send(
await self.send(
functions.messages.UpdatePinnedMessage(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
id=0
)
)

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class UpdateChatUsername(BaseClient):
def update_chat_username(
async def update_chat_username(
self,
chat_id: Union[int, str],
username: Union[str, None]
@ -50,11 +50,11 @@ class UpdateChatUsername(BaseClient):
app.update_chat_username(chat_id, "new_username")
"""
peer = self.resolve_peer(chat_id)
peer = await self.resolve_peer(chat_id)
if isinstance(peer, types.InputPeerChannel):
return bool(
self.send(
await self.send(
functions.channels.UpdateUsername(
channel=peer,
username=username or ""

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient
class AddContacts(BaseClient):
def add_contacts(
async def add_contacts(
self,
contacts: List["pyrogram.InputPhoneContact"]
):
@ -47,7 +47,7 @@ class AddContacts(BaseClient):
InputPhoneContact("38987654321", "Bar"),
InputPhoneContact("01234567891", "Baz")])
"""
imported_contacts = self.send(
imported_contacts = await self.send(
functions.contacts.ImportContacts(
contacts=contacts
)

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient
class DeleteContacts(BaseClient):
def delete_contacts(
async def delete_contacts(
self,
ids: List[int]
):
@ -47,14 +47,14 @@ class DeleteContacts(BaseClient):
for i in ids:
try:
input_user = self.resolve_peer(i)
input_user = await self.resolve_peer(i)
except PeerIdInvalid:
continue
else:
if isinstance(input_user, types.InputPeerUser):
contacts.append(input_user)
return self.send(
return await self.send(
functions.contacts.DeleteContacts(
id=contacts
)

View File

@ -16,6 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging
from typing import List
@ -27,7 +28,7 @@ log = logging.getLogger(__name__)
class GetContacts(BaseClient):
def get_contacts(self) -> List["pyrogram.User"]:
async def get_contacts(self) -> List["pyrogram.User"]:
"""Get contacts from your Telegram address book.
Returns:
@ -39,5 +40,5 @@ class GetContacts(BaseClient):
contacts = app.get_contacts()
print(contacts)
"""
contacts = self.send(functions.contacts.GetContacts(hash=0))
contacts = await self.send(functions.contacts.GetContacts(hash=0))
return pyrogram.List(pyrogram.User._parse(self, user) for user in contacts.users)

View File

@ -21,7 +21,7 @@ from ...ext import BaseClient
class GetContactsCount(BaseClient):
def get_contacts_count(self) -> int:
async def get_contacts_count(self) -> int:
"""Get the total count of contacts from your Telegram address book.
Returns:
@ -34,4 +34,4 @@ class GetContactsCount(BaseClient):
print(count)
"""
return len(self.send(functions.contacts.GetContacts(hash=0)).contacts)
return len((await self.send(functions.contacts.GetContacts(hash=0))).contacts)

View File

@ -23,7 +23,7 @@ from pyrogram.client.ext import BaseClient
class DeleteMessages(BaseClient):
def delete_messages(
async def delete_messages(
self,
chat_id: Union[int, str],
message_ids: Iterable[int],
@ -62,18 +62,18 @@ class DeleteMessages(BaseClient):
# Delete messages only on your side (without revoking)
app.delete_messages(chat_id, message_id, revoke=False)
"""
peer = self.resolve_peer(chat_id)
peer = await self.resolve_peer(chat_id)
message_ids = list(message_ids) if not isinstance(message_ids, int) else [message_ids]
if isinstance(peer, types.InputPeerChannel):
r = self.send(
r = await self.send(
functions.channels.DeleteMessages(
channel=peer,
id=message_ids
)
)
else:
r = self.send(
r = await self.send(
functions.messages.DeleteMessages(
id=message_ids,
revoke=revoke or None

View File

@ -16,6 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import binascii
import os
import struct
@ -32,7 +33,7 @@ DEFAULT_DOWNLOAD_DIR = "downloads/"
class DownloadMedia(BaseClient):
def download_media(
async def download_media(
self,
message: Union["pyrogram.Message", str],
file_ref: str = None,
@ -202,7 +203,7 @@ class DownloadMedia(BaseClient):
except (AssertionError, binascii.Error, struct.error):
raise FileIdInvalid from None
done = Event()
done = asyncio.Event()
path = [None]
directory, file_name = os.path.split(file_name)
@ -239,9 +240,9 @@ class DownloadMedia(BaseClient):
)
# Cast to string because Path objects aren't supported by Python 3.5
self.download_queue.put((data, str(directory), str(file_name), done, progress, progress_args, path))
self.download_queue.put_nowait((data, str(directory), str(file_name), done, progress, progress_args, path))
if block:
done.wait()
await done.wait()
return path[0]

View File

@ -23,7 +23,7 @@ from pyrogram.client.ext import BaseClient
class EditInlineCaption(BaseClient):
def edit_inline_caption(
async def edit_inline_caption(
self,
inline_message_id: str,
caption: str,
@ -58,7 +58,7 @@ class EditInlineCaption(BaseClient):
# Bots only
app.edit_inline_caption(inline_message_id, "new media caption")
"""
return self.edit_inline_text(
return await self.edit_inline_text(
inline_message_id=inline_message_id,
text=caption,
parse_mode=parse_mode,

View File

@ -29,7 +29,7 @@ from pyrogram.client.types.input_media import InputMedia
class EditInlineMedia(BaseClient):
def edit_inline_media(
async def edit_inline_media(
self,
inline_message_id: str,
media: InputMedia,
@ -109,11 +109,11 @@ class EditInlineMedia(BaseClient):
else:
media = utils.get_input_media_from_file_id(media.media, media.file_ref, 5)
return self.send(
return await self.send(
functions.messages.EditInlineBotMessage(
id=utils.unpack_inline_message_id(inline_message_id),
media=media,
reply_markup=reply_markup.write() if reply_markup else None,
**self.parser.parse(caption, parse_mode)
**await self.parser.parse(caption, parse_mode)
)
)

View File

@ -22,7 +22,7 @@ from pyrogram.client.ext import BaseClient, utils
class EditInlineReplyMarkup(BaseClient):
def edit_inline_reply_markup(
async def edit_inline_reply_markup(
self,
inline_message_id: str,
reply_markup: "pyrogram.InlineKeyboardMarkup" = None
@ -50,7 +50,7 @@ class EditInlineReplyMarkup(BaseClient):
InlineKeyboardMarkup([[
InlineKeyboardButton("New button", callback_data="new_data")]]))
"""
return self.send(
return await self.send(
functions.messages.EditInlineBotMessage(
id=utils.unpack_inline_message_id(inline_message_id),
reply_markup=reply_markup.write() if reply_markup else None,

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient, utils
class EditInlineText(BaseClient):
def edit_inline_text(
async def edit_inline_text(
self,
inline_message_id: str,
text: str,
@ -71,11 +71,11 @@ class EditInlineText(BaseClient):
disable_web_page_preview=True)
"""
return self.send(
return await self.send(
functions.messages.EditInlineBotMessage(
id=utils.unpack_inline_message_id(inline_message_id),
no_webpage=disable_web_page_preview or None,
reply_markup=reply_markup.write() if reply_markup else None,
**self.parser.parse(text, parse_mode)
**await self.parser.parse(text, parse_mode)
)
)

View File

@ -23,7 +23,7 @@ from pyrogram.client.ext import BaseClient
class EditMessageCaption(BaseClient):
def edit_message_caption(
async def edit_message_caption(
self,
chat_id: Union[int, str],
message_id: int,
@ -63,7 +63,7 @@ class EditMessageCaption(BaseClient):
app.edit_message_caption(chat_id, message_id, "new media caption")
"""
return self.edit_message_text(
return await self.edit_message_text(
chat_id=chat_id,
message_id=message_id,
text=caption,

View File

@ -31,7 +31,7 @@ from pyrogram.client.types.input_media import InputMedia
class EditMessageMedia(BaseClient):
def edit_message_media(
async def edit_message_media(
self,
chat_id: Union[int, str],
message_id: int,
@ -85,11 +85,11 @@ class EditMessageMedia(BaseClient):
if isinstance(media, InputMediaPhoto):
if os.path.isfile(media.media):
media = self.send(
media = await self.send(
functions.messages.UploadMedia(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
media=types.InputMediaUploadedPhoto(
file=self.save_file(media.media)
file=await self.save_file(media.media)
)
)
)
@ -109,13 +109,13 @@ class EditMessageMedia(BaseClient):
media = utils.get_input_media_from_file_id(media.media, media.file_ref, 2)
elif isinstance(media, InputMediaVideo):
if os.path.isfile(media.media):
media = self.send(
media = await self.send(
functions.messages.UploadMedia(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
media=types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(media.media) or "video/mp4",
thumb=self.save_file(media.thumb),
file=self.save_file(media.media),
thumb=await self.save_file(media.thumb),
file=await self.save_file(media.media),
attributes=[
types.DocumentAttributeVideo(
supports_streaming=media.supports_streaming or None,
@ -146,13 +146,13 @@ class EditMessageMedia(BaseClient):
media = utils.get_input_media_from_file_id(media.media, media.file_ref, 4)
elif isinstance(media, InputMediaAudio):
if os.path.isfile(media.media):
media = self.send(
media = await self.send(
functions.messages.UploadMedia(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
media=types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(media.media) or "audio/mpeg",
thumb=self.save_file(media.thumb),
file=self.save_file(media.media),
thumb=await self.save_file(media.thumb),
file=await self.save_file(media.media),
attributes=[
types.DocumentAttributeAudio(
duration=media.duration,
@ -182,13 +182,13 @@ class EditMessageMedia(BaseClient):
media = utils.get_input_media_from_file_id(media.media, media.file_ref, 9)
elif isinstance(media, InputMediaAnimation):
if os.path.isfile(media.media):
media = self.send(
media = await self.send(
functions.messages.UploadMedia(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
media=types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(media.media) or "video/mp4",
thumb=self.save_file(media.thumb),
file=self.save_file(media.media),
file=await self.save_file(media.media),
attributes=[
types.DocumentAttributeVideo(
supports_streaming=True,
@ -220,13 +220,13 @@ class EditMessageMedia(BaseClient):
media = utils.get_input_media_from_file_id(media.media, media.file_ref, 10)
elif isinstance(media, InputMediaDocument):
if os.path.isfile(media.media):
media = self.send(
media = await self.send(
functions.messages.UploadMedia(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
media=types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(media.media) or "application/zip",
thumb=self.save_file(media.thumb),
file=self.save_file(media.media),
thumb=await self.save_file(media.thumb),
file=await self.save_file(media.media),
attributes=[
types.DocumentAttributeFilename(
file_name=file_name or os.path.basename(media.media)
@ -250,19 +250,19 @@ class EditMessageMedia(BaseClient):
else:
media = utils.get_input_media_from_file_id(media.media, media.file_ref, 5)
r = self.send(
r = await self.send(
functions.messages.EditMessage(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
id=message_id,
media=media,
reply_markup=reply_markup.write() if reply_markup else None,
**self.parser.parse(caption, parse_mode)
**await self.parser.parse(caption, parse_mode)
)
)
for i in r.updates:
if isinstance(i, (types.UpdateEditMessage, types.UpdateEditChannelMessage)):
return pyrogram.Message._parse(
return await pyrogram.Message._parse(
self, i.message,
{i.id: i for i in r.users},
{i.id: i for i in r.chats}

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class EditMessageReplyMarkup(BaseClient):
def edit_message_reply_markup(
async def edit_message_reply_markup(
self,
chat_id: Union[int, str],
message_id: int,
@ -58,9 +58,9 @@ class EditMessageReplyMarkup(BaseClient):
InlineKeyboardMarkup([[
InlineKeyboardButton("New button", callback_data="new_data")]]))
"""
r = self.send(
r = await self.send(
functions.messages.EditMessage(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
id=message_id,
reply_markup=reply_markup.write() if reply_markup else None,
)
@ -68,7 +68,7 @@ class EditMessageReplyMarkup(BaseClient):
for i in r.updates:
if isinstance(i, (types.UpdateEditMessage, types.UpdateEditChannelMessage)):
return pyrogram.Message._parse(
return await pyrogram.Message._parse(
self, i.message,
{i.id: i for i in r.users},
{i.id: i for i in r.chats}

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class EditMessageText(BaseClient):
def edit_message_text(
async def edit_message_text(
self,
chat_id: Union[int, str],
message_id: int,
@ -75,19 +75,19 @@ class EditMessageText(BaseClient):
disable_web_page_preview=True)
"""
r = self.send(
r = await self.send(
functions.messages.EditMessage(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
id=message_id,
no_webpage=disable_web_page_preview or None,
reply_markup=reply_markup.write() if reply_markup else None,
**self.parser.parse(text, parse_mode)
**await self.parser.parse(text, parse_mode)
)
)
for i in r.updates:
if isinstance(i, (types.UpdateEditMessage, types.UpdateEditChannelMessage)):
return pyrogram.Message._parse(
return await pyrogram.Message._parse(
self, i.message,
{i.id: i for i in r.users},
{i.id: i for i in r.chats}

View File

@ -20,11 +20,12 @@ from typing import Union, Iterable, List
import pyrogram
from pyrogram.api import functions, types
from ...ext import BaseClient
class ForwardMessages(BaseClient):
def forward_messages(
async def forward_messages(
self,
chat_id: Union[int, str],
from_chat_id: Union[int, str],
@ -94,11 +95,11 @@ class ForwardMessages(BaseClient):
forwarded_messages = []
for chunk in [message_ids[i:i + 200] for i in range(0, len(message_ids), 200)]:
messages = self.get_messages(chat_id=from_chat_id, message_ids=chunk)
messages = await self.get_messages(chat_id=from_chat_id, message_ids=chunk)
for message in messages:
forwarded_messages.append(
message.forward(
await message.forward(
chat_id,
disable_notification=disable_notification,
as_copy=True,
@ -109,10 +110,10 @@ class ForwardMessages(BaseClient):
return pyrogram.List(forwarded_messages) if is_iterable else forwarded_messages[0]
else:
r = self.send(
r = await self.send(
functions.messages.ForwardMessages(
to_peer=self.resolve_peer(chat_id),
from_peer=self.resolve_peer(from_chat_id),
to_peer=await self.resolve_peer(chat_id),
from_peer=await self.resolve_peer(from_chat_id),
id=message_ids,
silent=disable_notification or None,
random_id=[self.rnd_id() for _ in message_ids],
@ -128,7 +129,7 @@ class ForwardMessages(BaseClient):
for i in r.updates:
if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)):
forwarded_messages.append(
pyrogram.Message._parse(
await pyrogram.Message._parse(
self, i.message,
users, chats
)

View File

@ -16,6 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging
from typing import Union, List
@ -28,7 +29,7 @@ log = logging.getLogger(__name__)
class GetHistory(BaseClient):
def get_history(
async def get_history(
self,
chat_id: Union[int, str],
limit: int = 100,
@ -83,11 +84,11 @@ class GetHistory(BaseClient):
offset_id = offset_id or (1 if reverse else 0)
messages = utils.parse_messages(
messages = await utils.parse_messages(
self,
self.send(
await self.send(
functions.messages.GetHistory(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
offset_id=offset_id,
offset_date=offset_date,
add_offset=offset * (-1 if reverse else 1) - (limit if reverse else 0),

View File

@ -26,7 +26,7 @@ log = logging.getLogger(__name__)
class GetHistoryCount(BaseClient):
def get_history_count(
async def get_history_count(
self,
chat_id: Union[int, str]
) -> int:
@ -51,9 +51,9 @@ class GetHistoryCount(BaseClient):
app.get_history_count("pyrogramchat")
"""
r = self.send(
r = await self.send(
functions.messages.GetHistory(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
offset_id=0,
offset_date=0,
add_offset=0,

View File

@ -16,6 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging
from typing import Union, Iterable, List
@ -30,7 +31,7 @@ log = logging.getLogger(__name__)
class GetMessages(BaseClient):
def get_messages(
async def get_messages(
self,
chat_id: Union[int, str],
message_ids: Union[int, Iterable[int]] = None,
@ -96,7 +97,7 @@ class GetMessages(BaseClient):
if ids is None:
raise ValueError("No argument supplied. Either pass message_ids or reply_to_message_ids")
peer = self.resolve_peer(chat_id)
peer = await self.resolve_peer(chat_id)
is_iterable = not isinstance(ids, int)
ids = list(ids) if is_iterable else [ids]
@ -110,8 +111,8 @@ class GetMessages(BaseClient):
else:
rpc = functions.messages.GetMessages(id=ids)
r = self.send(rpc)
r = await self.send(rpc)
messages = utils.parse_messages(self, r, replies=replies)
messages = await utils.parse_messages(self, r, replies=replies)
return messages if is_iterable else messages[0]

View File

@ -16,14 +16,17 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from typing import Union, Generator
from typing import Union, Optional, Generator
import pyrogram
from async_generator import async_generator, yield_
from ...ext import BaseClient
class IterHistory(BaseClient):
def iter_history(
@async_generator
async def iter_history(
self,
chat_id: Union[int, str],
limit: int = 0,
@ -31,7 +34,7 @@ class IterHistory(BaseClient):
offset_id: int = 0,
offset_date: int = 0,
reverse: bool = False
) -> Generator["pyrogram.Message", None, None]:
) -> Optional[Generator["pyrogram.Message", None, None]]:
"""Iterate through a chat history sequentially.
This convenience method does the same as repeatedly calling :meth:`~Client.get_history` in a loop, thus saving
@ -76,7 +79,7 @@ class IterHistory(BaseClient):
limit = min(100, total)
while True:
messages = self.get_history(
messages = await self.get_history(
chat_id=chat_id,
limit=limit,
offset=offset,
@ -91,7 +94,7 @@ class IterHistory(BaseClient):
offset_id = messages[-1].message_id + (1 if reverse else 0)
for message in messages:
yield message
await yield_(message)
current += 1

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class ReadHistory(BaseClient):
def read_history(
async def read_history(
self,
chat_id: Union[int, str],
max_id: int = 0
@ -53,7 +53,7 @@ class ReadHistory(BaseClient):
app.read_history("pyrogramlounge", 123456)
"""
peer = self.resolve_peer(chat_id)
peer = await self.resolve_peer(chat_id)
if isinstance(peer, types.InputPeerChannel):
q = functions.channels.ReadHistory(
@ -66,6 +66,6 @@ class ReadHistory(BaseClient):
max_id=max_id
)
self.send(q)
await self.send(q)
return True

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class RetractVote(BaseClient):
def retract_vote(
async def retract_vote(
self,
chat_id: Union[int, str],
message_id: int
@ -48,9 +48,9 @@ class RetractVote(BaseClient):
app.retract_vote(chat_id, message_id)
"""
r = self.send(
r = await self.send(
functions.messages.SendVote(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
msg_id=message_id,
options=[]
)

View File

@ -16,7 +16,9 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from typing import Generator
from typing import Generator, Optional
from async_generator import async_generator, yield_
import pyrogram
from pyrogram.api import functions, types
@ -24,11 +26,12 @@ from pyrogram.client.ext import BaseClient, utils
class SearchGlobal(BaseClient):
def search_global(
@async_generator
async def search_global(
self,
query: str,
limit: int = 0,
) -> Generator["pyrogram.Message", None, None]:
) -> Optional[Generator["pyrogram.Message", None, None]]:
"""Search messages globally from all of your chats.
.. note::
@ -64,9 +67,9 @@ class SearchGlobal(BaseClient):
offset_id = 0
while True:
messages = utils.parse_messages(
messages = await utils.parse_messages(
self,
self.send(
await self.send(
functions.messages.SearchGlobal(
q=query,
offset_rate=offset_date,
@ -84,11 +87,11 @@ class SearchGlobal(BaseClient):
last = messages[-1]
offset_date = last.date
offset_peer = self.resolve_peer(last.chat.id)
offset_peer = await self.resolve_peer(last.chat.id)
offset_id = last.message_id
for message in messages:
yield message
await yield_(message)
current += 1

View File

@ -16,11 +16,12 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from typing import Union, List, Generator
from typing import Union, List, Generator, Optional
import pyrogram
from pyrogram.client.ext import BaseClient, utils
from pyrogram.api import functions, types
from async_generator import async_generator, yield_
class Filters:
@ -46,7 +47,7 @@ POSSIBLE_VALUES = list(map(lambda x: x.lower(), filter(lambda x: not x.startswit
# noinspection PyShadowingBuiltins
def get_chunk(
async def get_chunk(
client: BaseClient,
chat_id: Union[int, str],
query: str = "",
@ -61,9 +62,9 @@ def get_chunk(
raise ValueError('Invalid filter "{}". Possible values are: {}'.format(
filter, ", ".join('"{}"'.format(v) for v in POSSIBLE_VALUES))) from None
r = client.send(
r = await client.send(
functions.messages.Search(
peer=client.resolve_peer(chat_id),
peer=await client.resolve_peer(chat_id),
q=query,
filter=filter,
min_date=0,
@ -74,7 +75,7 @@ def get_chunk(
min_id=0,
max_id=0,
from_id=(
client.resolve_peer(from_user)
await client.resolve_peer(from_user)
if from_user
else None
),
@ -82,12 +83,13 @@ def get_chunk(
)
)
return utils.parse_messages(client, r)
return await utils.parse_messages(client, r)
class SearchMessages(BaseClient):
# noinspection PyShadowingBuiltins
def search_messages(
@async_generator
async def search_messages(
self,
chat_id: Union[int, str],
query: str = "",
@ -95,7 +97,7 @@ class SearchMessages(BaseClient):
filter: str = "empty",
limit: int = 0,
from_user: Union[int, str] = None
) -> Generator["pyrogram.Message", None, None]:
) -> Optional[Generator["pyrogram.Message", None, None]]:
"""Search for text and media messages inside a specific chat.
Parameters:
@ -160,7 +162,7 @@ class SearchMessages(BaseClient):
limit = min(100, total)
while True:
messages = get_chunk(
messages = await get_chunk(
client=self,
chat_id=chat_id,
query=query,
@ -176,7 +178,7 @@ class SearchMessages(BaseClient):
offset += 100
for message in messages:
yield message
await yield_(message)
current += 1

View File

@ -27,7 +27,7 @@ from pyrogram.errors import FilePartMissing
class SendAnimation(BaseClient):
def send_animation(
async def send_animation(
self,
chat_id: Union[int, str],
animation: Union[str, BinaryIO],
@ -167,8 +167,8 @@ class SendAnimation(BaseClient):
try:
if isinstance(animation, str):
if os.path.isfile(animation):
thumb = self.save_file(thumb)
file = self.save_file(animation, progress=progress, progress_args=progress_args)
thumb = await self.save_file(thumb)
file = await self.save_file(animation, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(animation) or "video/mp4",
file=file,
@ -191,8 +191,8 @@ class SendAnimation(BaseClient):
else:
media = utils.get_input_media_from_file_id(animation, file_ref, 10)
else:
thumb = self.save_file(thumb)
file = self.save_file(animation, progress=progress, progress_args=progress_args)
thumb = await self.save_file(thumb)
file = await self.save_file(animation, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(animation.name) or "video/mp4",
file=file,
@ -211,27 +211,27 @@ class SendAnimation(BaseClient):
while True:
try:
r = self.send(
r = await self.send(
functions.messages.SendMedia(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
media=media,
silent=disable_notification or None,
reply_to_msg_id=reply_to_message_id,
random_id=self.rnd_id(),
schedule_date=schedule_date,
reply_markup=reply_markup.write() if reply_markup else None,
**self.parser.parse(caption, parse_mode)
**await self.parser.parse(caption, parse_mode)
)
)
except FilePartMissing as e:
self.save_file(animation, file_id=file.id, file_part=e.x)
await self.save_file(animation, file_id=file.id, file_part=e.x)
else:
for i in r.updates:
if isinstance(
i,
(types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)
):
message = pyrogram.Message._parse(
message = await pyrogram.Message._parse(
self, i.message,
{i.id: i for i in r.users},
{i.id: i for i in r.chats},
@ -242,7 +242,7 @@ class SendAnimation(BaseClient):
document = message.animation or message.document
document_id = utils.get_input_media_from_file_id(document.file_id, document.file_ref).id
self.send(
await self.send(
functions.messages.SaveGif(
id=document_id,
unsave=True

View File

@ -27,7 +27,7 @@ from pyrogram.errors import FilePartMissing
class SendAudio(BaseClient):
def send_audio(
async def send_audio(
self,
chat_id: Union[int, str],
audio: Union[str, BinaryIO],
@ -37,8 +37,7 @@ class SendAudio(BaseClient):
duration: int = 0,
performer: str = None,
title: str = None,
thumb: Union[str, BinaryIO] = None,
file_name: str = None,
thumb: Union[str, BinaryIO] = None, file_name: str = None,
disable_notification: bool = None,
reply_to_message_id: int = None,
schedule_date: int = None,
@ -167,8 +166,8 @@ class SendAudio(BaseClient):
try:
if isinstance(audio, str):
if os.path.isfile(audio):
thumb = self.save_file(thumb)
file = self.save_file(audio, progress=progress, progress_args=progress_args)
thumb = await self.save_file(thumb)
file = await self.save_file(audio, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(audio) or "audio/mpeg",
file=file,
@ -189,8 +188,8 @@ class SendAudio(BaseClient):
else:
media = utils.get_input_media_from_file_id(audio, file_ref, 9)
else:
thumb = self.save_file(thumb)
file = self.save_file(audio, progress=progress, progress_args=progress_args)
thumb = await self.save_file(thumb)
file = await self.save_file(audio, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(audio.name) or "audio/mpeg",
file=file,
@ -207,27 +206,27 @@ class SendAudio(BaseClient):
while True:
try:
r = self.send(
r = await self.send(
functions.messages.SendMedia(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
media=media,
silent=disable_notification or None,
reply_to_msg_id=reply_to_message_id,
random_id=self.rnd_id(),
schedule_date=schedule_date,
reply_markup=reply_markup.write() if reply_markup else None,
**self.parser.parse(caption, parse_mode)
**await self.parser.parse(caption, parse_mode)
)
)
except FilePartMissing as e:
self.save_file(audio, file_id=file.id, file_part=e.x)
await self.save_file(audio, file_id=file.id, file_part=e.x)
else:
for i in r.updates:
if isinstance(
i,
(types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)
):
return pyrogram.Message._parse(
return await pyrogram.Message._parse(
self, i.message,
{i.id: i for i in r.users},
{i.id: i for i in r.chats},

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient, utils
class SendCachedMedia(BaseClient):
def send_cached_media(
async def send_cached_media(
self,
chat_id: Union[int, str],
file_id: str,
@ -94,22 +94,22 @@ class SendCachedMedia(BaseClient):
app.send_cached_media("me", "CAADBAADzg4AAvLQYAEz_x2EOgdRwBYE")
"""
r = self.send(
r = await self.send(
functions.messages.SendMedia(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
media=utils.get_input_media_from_file_id(file_id, file_ref),
silent=disable_notification or None,
reply_to_msg_id=reply_to_message_id,
random_id=self.rnd_id(),
schedule_date=schedule_date,
reply_markup=reply_markup.write() if reply_markup else None,
**self.parser.parse(caption, parse_mode)
**await self.parser.parse(caption, parse_mode)
)
)
for i in r.updates:
if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)):
return pyrogram.Message._parse(
return await pyrogram.Message._parse(
self, i.message,
{i.id: i for i in r.users},
{i.id: i for i in r.chats},

View File

@ -43,7 +43,7 @@ POSSIBLE_VALUES = list(map(lambda x: x.lower(), filter(lambda x: not x.startswit
class SendChatAction(BaseClient):
def send_chat_action(self, chat_id: Union[int, str], action: str) -> bool:
async def send_chat_action(self, chat_id: Union[int, str], action: str) -> bool:
"""Tell the other party that something is happening on your side.
Parameters:
@ -93,9 +93,9 @@ class SendChatAction(BaseClient):
else:
action = action()
return self.send(
return await self.send(
functions.messages.SetTyping(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
action=action
)
)

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class SendContact(BaseClient):
def send_contact(
async def send_contact(
self,
chat_id: Union[int, str],
phone_number: str,
@ -83,9 +83,9 @@ class SendContact(BaseClient):
app.send_contact("me", "+39 123 456 7890", "Dan")
"""
r = self.send(
r = await self.send(
functions.messages.SendMedia(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
media=types.InputMediaContact(
phone_number=phone_number,
first_name=first_name,
@ -103,7 +103,7 @@ class SendContact(BaseClient):
for i in r.updates:
if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)):
return pyrogram.Message._parse(
return await pyrogram.Message._parse(
self, i.message,
{i.id: i for i in r.users},
{i.id: i for i in r.chats},

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class SendDice(BaseClient):
def send_dice(
async def send_dice(
self,
chat_id: Union[int, str],
emoji: str = "🎲",
@ -80,9 +80,9 @@ class SendDice(BaseClient):
app.send_dice("pyrogramlounge", "🏀")
"""
r = self.send(
r = await self.send(
functions.messages.SendMedia(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
media=types.InputMediaDice(emoticon=emoji),
silent=disable_notification or None,
reply_to_msg_id=reply_to_message_id,
@ -94,11 +94,8 @@ class SendDice(BaseClient):
)
for i in r.updates:
if isinstance(
i,
(types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)
):
return pyrogram.Message._parse(
if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)):
return await pyrogram.Message._parse(
self, i.message,
{i.id: i for i in r.users},
{i.id: i for i in r.chats},

View File

@ -27,7 +27,7 @@ from pyrogram.errors import FilePartMissing
class SendDocument(BaseClient):
def send_document(
async def send_document(
self,
chat_id: Union[int, str],
document: Union[str, BinaryIO],
@ -147,8 +147,8 @@ class SendDocument(BaseClient):
try:
if isinstance(document, str):
if os.path.isfile(document):
thumb = self.save_file(thumb)
file = self.save_file(document, progress=progress, progress_args=progress_args)
thumb = await self.save_file(thumb)
file = await self.save_file(document, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(document) or "application/zip",
file=file,
@ -165,8 +165,8 @@ class SendDocument(BaseClient):
else:
media = utils.get_input_media_from_file_id(document, file_ref, 5)
else:
thumb = self.save_file(thumb)
file = self.save_file(document, progress=progress, progress_args=progress_args)
thumb = await self.save_file(thumb)
file = await self.save_file(document, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(document.name) or "application/zip",
file=file,
@ -178,27 +178,27 @@ class SendDocument(BaseClient):
while True:
try:
r = self.send(
r = await self.send(
functions.messages.SendMedia(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
media=media,
silent=disable_notification or None,
reply_to_msg_id=reply_to_message_id,
random_id=self.rnd_id(),
schedule_date=schedule_date,
reply_markup=reply_markup.write() if reply_markup else None,
**self.parser.parse(caption, parse_mode)
**await self.parser.parse(caption, parse_mode)
)
)
except FilePartMissing as e:
self.save_file(document, file_id=file.id, file_part=e.x)
await self.save_file(document, file_id=file.id, file_part=e.x)
else:
for i in r.updates:
if isinstance(
i,
(types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)
):
return pyrogram.Message._parse(
return await pyrogram.Message._parse(
self, i.message,
{i.id: i for i in r.users},
{i.id: i for i in r.chats},

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class SendLocation(BaseClient):
def send_location(
async def send_location(
self,
chat_id: Union[int, str],
latitude: float,
@ -75,9 +75,9 @@ class SendLocation(BaseClient):
app.send_location("me", 51.500729, -0.124583)
"""
r = self.send(
r = await self.send(
functions.messages.SendMedia(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
media=types.InputMediaGeoPoint(
geo_point=types.InputGeoPoint(
lat=latitude,
@ -95,7 +95,7 @@ class SendLocation(BaseClient):
for i in r.updates:
if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)):
return pyrogram.Message._parse(
return await pyrogram.Message._parse(
self, i.message,
{i.id: i for i in r.users},
{i.id: i for i in r.chats},

View File

@ -30,7 +30,7 @@ log = logging.getLogger(__name__)
class SendMediaGroup(BaseClient):
# TODO: Add progress parameter
def send_media_group(
async def send_media_group(
self,
chat_id: Union[int, str],
media: List[Union["pyrogram.InputMediaPhoto", "pyrogram.InputMediaVideo"]],
@ -77,11 +77,11 @@ class SendMediaGroup(BaseClient):
for i in media:
if isinstance(i, pyrogram.InputMediaPhoto):
if os.path.isfile(i.media):
media = self.send(
media = await self.send(
functions.messages.UploadMedia(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
media=types.InputMediaUploadedPhoto(
file=self.save_file(i.media)
file=await self.save_file(i.media)
)
)
)
@ -94,9 +94,9 @@ class SendMediaGroup(BaseClient):
)
)
elif re.match("^https?://", i.media):
media = self.send(
media = await self.send(
functions.messages.UploadMedia(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
media=types.InputMediaPhotoExternal(
url=i.media
)
@ -114,11 +114,11 @@ class SendMediaGroup(BaseClient):
media = utils.get_input_media_from_file_id(i.media, i.file_ref, 2)
elif isinstance(i, pyrogram.InputMediaVideo):
if os.path.isfile(i.media):
media = self.send(
media = await self.send(
functions.messages.UploadMedia(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
media=types.InputMediaUploadedDocument(
file=self.save_file(i.media),
file=await self.save_file(i.media),
thumb=self.save_file(i.thumb),
mime_type=self.guess_mime_type(i.media) or "video/mp4",
attributes=[
@ -142,9 +142,9 @@ class SendMediaGroup(BaseClient):
)
)
elif re.match("^https?://", i.media):
media = self.send(
media = await self.send(
functions.messages.UploadMedia(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
media=types.InputMediaDocumentExternal(
url=i.media
)
@ -165,20 +165,20 @@ class SendMediaGroup(BaseClient):
types.InputSingleMedia(
media=media,
random_id=self.rnd_id(),
**self.parser.parse(i.caption, i.parse_mode)
**await self.parser.parse(i.caption, i.parse_mode)
)
)
r = self.send(
r = await self.send(
functions.messages.SendMultiMedia(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
multi_media=multi_media,
silent=disable_notification or None,
reply_to_msg_id=reply_to_message_id
)
)
return utils.parse_messages(
return await utils.parse_messages(
self,
types.messages.Messages(
messages=[m.message for m in filter(

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient
class SendMessage(BaseClient):
def send_message(
async def send_message(
self,
chat_id: Union[int, str],
text: str,
@ -116,11 +116,11 @@ class SendMessage(BaseClient):
]))
"""
message, entities = self.parser.parse(text, parse_mode).values()
message, entities = (await self.parser.parse(text, parse_mode)).values()
r = self.send(
r = await self.send(
functions.messages.SendMessage(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
no_webpage=disable_web_page_preview or None,
silent=disable_notification or None,
reply_to_msg_id=reply_to_message_id,
@ -133,7 +133,7 @@ class SendMessage(BaseClient):
)
if isinstance(r, types.UpdateShortSentMessage):
peer = self.resolve_peer(chat_id)
peer = await self.resolve_peer(chat_id)
peer_id = (
peer.user_id
@ -160,7 +160,7 @@ class SendMessage(BaseClient):
for i in r.updates:
if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)):
return pyrogram.Message._parse(
return await pyrogram.Message._parse(
self, i.message,
{i.id: i for i in r.users},
{i.id: i for i in r.chats},

View File

@ -27,7 +27,7 @@ from pyrogram.errors import FilePartMissing
class SendPhoto(BaseClient):
def send_photo(
async def send_photo(
self,
chat_id: Union[int, str],
photo: Union[str, BinaryIO],
@ -141,7 +141,7 @@ class SendPhoto(BaseClient):
try:
if isinstance(photo, str):
if os.path.isfile(photo):
file = self.save_file(photo, progress=progress, progress_args=progress_args)
file = await self.save_file(photo, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedPhoto(
file=file,
ttl_seconds=ttl_seconds
@ -154,7 +154,7 @@ class SendPhoto(BaseClient):
else:
media = utils.get_input_media_from_file_id(photo, file_ref, 2)
else:
file = self.save_file(photo, progress=progress, progress_args=progress_args)
file = await self.save_file(photo, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedPhoto(
file=file,
ttl_seconds=ttl_seconds
@ -162,27 +162,27 @@ class SendPhoto(BaseClient):
while True:
try:
r = self.send(
r = await self.send(
functions.messages.SendMedia(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
media=media,
silent=disable_notification or None,
reply_to_msg_id=reply_to_message_id,
random_id=self.rnd_id(),
schedule_date=schedule_date,
reply_markup=reply_markup.write() if reply_markup else None,
**self.parser.parse(caption, parse_mode)
**await self.parser.parse(caption, parse_mode)
)
)
except FilePartMissing as e:
self.save_file(photo, file_id=file.id, file_part=e.x)
await self.save_file(photo, file_id=file.id, file_part=e.x)
else:
for i in r.updates:
if isinstance(
i,
(types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)
):
return pyrogram.Message._parse(
return await pyrogram.Message._parse(
self, i.message,
{i.id: i for i in r.users},
{i.id: i for i in r.chats},

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class SendPoll(BaseClient):
def send_poll(
async def send_poll(
self,
chat_id: Union[int, str],
question: str,
@ -95,9 +95,9 @@ class SendPoll(BaseClient):
app.send_poll(chat_id, "Is this a poll question?", ["Yes", "No", "Maybe"])
"""
r = self.send(
r = await self.send(
functions.messages.SendMedia(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
media=types.InputMediaPoll(
poll=types.Poll(
id=0,
@ -123,7 +123,7 @@ class SendPoll(BaseClient):
for i in r.updates:
if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)):
return pyrogram.Message._parse(
return await pyrogram.Message._parse(
self, i.message,
{i.id: i for i in r.users},
{i.id: i for i in r.chats},

View File

@ -27,7 +27,7 @@ from pyrogram.errors import FilePartMissing
class SendSticker(BaseClient):
def send_sticker(
async def send_sticker(
self,
chat_id: Union[int, str],
sticker: Union[str, BinaryIO],
@ -117,7 +117,7 @@ class SendSticker(BaseClient):
try:
if isinstance(sticker, str):
if os.path.isfile(sticker):
file = self.save_file(sticker, progress=progress, progress_args=progress_args)
file = await self.save_file(sticker, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(sticker) or "image/webp",
file=file,
@ -132,7 +132,7 @@ class SendSticker(BaseClient):
else:
media = utils.get_input_media_from_file_id(sticker, file_ref, 8)
else:
file = self.save_file(sticker, progress=progress, progress_args=progress_args)
file = await self.save_file(sticker, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(sticker.name) or "image/webp",
file=file,
@ -143,9 +143,9 @@ class SendSticker(BaseClient):
while True:
try:
r = self.send(
r = await self.send(
functions.messages.SendMedia(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
media=media,
silent=disable_notification or None,
reply_to_msg_id=reply_to_message_id,
@ -156,14 +156,14 @@ class SendSticker(BaseClient):
)
)
except FilePartMissing as e:
self.save_file(sticker, file_id=file.id, file_part=e.x)
await self.save_file(sticker, file_id=file.id, file_part=e.x)
else:
for i in r.updates:
if isinstance(
i,
(types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)
):
return pyrogram.Message._parse(
return await pyrogram.Message._parse(
self, i.message,
{i.id: i for i in r.users},
{i.id: i for i in r.chats},

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class SendVenue(BaseClient):
def send_venue(
async def send_venue(
self,
chat_id: Union[int, str],
latitude: float,
@ -94,9 +94,9 @@ class SendVenue(BaseClient):
"me", 51.500729, -0.124583,
"Elizabeth Tower", "Westminster, London SW1A 0AA, UK")
"""
r = self.send(
r = await self.send(
functions.messages.SendMedia(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
media=types.InputMediaVenue(
geo_point=types.InputGeoPoint(
lat=latitude,
@ -119,7 +119,7 @@ class SendVenue(BaseClient):
for i in r.updates:
if isinstance(i, (types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)):
return pyrogram.Message._parse(
return await pyrogram.Message._parse(
self, i.message,
{i.id: i for i in r.users},
{i.id: i for i in r.chats},

View File

@ -27,7 +27,7 @@ from pyrogram.errors import FilePartMissing
class SendVideo(BaseClient):
def send_video(
async def send_video(
self,
chat_id: Union[int, str],
video: Union[str, BinaryIO],
@ -164,8 +164,8 @@ class SendVideo(BaseClient):
try:
if isinstance(video, str):
if os.path.isfile(video):
thumb = self.save_file(thumb)
file = self.save_file(video, progress=progress, progress_args=progress_args)
thumb = await self.save_file(thumb)
file = await self.save_file(video, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(video) or "video/mp4",
file=file,
@ -187,8 +187,8 @@ class SendVideo(BaseClient):
else:
media = utils.get_input_media_from_file_id(video, file_ref, 4)
else:
thumb = self.save_file(thumb)
file = self.save_file(video, progress=progress, progress_args=progress_args)
thumb = await self.save_file(thumb)
file = await self.save_file(video, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(video.name) or "video/mp4",
file=file,
@ -206,27 +206,27 @@ class SendVideo(BaseClient):
while True:
try:
r = self.send(
r = await self.send(
functions.messages.SendMedia(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
media=media,
silent=disable_notification or None,
reply_to_msg_id=reply_to_message_id,
random_id=self.rnd_id(),
schedule_date=schedule_date,
reply_markup=reply_markup.write() if reply_markup else None,
**self.parser.parse(caption, parse_mode)
**await self.parser.parse(caption, parse_mode)
)
)
except FilePartMissing as e:
self.save_file(video, file_id=file.id, file_part=e.x)
await self.save_file(video, file_id=file.id, file_part=e.x)
else:
for i in r.updates:
if isinstance(
i,
(types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)
):
return pyrogram.Message._parse(
return await pyrogram.Message._parse(
self, i.message,
{i.id: i for i in r.users},
{i.id: i for i in r.chats},

View File

@ -26,7 +26,7 @@ from pyrogram.errors import FilePartMissing
class SendVideoNote(BaseClient):
def send_video_note(
async def send_video_note(
self,
chat_id: Union[int, str],
video_note: Union[str, BinaryIO],
@ -131,8 +131,8 @@ class SendVideoNote(BaseClient):
try:
if isinstance(video_note, str):
if os.path.isfile(video_note):
thumb = self.save_file(thumb)
file = self.save_file(video_note, progress=progress, progress_args=progress_args)
thumb = await self.save_file(thumb)
file = await self.save_file(video_note, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(video_note) or "video/mp4",
file=file,
@ -149,8 +149,8 @@ class SendVideoNote(BaseClient):
else:
media = utils.get_input_media_from_file_id(video_note, file_ref, 13)
else:
thumb = self.save_file(thumb)
file = self.save_file(video_note, progress=progress, progress_args=progress_args)
thumb = await self.save_file(thumb)
file = await self.save_file(video_note, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(video_note.name) or "video/mp4",
file=file,
@ -167,9 +167,9 @@ class SendVideoNote(BaseClient):
while True:
try:
r = self.send(
r = await self.send(
functions.messages.SendMedia(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
media=media,
silent=disable_notification or None,
reply_to_msg_id=reply_to_message_id,
@ -180,14 +180,14 @@ class SendVideoNote(BaseClient):
)
)
except FilePartMissing as e:
self.save_file(video_note, file_id=file.id, file_part=e.x)
await self.save_file(video_note, file_id=file.id, file_part=e.x)
else:
for i in r.updates:
if isinstance(
i,
(types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)
):
return pyrogram.Message._parse(
return await pyrogram.Message._parse(
self, i.message,
{i.id: i for i in r.users},
{i.id: i for i in r.chats},

View File

@ -27,7 +27,7 @@ from pyrogram.errors import FilePartMissing
class SendVoice(BaseClient):
def send_voice(
async def send_voice(
self,
chat_id: Union[int, str],
voice: Union[str, BinaryIO],
@ -136,7 +136,7 @@ class SendVoice(BaseClient):
try:
if isinstance(voice, str):
if os.path.isfile(voice):
file = self.save_file(voice, progress=progress, progress_args=progress_args)
file = await self.save_file(voice, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(voice) or "audio/mpeg",
file=file,
@ -154,7 +154,7 @@ class SendVoice(BaseClient):
else:
media = utils.get_input_media_from_file_id(voice, file_ref, 3)
else:
file = self.save_file(voice, progress=progress, progress_args=progress_args)
file = await self.save_file(voice, progress=progress, progress_args=progress_args)
media = types.InputMediaUploadedDocument(
mime_type=self.guess_mime_type(voice.name) or "audio/mpeg",
file=file,
@ -168,27 +168,27 @@ class SendVoice(BaseClient):
while True:
try:
r = self.send(
r = await self.send(
functions.messages.SendMedia(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
media=media,
silent=disable_notification or None,
reply_to_msg_id=reply_to_message_id,
random_id=self.rnd_id(),
schedule_date=schedule_date,
reply_markup=reply_markup.write() if reply_markup else None,
**self.parser.parse(caption, parse_mode)
**await self.parser.parse(caption, parse_mode)
)
)
except FilePartMissing as e:
self.save_file(voice, file_id=file.id, file_part=e.x)
await self.save_file(voice, file_id=file.id, file_part=e.x)
else:
for i in r.updates:
if isinstance(
i,
(types.UpdateNewMessage, types.UpdateNewChannelMessage, types.UpdateNewScheduledMessage)
):
return pyrogram.Message._parse(
return await pyrogram.Message._parse(
self, i.message,
{i.id: i for i in r.users},
{i.id: i for i in r.chats},

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class StopPoll(BaseClient):
def stop_poll(
async def stop_poll(
self,
chat_id: Union[int, str],
message_id: int,
@ -54,11 +54,11 @@ class StopPoll(BaseClient):
app.stop_poll(chat_id, message_id)
"""
poll = self.get_messages(chat_id, message_id).poll
poll = (await self.get_messages(chat_id, message_id)).poll
r = self.send(
r = await self.send(
functions.messages.EditMessage(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
id=message_id,
media=types.InputMediaPoll(
poll=types.Poll(

View File

@ -24,7 +24,7 @@ from pyrogram.client.ext import BaseClient
class VotePoll(BaseClient):
def vote_poll(
async def vote_poll(
self,
chat_id: Union[int, str],
message_id: id,
@ -53,12 +53,12 @@ class VotePoll(BaseClient):
app.vote_poll(chat_id, message_id, 6)
"""
poll = self.get_messages(chat_id, message_id).poll
poll = (await self.get_messages(chat_id, message_id)).poll
options = [options] if not isinstance(options, list) else options
r = self.send(
r = await self.send(
functions.messages.SendVote(
peer=self.resolve_peer(chat_id),
peer=await self.resolve_peer(chat_id),
msg_id=message_id,
options=[poll.options[option].data for option in options]
)

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient
class ChangeCloudPassword(BaseClient):
def change_cloud_password(
async def change_cloud_password(
self,
current_password: str,
new_password: str,
@ -57,7 +57,7 @@ class ChangeCloudPassword(BaseClient):
# Change password and hint
app.change_cloud_password("current_password", "new_password", new_hint="hint")
"""
r = self.send(functions.account.GetPassword())
r = await self.send(functions.account.GetPassword())
if not r.has_password:
raise ValueError("There is no cloud password to change")
@ -66,7 +66,7 @@ class ChangeCloudPassword(BaseClient):
new_hash = btoi(compute_hash(r.new_algo, new_password))
new_hash = itob(pow(r.new_algo.g, new_hash, btoi(r.new_algo.p)))
self.send(
await self.send(
functions.account.UpdatePasswordSettings(
password=compute_check(r, current_password),
new_settings=types.account.PasswordInputSettings(

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient
class EnableCloudPassword(BaseClient):
def enable_cloud_password(
async def enable_cloud_password(
self,
password: str,
hint: str = "",
@ -62,7 +62,7 @@ class EnableCloudPassword(BaseClient):
# Enable password with hint and email
app.enable_cloud_password("password", hint="hint", email="user@email.com")
"""
r = self.send(functions.account.GetPassword())
r = await self.send(functions.account.GetPassword())
if r.has_password:
raise ValueError("There is already a cloud password enabled")
@ -71,7 +71,7 @@ class EnableCloudPassword(BaseClient):
new_hash = btoi(compute_hash(r.new_algo, password))
new_hash = itob(pow(r.new_algo.g, new_hash, btoi(r.new_algo.p)))
self.send(
await self.send(
functions.account.UpdatePasswordSettings(
password=types.InputCheckPasswordEmpty(),
new_settings=types.account.PasswordInputSettings(

View File

@ -22,7 +22,7 @@ from ...ext import BaseClient
class RemoveCloudPassword(BaseClient):
def remove_cloud_password(
async def remove_cloud_password(
self,
password: str
) -> bool:
@ -43,12 +43,12 @@ class RemoveCloudPassword(BaseClient):
app.remove_cloud_password("password")
"""
r = self.send(functions.account.GetPassword())
r = await self.send(functions.account.GetPassword())
if not r.has_password:
raise ValueError("There is no cloud password to remove")
self.send(
await self.send(
functions.account.UpdatePasswordSettings(
password=compute_check(r, password),
new_settings=types.account.PasswordInputSettings(

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class BlockUser(BaseClient):
def block_user(
async def block_user(
self,
user_id: Union[int, str]
) -> bool:
@ -44,9 +44,9 @@ class BlockUser(BaseClient):
app.block_user(user_id)
"""
return bool(
self.send(
await self.send(
functions.contacts.Block(
id=self.resolve_peer(user_id)
id=await self.resolve_peer(user_id)
)
)
)

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient
class DeleteProfilePhotos(BaseClient):
def delete_profile_photos(
async def delete_profile_photos(
self,
photo_ids: Union[str, List[str]]
) -> bool:
@ -53,7 +53,7 @@ class DeleteProfilePhotos(BaseClient):
photo_ids = photo_ids if isinstance(photo_ids, list) else [photo_ids]
input_photos = [utils.get_input_media_from_file_id(i).id for i in photo_ids]
return bool(self.send(
return bool(await self.send(
functions.photos.DeletePhotos(
id=input_photos
)

View File

@ -24,7 +24,7 @@ from ...ext import BaseClient
class GetCommonChats(BaseClient):
def get_common_chats(self, user_id: Union[int, str]) -> list:
async def get_common_chats(self, user_id: Union[int, str]) -> list:
"""Get the common chats you have with a user.
Parameters:
@ -46,10 +46,10 @@ class GetCommonChats(BaseClient):
print(common)
"""
peer = self.resolve_peer(user_id)
peer = await self.resolve_peer(user_id)
if isinstance(peer, types.InputPeerUser):
r = self.send(
r = await self.send(
functions.messages.GetCommonChats(
user_id=peer,
max_id=0,

View File

@ -22,7 +22,7 @@ from ...ext import BaseClient
class GetMe(BaseClient):
def get_me(self) -> "pyrogram.User":
async def get_me(self) -> "pyrogram.User":
"""Get your own user identity.
Returns:
@ -36,9 +36,9 @@ class GetMe(BaseClient):
"""
return pyrogram.User._parse(
self,
self.send(
(await self.send(
functions.users.GetFullUser(
id=types.InputPeerSelf()
)
).user
)).user
)

View File

@ -25,7 +25,7 @@ from ...ext import BaseClient
class GetProfilePhotos(BaseClient):
def get_profile_photos(
async def get_profile_photos(
self,
chat_id: Union[int, str],
offset: int = 0,
@ -62,12 +62,12 @@ class GetProfilePhotos(BaseClient):
# Get 3 profile photos of a user, skip the first 5
app.get_profile_photos("haskell", limit=3, offset=5)
"""
peer_id = self.resolve_peer(chat_id)
peer_id = await self.resolve_peer(chat_id)
if isinstance(peer_id, types.InputPeerChannel):
r = utils.parse_messages(
r = await utils.parse_messages(
self,
self.send(
await self.send(
functions.messages.Search(
peer=peer_id,
q="",
@ -86,7 +86,7 @@ class GetProfilePhotos(BaseClient):
return pyrogram.List([message.new_chat_photo for message in r][:limit])
else:
r = self.send(
r = await self.send(
functions.photos.GetUserPhotos(
user_id=peer_id,
offset=offset,

View File

@ -23,7 +23,7 @@ from ...ext import BaseClient
class GetProfilePhotosCount(BaseClient):
def get_profile_photos_count(self, chat_id: Union[int, str]) -> int:
async def get_profile_photos_count(self, chat_id: Union[int, str]) -> int:
"""Get the total count of profile pictures for a user.
Parameters:
@ -42,10 +42,10 @@ class GetProfilePhotosCount(BaseClient):
print(count)
"""
peer_id = self.resolve_peer(chat_id)
peer_id = await self.resolve_peer(chat_id)
if isinstance(peer_id, types.InputPeerChannel):
r = self.send(
r = await self.send(
functions.messages.GetSearchCounters(
peer=peer_id,
filters=[types.InputMessagesFilterChatPhotos()],
@ -54,7 +54,7 @@ class GetProfilePhotosCount(BaseClient):
return r[0].count
else:
r = self.send(
r = await self.send(
functions.photos.GetUserPhotos(
user_id=peer_id,
offset=0,

View File

@ -16,6 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
from typing import Iterable, Union, List
import pyrogram
@ -24,7 +25,7 @@ from ...ext import BaseClient
class GetUsers(BaseClient):
def get_users(
async def get_users(
self,
user_ids: Union[Iterable[Union[int, str]], int, str]
) -> Union["pyrogram.User", List["pyrogram.User"]]:
@ -53,9 +54,9 @@ class GetUsers(BaseClient):
"""
is_iterable = not isinstance(user_ids, (int, str))
user_ids = list(user_ids) if is_iterable else [user_ids]
user_ids = [self.resolve_peer(i) for i in user_ids]
user_ids = await asyncio.gather(*[self.resolve_peer(i) for i in user_ids])
r = self.send(
r = await self.send(
functions.users.GetUsers(
id=user_ids
)

Some files were not shown because too many files have changed in this diff Show More