Move crypto calls to threads in case of big enough chunks

This commit is contained in:
Dan 2020-12-07 19:16:46 +01:00
parent 521e403f92
commit 844e53a70e
5 changed files with 34 additions and 49 deletions

View File

@ -20,6 +20,8 @@ __version__ = "1.0.7"
__license__ = "GNU Lesser General Public License v3 or later (LGPLv3+)" __license__ = "GNU Lesser General Public License v3 or later (LGPLv3+)"
__copyright__ = "Copyright (C) 2017-2020 Dan <https://github.com/delivrance>" __copyright__ = "Copyright (C) 2017-2020 Dan <https://github.com/delivrance>"
from concurrent.futures.thread import ThreadPoolExecutor
class StopTransmission(StopAsyncIteration): class StopTransmission(StopAsyncIteration):
pass pass
@ -41,3 +43,7 @@ from .sync import idle
# Save the main thread loop for future references # Save the main thread loop for future references
main_event_loop = asyncio.get_event_loop() main_event_loop = asyncio.get_event_loop()
CRYPTO_EXECUTOR_SIZE_THRESHOLD = 512
crypto_executor = ThreadPoolExecutor(2, thread_name_prefix="CryptoWorker")

View File

@ -45,6 +45,7 @@ class TCP:
self.writer = None # type: asyncio.StreamWriter self.writer = None # type: asyncio.StreamWriter
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
self.loop = asyncio.get_event_loop()
if proxy.get("enabled", False): if proxy.get("enabled", False):
hostname = proxy.get("hostname", None) hostname = proxy.get("hostname", None)

View File

@ -19,6 +19,7 @@
import logging import logging
import os import os
from pyrogram import utils
from pyrogram.crypto import aes from pyrogram.crypto import aes
from .tcp import TCP from .tcp import TCP
@ -55,16 +56,10 @@ class TCPAbridgedO(TCP):
async def send(self, data: bytes, *args): async def send(self, data: bytes, *args):
length = len(data) // 4 length = len(data) // 4
data = (bytes([length]) if length <= 126 else b"\x7f" + length.to_bytes(3, "little")) + data
payload = await utils.maybe_run_in_executor(aes.ctr256_encrypt, data, len(data), self.loop, *self.encrypt)
await super().send( await super().send(payload)
aes.ctr256_encrypt(
(bytes([length])
if length <= 126
else b"\x7f" + length.to_bytes(3, "little"))
+ data,
*self.encrypt
)
)
async def recv(self, length: int = 0) -> bytes or None: async def recv(self, length: int = 0) -> bytes or None:
length = await super().recv(1) length = await super().recv(1)
@ -87,4 +82,4 @@ class TCPAbridgedO(TCP):
if data is None: if data is None:
return None return None
return aes.ctr256_decrypt(data, *self.decrypt) return await utils.maybe_run_in_executor(aes.ctr256_decrypt, data, len(data), self.loop, *self.decrypt)

View File

@ -19,13 +19,12 @@
import asyncio import asyncio
import logging import logging
import os import os
from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime, timedelta from datetime import datetime, timedelta
from hashlib import sha1 from hashlib import sha1
from io import BytesIO from io import BytesIO
import pyrogram import pyrogram
from pyrogram import __copyright__, __license__, __version__ from pyrogram import __copyright__, __license__, __version__, utils
from pyrogram import raw from pyrogram import raw
from pyrogram.connection import Connection from pyrogram.connection import Connection
from pyrogram.crypto import mtproto from pyrogram.crypto import mtproto
@ -51,7 +50,6 @@ class Session:
MAX_RETRIES = 5 MAX_RETRIES = 5
ACKS_THRESHOLD = 8 ACKS_THRESHOLD = 8
PING_INTERVAL = 5 PING_INTERVAL = 5
EXECUTOR_SIZE_THRESHOLD = 512
notice_displayed = False notice_displayed = False
@ -69,8 +67,6 @@ class Session:
64: "[64] invalid container" 64: "[64] invalid container"
} }
executor = ThreadPoolExecutor(2, thread_name_prefix="CryptoWorker")
def __init__( def __init__(
self, self,
client: "pyrogram.Client", client: "pyrogram.Client",
@ -220,22 +216,12 @@ class Session:
await self.start() await self.start()
async def handle_packet(self, packet): async def handle_packet(self, packet):
if len(packet) <= self.EXECUTOR_SIZE_THRESHOLD: data = await utils.maybe_run_in_executor(
data = mtproto.unpack( mtproto.unpack, BytesIO(packet), len(packet), self.loop,
BytesIO(packet), self.session_id,
self.session_id, self.auth_key,
self.auth_key, self.auth_key_id
self.auth_key_id )
)
else:
data = await self.loop.run_in_executor(
self.executor,
mtproto.unpack,
BytesIO(packet),
self.session_id,
self.auth_key,
self.auth_key_id
)
messages = ( messages = (
data.body.messages data.body.messages
@ -375,24 +361,13 @@ class Session:
log.debug(f"Sent:") log.debug(f"Sent:")
log.debug(message) log.debug(message)
if len(message) <= self.EXECUTOR_SIZE_THRESHOLD: payload = await utils.maybe_run_in_executor(
payload = mtproto.pack( mtproto.pack, message, len(message), self.loop,
message, self.current_salt.salt,
self.current_salt.salt, self.session_id,
self.session_id, self.auth_key,
self.auth_key, self.auth_key_id
self.auth_key_id )
)
else:
payload = await self.loop.run_in_executor(
self.executor,
mtproto.pack,
message,
self.current_salt.salt,
self.session_id,
self.auth_key,
self.auth_key_id
)
try: try:
await self.connection.send(payload) await self.connection.send(payload)

View File

@ -315,3 +315,11 @@ async def parse_text_entities(
"message": text, "message": text,
"entities": entities "entities": entities
} }
async def maybe_run_in_executor(func, data, length, loop, *args):
return (
func(data, *args)
if length <= pyrogram.CRYPTO_EXECUTOR_SIZE_THRESHOLD
else await loop.run_in_executor(pyrogram.crypto_executor, func, data, *args)
)