Move crypto calls to threads in case of big enough chunks

This commit is contained in:
Dan 2020-12-07 19:16:46 +01:00
parent 521e403f92
commit 844e53a70e
5 changed files with 34 additions and 49 deletions

View File

@ -20,6 +20,8 @@ __version__ = "1.0.7"
__license__ = "GNU Lesser General Public License v3 or later (LGPLv3+)"
__copyright__ = "Copyright (C) 2017-2020 Dan <https://github.com/delivrance>"
from concurrent.futures.thread import ThreadPoolExecutor
class StopTransmission(StopAsyncIteration):
pass
@ -41,3 +43,7 @@ from .sync import idle
# Save the main thread loop for future references
main_event_loop = asyncio.get_event_loop()
CRYPTO_EXECUTOR_SIZE_THRESHOLD = 512
crypto_executor = ThreadPoolExecutor(2, thread_name_prefix="CryptoWorker")

View File

@ -45,6 +45,7 @@ class TCP:
self.writer = None # type: asyncio.StreamWriter
self.lock = asyncio.Lock()
self.loop = asyncio.get_event_loop()
if proxy.get("enabled", False):
hostname = proxy.get("hostname", None)

View File

@ -19,6 +19,7 @@
import logging
import os
from pyrogram import utils
from pyrogram.crypto import aes
from .tcp import TCP
@ -55,16 +56,10 @@ class TCPAbridgedO(TCP):
async def send(self, data: bytes, *args):
length = len(data) // 4
data = (bytes([length]) if length <= 126 else b"\x7f" + length.to_bytes(3, "little")) + data
payload = await utils.maybe_run_in_executor(aes.ctr256_encrypt, data, len(data), self.loop, *self.encrypt)
await super().send(
aes.ctr256_encrypt(
(bytes([length])
if length <= 126
else b"\x7f" + length.to_bytes(3, "little"))
+ data,
*self.encrypt
)
)
await super().send(payload)
async def recv(self, length: int = 0) -> bytes or None:
length = await super().recv(1)
@ -87,4 +82,4 @@ class TCPAbridgedO(TCP):
if data is None:
return None
return aes.ctr256_decrypt(data, *self.decrypt)
return await utils.maybe_run_in_executor(aes.ctr256_decrypt, data, len(data), self.loop, *self.decrypt)

View File

@ -19,13 +19,12 @@
import asyncio
import logging
import os
from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime, timedelta
from hashlib import sha1
from io import BytesIO
import pyrogram
from pyrogram import __copyright__, __license__, __version__
from pyrogram import __copyright__, __license__, __version__, utils
from pyrogram import raw
from pyrogram.connection import Connection
from pyrogram.crypto import mtproto
@ -51,7 +50,6 @@ class Session:
MAX_RETRIES = 5
ACKS_THRESHOLD = 8
PING_INTERVAL = 5
EXECUTOR_SIZE_THRESHOLD = 512
notice_displayed = False
@ -69,8 +67,6 @@ class Session:
64: "[64] invalid container"
}
executor = ThreadPoolExecutor(2, thread_name_prefix="CryptoWorker")
def __init__(
self,
client: "pyrogram.Client",
@ -220,18 +216,8 @@ class Session:
await self.start()
async def handle_packet(self, packet):
if len(packet) <= self.EXECUTOR_SIZE_THRESHOLD:
data = mtproto.unpack(
BytesIO(packet),
self.session_id,
self.auth_key,
self.auth_key_id
)
else:
data = await self.loop.run_in_executor(
self.executor,
mtproto.unpack,
BytesIO(packet),
data = await utils.maybe_run_in_executor(
mtproto.unpack, BytesIO(packet), len(packet), self.loop,
self.session_id,
self.auth_key,
self.auth_key_id
@ -375,19 +361,8 @@ class Session:
log.debug(f"Sent:")
log.debug(message)
if len(message) <= self.EXECUTOR_SIZE_THRESHOLD:
payload = mtproto.pack(
message,
self.current_salt.salt,
self.session_id,
self.auth_key,
self.auth_key_id
)
else:
payload = await self.loop.run_in_executor(
self.executor,
mtproto.pack,
message,
payload = await utils.maybe_run_in_executor(
mtproto.pack, message, len(message), self.loop,
self.current_salt.salt,
self.session_id,
self.auth_key,

View File

@ -315,3 +315,11 @@ async def parse_text_entities(
"message": text,
"entities": entities
}
async def maybe_run_in_executor(func, data, length, loop, *args):
return (
func(data, *args)
if length <= pyrogram.CRYPTO_EXECUTOR_SIZE_THRESHOLD
else await loop.run_in_executor(pyrogram.crypto_executor, func, data, *args)
)