Allow to specify a limit to concurrent transmissions

This commit is contained in:
Dan 2023-01-12 18:19:15 +01:00
parent 2a7110e257
commit 82b8c7792e
2 changed files with 14 additions and 5 deletions

View File

@ -172,6 +172,11 @@ class Client(Methods):
Pass True to hide the password when typing it during the login. Pass True to hide the password when typing it during the login.
Defaults to False, because ``getpass`` (the library used) is known to be problematic in some Defaults to False, because ``getpass`` (the library used) is known to be problematic in some
terminal environments. terminal environments.
max_concurrent_transmissions (``bool``, *optional*):
Set the maximum amount of concurrent transmissions (uploads & downloads).
A value that is too high may result in network related issues.
Defaults to 1.
""" """
APP_VERSION = f"Pyrogram {__version__}" APP_VERSION = f"Pyrogram {__version__}"
@ -189,6 +194,8 @@ class Client(Methods):
# Interval of seconds in which the updates watchdog will kick in # Interval of seconds in which the updates watchdog will kick in
UPDATES_WATCHDOG_INTERVAL = 5 * 60 UPDATES_WATCHDOG_INTERVAL = 5 * 60
MAX_CONCURRENT_TRANSMISSIONS = 1
mimetypes = MimeTypes() mimetypes = MimeTypes()
mimetypes.readfp(StringIO(mime_types)) mimetypes.readfp(StringIO(mime_types))
@ -217,7 +224,8 @@ class Client(Methods):
no_updates: bool = None, no_updates: bool = None,
takeout: bool = None, takeout: bool = None,
sleep_threshold: int = Session.SLEEP_THRESHOLD, sleep_threshold: int = Session.SLEEP_THRESHOLD,
hide_password: bool = False hide_password: bool = False,
max_concurrent_transmissions: int = MAX_CONCURRENT_TRANSMISSIONS
): ):
super().__init__() super().__init__()
@ -245,6 +253,7 @@ class Client(Methods):
self.takeout = takeout self.takeout = takeout
self.sleep_threshold = sleep_threshold self.sleep_threshold = sleep_threshold
self.hide_password = hide_password self.hide_password = hide_password
self.max_concurrent_transmissions = max_concurrent_transmissions
self.executor = ThreadPoolExecutor(self.workers, thread_name_prefix="Handler") self.executor = ThreadPoolExecutor(self.workers, thread_name_prefix="Handler")
@ -266,8 +275,8 @@ class Client(Methods):
self.media_sessions = {} self.media_sessions = {}
self.media_sessions_lock = asyncio.Lock() self.media_sessions_lock = asyncio.Lock()
self.save_file_lock = asyncio.Lock() self.save_file_semaphore = asyncio.Semaphore(self.max_concurrent_transmissions)
self.get_file_lock = asyncio.Lock() self.get_file_semaphore = asyncio.Semaphore(self.max_concurrent_transmissions)
self.is_connected = None self.is_connected = None
self.is_initialized = None self.is_initialized = None
@ -798,7 +807,7 @@ class Client(Methods):
progress: Callable = None, progress: Callable = None,
progress_args: tuple = () progress_args: tuple = ()
) -> Optional[AsyncGenerator[bytes, None]]: ) -> Optional[AsyncGenerator[bytes, None]]:
async with self.get_file_lock: async with self.get_file_semaphore:
file_type = file_id.file_type file_type = file_id.file_type
if file_type == FileType.CHAT_PHOTO: if file_type == FileType.CHAT_PHOTO:

View File

@ -94,7 +94,7 @@ class SaveFile:
Raises: Raises:
RPCError: In case of a Telegram RPC error. RPCError: In case of a Telegram RPC error.
""" """
async with self.save_file_lock: async with self.save_file_semaphore:
if path is None: if path is None:
return None return None