Move crypto calls to threads in case of big enough chunks
This commit is contained in:
parent
521e403f92
commit
844e53a70e
@ -20,6 +20,8 @@ __version__ = "1.0.7"
|
|||||||
__license__ = "GNU Lesser General Public License v3 or later (LGPLv3+)"
|
__license__ = "GNU Lesser General Public License v3 or later (LGPLv3+)"
|
||||||
__copyright__ = "Copyright (C) 2017-2020 Dan <https://github.com/delivrance>"
|
__copyright__ = "Copyright (C) 2017-2020 Dan <https://github.com/delivrance>"
|
||||||
|
|
||||||
|
from concurrent.futures.thread import ThreadPoolExecutor
|
||||||
|
|
||||||
|
|
||||||
class StopTransmission(StopAsyncIteration):
|
class StopTransmission(StopAsyncIteration):
|
||||||
pass
|
pass
|
||||||
@ -41,3 +43,7 @@ from .sync import idle
|
|||||||
|
|
||||||
# Save the main thread loop for future references
|
# Save the main thread loop for future references
|
||||||
main_event_loop = asyncio.get_event_loop()
|
main_event_loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
CRYPTO_EXECUTOR_SIZE_THRESHOLD = 512
|
||||||
|
|
||||||
|
crypto_executor = ThreadPoolExecutor(2, thread_name_prefix="CryptoWorker")
|
||||||
|
@ -45,6 +45,7 @@ class TCP:
|
|||||||
self.writer = None # type: asyncio.StreamWriter
|
self.writer = None # type: asyncio.StreamWriter
|
||||||
|
|
||||||
self.lock = asyncio.Lock()
|
self.lock = asyncio.Lock()
|
||||||
|
self.loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
if proxy.get("enabled", False):
|
if proxy.get("enabled", False):
|
||||||
hostname = proxy.get("hostname", None)
|
hostname = proxy.get("hostname", None)
|
||||||
|
@ -19,6 +19,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from pyrogram import utils
|
||||||
from pyrogram.crypto import aes
|
from pyrogram.crypto import aes
|
||||||
from .tcp import TCP
|
from .tcp import TCP
|
||||||
|
|
||||||
@ -55,16 +56,10 @@ class TCPAbridgedO(TCP):
|
|||||||
|
|
||||||
async def send(self, data: bytes, *args):
|
async def send(self, data: bytes, *args):
|
||||||
length = len(data) // 4
|
length = len(data) // 4
|
||||||
|
data = (bytes([length]) if length <= 126 else b"\x7f" + length.to_bytes(3, "little")) + data
|
||||||
|
payload = await utils.maybe_run_in_executor(aes.ctr256_encrypt, data, len(data), self.loop, *self.encrypt)
|
||||||
|
|
||||||
await super().send(
|
await super().send(payload)
|
||||||
aes.ctr256_encrypt(
|
|
||||||
(bytes([length])
|
|
||||||
if length <= 126
|
|
||||||
else b"\x7f" + length.to_bytes(3, "little"))
|
|
||||||
+ data,
|
|
||||||
*self.encrypt
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def recv(self, length: int = 0) -> bytes or None:
|
async def recv(self, length: int = 0) -> bytes or None:
|
||||||
length = await super().recv(1)
|
length = await super().recv(1)
|
||||||
@ -87,4 +82,4 @@ class TCPAbridgedO(TCP):
|
|||||||
if data is None:
|
if data is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return aes.ctr256_decrypt(data, *self.decrypt)
|
return await utils.maybe_run_in_executor(aes.ctr256_decrypt, data, len(data), self.loop, *self.decrypt)
|
||||||
|
@ -19,13 +19,12 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from concurrent.futures.thread import ThreadPoolExecutor
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from hashlib import sha1
|
from hashlib import sha1
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
import pyrogram
|
import pyrogram
|
||||||
from pyrogram import __copyright__, __license__, __version__
|
from pyrogram import __copyright__, __license__, __version__, utils
|
||||||
from pyrogram import raw
|
from pyrogram import raw
|
||||||
from pyrogram.connection import Connection
|
from pyrogram.connection import Connection
|
||||||
from pyrogram.crypto import mtproto
|
from pyrogram.crypto import mtproto
|
||||||
@ -51,7 +50,6 @@ class Session:
|
|||||||
MAX_RETRIES = 5
|
MAX_RETRIES = 5
|
||||||
ACKS_THRESHOLD = 8
|
ACKS_THRESHOLD = 8
|
||||||
PING_INTERVAL = 5
|
PING_INTERVAL = 5
|
||||||
EXECUTOR_SIZE_THRESHOLD = 512
|
|
||||||
|
|
||||||
notice_displayed = False
|
notice_displayed = False
|
||||||
|
|
||||||
@ -69,8 +67,6 @@ class Session:
|
|||||||
64: "[64] invalid container"
|
64: "[64] invalid container"
|
||||||
}
|
}
|
||||||
|
|
||||||
executor = ThreadPoolExecutor(2, thread_name_prefix="CryptoWorker")
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
client: "pyrogram.Client",
|
client: "pyrogram.Client",
|
||||||
@ -220,18 +216,8 @@ class Session:
|
|||||||
await self.start()
|
await self.start()
|
||||||
|
|
||||||
async def handle_packet(self, packet):
|
async def handle_packet(self, packet):
|
||||||
if len(packet) <= self.EXECUTOR_SIZE_THRESHOLD:
|
data = await utils.maybe_run_in_executor(
|
||||||
data = mtproto.unpack(
|
mtproto.unpack, BytesIO(packet), len(packet), self.loop,
|
||||||
BytesIO(packet),
|
|
||||||
self.session_id,
|
|
||||||
self.auth_key,
|
|
||||||
self.auth_key_id
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
data = await self.loop.run_in_executor(
|
|
||||||
self.executor,
|
|
||||||
mtproto.unpack,
|
|
||||||
BytesIO(packet),
|
|
||||||
self.session_id,
|
self.session_id,
|
||||||
self.auth_key,
|
self.auth_key,
|
||||||
self.auth_key_id
|
self.auth_key_id
|
||||||
@ -375,19 +361,8 @@ class Session:
|
|||||||
log.debug(f"Sent:")
|
log.debug(f"Sent:")
|
||||||
log.debug(message)
|
log.debug(message)
|
||||||
|
|
||||||
if len(message) <= self.EXECUTOR_SIZE_THRESHOLD:
|
payload = await utils.maybe_run_in_executor(
|
||||||
payload = mtproto.pack(
|
mtproto.pack, message, len(message), self.loop,
|
||||||
message,
|
|
||||||
self.current_salt.salt,
|
|
||||||
self.session_id,
|
|
||||||
self.auth_key,
|
|
||||||
self.auth_key_id
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
payload = await self.loop.run_in_executor(
|
|
||||||
self.executor,
|
|
||||||
mtproto.pack,
|
|
||||||
message,
|
|
||||||
self.current_salt.salt,
|
self.current_salt.salt,
|
||||||
self.session_id,
|
self.session_id,
|
||||||
self.auth_key,
|
self.auth_key,
|
||||||
|
@ -315,3 +315,11 @@ async def parse_text_entities(
|
|||||||
"message": text,
|
"message": text,
|
||||||
"entities": entities
|
"entities": entities
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def maybe_run_in_executor(func, data, length, loop, *args):
|
||||||
|
return (
|
||||||
|
func(data, *args)
|
||||||
|
if length <= pyrogram.CRYPTO_EXECUTOR_SIZE_THRESHOLD
|
||||||
|
else await loop.run_in_executor(pyrogram.crypto_executor, func, data, *args)
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user