diff --git a/pyrogram/__init__.py b/pyrogram/__init__.py index 39b83df7..0d0669e7 100644 --- a/pyrogram/__init__.py +++ b/pyrogram/__init__.py @@ -20,6 +20,8 @@ __version__ = "1.0.7" __license__ = "GNU Lesser General Public License v3 or later (LGPLv3+)" __copyright__ = "Copyright (C) 2017-2020 Dan " +from concurrent.futures.thread import ThreadPoolExecutor + class StopTransmission(StopAsyncIteration): pass @@ -41,3 +43,7 @@ from .sync import idle # Save the main thread loop for future references main_event_loop = asyncio.get_event_loop() + +CRYPTO_EXECUTOR_SIZE_THRESHOLD = 512 + +crypto_executor = ThreadPoolExecutor(2, thread_name_prefix="CryptoWorker") diff --git a/pyrogram/connection/transport/tcp/tcp.py b/pyrogram/connection/transport/tcp/tcp.py index b2ff5391..acc248b7 100644 --- a/pyrogram/connection/transport/tcp/tcp.py +++ b/pyrogram/connection/transport/tcp/tcp.py @@ -45,6 +45,7 @@ class TCP: self.writer = None # type: asyncio.StreamWriter self.lock = asyncio.Lock() + self.loop = asyncio.get_event_loop() if proxy.get("enabled", False): hostname = proxy.get("hostname", None) diff --git a/pyrogram/connection/transport/tcp/tcp_abridged_o.py b/pyrogram/connection/transport/tcp/tcp_abridged_o.py index c7b24159..04644a63 100644 --- a/pyrogram/connection/transport/tcp/tcp_abridged_o.py +++ b/pyrogram/connection/transport/tcp/tcp_abridged_o.py @@ -19,6 +19,7 @@ import logging import os +from pyrogram import utils from pyrogram.crypto import aes from .tcp import TCP @@ -55,16 +56,10 @@ class TCPAbridgedO(TCP): async def send(self, data: bytes, *args): length = len(data) // 4 + data = (bytes([length]) if length <= 126 else b"\x7f" + length.to_bytes(3, "little")) + data + payload = await utils.maybe_run_in_executor(aes.ctr256_encrypt, data, len(data), self.loop, *self.encrypt) - await super().send( - aes.ctr256_encrypt( - (bytes([length]) - if length <= 126 - else b"\x7f" + length.to_bytes(3, "little")) - + data, - *self.encrypt - ) - ) + await super().send(payload) async def recv(self, length: int = 0) -> bytes or None: length = await super().recv(1) @@ -87,4 +82,4 @@ class TCPAbridgedO(TCP): if data is None: return None - return aes.ctr256_decrypt(data, *self.decrypt) + return await utils.maybe_run_in_executor(aes.ctr256_decrypt, data, len(data), self.loop, *self.decrypt) diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index 490eea54..d46751d4 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -19,13 +19,12 @@ import asyncio import logging import os -from concurrent.futures.thread import ThreadPoolExecutor from datetime import datetime, timedelta from hashlib import sha1 from io import BytesIO import pyrogram -from pyrogram import __copyright__, __license__, __version__ +from pyrogram import __copyright__, __license__, __version__, utils from pyrogram import raw from pyrogram.connection import Connection from pyrogram.crypto import mtproto @@ -51,7 +50,6 @@ class Session: MAX_RETRIES = 5 ACKS_THRESHOLD = 8 PING_INTERVAL = 5 - EXECUTOR_SIZE_THRESHOLD = 512 notice_displayed = False @@ -69,8 +67,6 @@ class Session: 64: "[64] invalid container" } - executor = ThreadPoolExecutor(2, thread_name_prefix="CryptoWorker") - def __init__( self, client: "pyrogram.Client", @@ -220,22 +216,12 @@ class Session: await self.start() async def handle_packet(self, packet): - if len(packet) <= self.EXECUTOR_SIZE_THRESHOLD: - data = mtproto.unpack( - BytesIO(packet), - self.session_id, - self.auth_key, - self.auth_key_id - ) - else: - data = await self.loop.run_in_executor( - self.executor, - mtproto.unpack, - BytesIO(packet), - self.session_id, - self.auth_key, - self.auth_key_id - ) + data = await utils.maybe_run_in_executor( + mtproto.unpack, BytesIO(packet), len(packet), self.loop, + self.session_id, + self.auth_key, + self.auth_key_id + ) messages = ( data.body.messages @@ -375,24 +361,13 @@ class Session: log.debug(f"Sent:") log.debug(message) - if len(message) <= self.EXECUTOR_SIZE_THRESHOLD: - payload = mtproto.pack( - message, - self.current_salt.salt, - self.session_id, - self.auth_key, - self.auth_key_id - ) - else: - payload = await self.loop.run_in_executor( - self.executor, - mtproto.pack, - message, - self.current_salt.salt, - self.session_id, - self.auth_key, - self.auth_key_id - ) + payload = await utils.maybe_run_in_executor( + mtproto.pack, message, len(message), self.loop, + self.current_salt.salt, + self.session_id, + self.auth_key, + self.auth_key_id + ) try: await self.connection.send(payload) diff --git a/pyrogram/utils.py b/pyrogram/utils.py index 4c444533..8b6242eb 100644 --- a/pyrogram/utils.py +++ b/pyrogram/utils.py @@ -315,3 +315,11 @@ async def parse_text_entities( "message": text, "entities": entities } + + +async def maybe_run_in_executor(func, data, length, loop, *args): + return ( + func(data, *args) + if length <= pyrogram.CRYPTO_EXECUTOR_SIZE_THRESHOLD + else await loop.run_in_executor(pyrogram.crypto_executor, func, data, *args) + )