mirror of
https://github.com/TeamPGM/pyrogram.git
synced 2024-11-27 16:45:19 +00:00
Move crypto calls to threads in case of big enough chunks
This commit is contained in:
parent
521e403f92
commit
844e53a70e
@ -20,6 +20,8 @@ __version__ = "1.0.7"
|
||||
__license__ = "GNU Lesser General Public License v3 or later (LGPLv3+)"
|
||||
__copyright__ = "Copyright (C) 2017-2020 Dan <https://github.com/delivrance>"
|
||||
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
|
||||
|
||||
class StopTransmission(StopAsyncIteration):
|
||||
pass
|
||||
@ -41,3 +43,7 @@ from .sync import idle
|
||||
|
||||
# Save the main thread loop for future references
|
||||
main_event_loop = asyncio.get_event_loop()
|
||||
|
||||
CRYPTO_EXECUTOR_SIZE_THRESHOLD = 512
|
||||
|
||||
crypto_executor = ThreadPoolExecutor(2, thread_name_prefix="CryptoWorker")
|
||||
|
@ -45,6 +45,7 @@ class TCP:
|
||||
self.writer = None # type: asyncio.StreamWriter
|
||||
|
||||
self.lock = asyncio.Lock()
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
if proxy.get("enabled", False):
|
||||
hostname = proxy.get("hostname", None)
|
||||
|
@ -19,6 +19,7 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from pyrogram import utils
|
||||
from pyrogram.crypto import aes
|
||||
from .tcp import TCP
|
||||
|
||||
@ -55,16 +56,10 @@ class TCPAbridgedO(TCP):
|
||||
|
||||
async def send(self, data: bytes, *args):
|
||||
length = len(data) // 4
|
||||
data = (bytes([length]) if length <= 126 else b"\x7f" + length.to_bytes(3, "little")) + data
|
||||
payload = await utils.maybe_run_in_executor(aes.ctr256_encrypt, data, len(data), self.loop, *self.encrypt)
|
||||
|
||||
await super().send(
|
||||
aes.ctr256_encrypt(
|
||||
(bytes([length])
|
||||
if length <= 126
|
||||
else b"\x7f" + length.to_bytes(3, "little"))
|
||||
+ data,
|
||||
*self.encrypt
|
||||
)
|
||||
)
|
||||
await super().send(payload)
|
||||
|
||||
async def recv(self, length: int = 0) -> bytes or None:
|
||||
length = await super().recv(1)
|
||||
@ -87,4 +82,4 @@ class TCPAbridgedO(TCP):
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
return aes.ctr256_decrypt(data, *self.decrypt)
|
||||
return await utils.maybe_run_in_executor(aes.ctr256_decrypt, data, len(data), self.loop, *self.decrypt)
|
||||
|
@ -19,13 +19,12 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from datetime import datetime, timedelta
|
||||
from hashlib import sha1
|
||||
from io import BytesIO
|
||||
|
||||
import pyrogram
|
||||
from pyrogram import __copyright__, __license__, __version__
|
||||
from pyrogram import __copyright__, __license__, __version__, utils
|
||||
from pyrogram import raw
|
||||
from pyrogram.connection import Connection
|
||||
from pyrogram.crypto import mtproto
|
||||
@ -51,7 +50,6 @@ class Session:
|
||||
MAX_RETRIES = 5
|
||||
ACKS_THRESHOLD = 8
|
||||
PING_INTERVAL = 5
|
||||
EXECUTOR_SIZE_THRESHOLD = 512
|
||||
|
||||
notice_displayed = False
|
||||
|
||||
@ -69,8 +67,6 @@ class Session:
|
||||
64: "[64] invalid container"
|
||||
}
|
||||
|
||||
executor = ThreadPoolExecutor(2, thread_name_prefix="CryptoWorker")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: "pyrogram.Client",
|
||||
@ -220,22 +216,12 @@ class Session:
|
||||
await self.start()
|
||||
|
||||
async def handle_packet(self, packet):
|
||||
if len(packet) <= self.EXECUTOR_SIZE_THRESHOLD:
|
||||
data = mtproto.unpack(
|
||||
BytesIO(packet),
|
||||
self.session_id,
|
||||
self.auth_key,
|
||||
self.auth_key_id
|
||||
)
|
||||
else:
|
||||
data = await self.loop.run_in_executor(
|
||||
self.executor,
|
||||
mtproto.unpack,
|
||||
BytesIO(packet),
|
||||
self.session_id,
|
||||
self.auth_key,
|
||||
self.auth_key_id
|
||||
)
|
||||
data = await utils.maybe_run_in_executor(
|
||||
mtproto.unpack, BytesIO(packet), len(packet), self.loop,
|
||||
self.session_id,
|
||||
self.auth_key,
|
||||
self.auth_key_id
|
||||
)
|
||||
|
||||
messages = (
|
||||
data.body.messages
|
||||
@ -375,24 +361,13 @@ class Session:
|
||||
log.debug(f"Sent:")
|
||||
log.debug(message)
|
||||
|
||||
if len(message) <= self.EXECUTOR_SIZE_THRESHOLD:
|
||||
payload = mtproto.pack(
|
||||
message,
|
||||
self.current_salt.salt,
|
||||
self.session_id,
|
||||
self.auth_key,
|
||||
self.auth_key_id
|
||||
)
|
||||
else:
|
||||
payload = await self.loop.run_in_executor(
|
||||
self.executor,
|
||||
mtproto.pack,
|
||||
message,
|
||||
self.current_salt.salt,
|
||||
self.session_id,
|
||||
self.auth_key,
|
||||
self.auth_key_id
|
||||
)
|
||||
payload = await utils.maybe_run_in_executor(
|
||||
mtproto.pack, message, len(message), self.loop,
|
||||
self.current_salt.salt,
|
||||
self.session_id,
|
||||
self.auth_key,
|
||||
self.auth_key_id
|
||||
)
|
||||
|
||||
try:
|
||||
await self.connection.send(payload)
|
||||
|
@ -315,3 +315,11 @@ async def parse_text_entities(
|
||||
"message": text,
|
||||
"entities": entities
|
||||
}
|
||||
|
||||
|
||||
async def maybe_run_in_executor(func, data, length, loop, *args):
|
||||
return (
|
||||
func(data, *args)
|
||||
if length <= pyrogram.CRYPTO_EXECUTOR_SIZE_THRESHOLD
|
||||
else await loop.run_in_executor(pyrogram.crypto_executor, func, data, *args)
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user