Implement non-blocking TCP connection (#71)

* impl non-blocking open connection via direct

- Proxy typing
- proxy scheme validation
- TCP typing

* impl provide connection and protocol factory

* impl non-blocking open connection via proxy
This commit is contained in:
Artem Ukolov 2024-06-15 19:08:34 +03:00 committed by GitHub
parent 589caf4466
commit c2998466b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 179 additions and 102 deletions

View File

@ -32,7 +32,7 @@ from importlib import import_module
from io import StringIO, BytesIO from io import StringIO, BytesIO
from mimetypes import MimeTypes from mimetypes import MimeTypes
from pathlib import Path from pathlib import Path
from typing import Union, List, Optional, Callable, AsyncGenerator from typing import Union, List, Optional, Callable, AsyncGenerator, Type
import pyrogram import pyrogram
from pyrogram import __version__, __license__ from pyrogram import __version__, __license__
@ -53,6 +53,8 @@ from pyrogram.session import Auth, Session
from pyrogram.storage import Storage, FileStorage, MemoryStorage from pyrogram.storage import Storage, FileStorage, MemoryStorage
from pyrogram.types import User, TermsOfService from pyrogram.types import User, TermsOfService
from pyrogram.utils import ainput from pyrogram.utils import ainput
from .connection import Connection
from .connection.transport import TCP, TCPAbridged
from .dispatcher import Dispatcher from .dispatcher import Dispatcher
from .file_id import FileId, FileType, ThumbnailSource from .file_id import FileId, FileType, ThumbnailSource
from .mime_types import mime_types from .mime_types import mime_types
@ -264,7 +266,9 @@ class Client(Methods):
max_message_cache_size: int = MAX_MESSAGE_CACHE_SIZE, max_message_cache_size: int = MAX_MESSAGE_CACHE_SIZE,
storage_engine: Optional[Storage] = None, storage_engine: Optional[Storage] = None,
client_platform: "enums.ClientPlatform" = enums.ClientPlatform.OTHER, client_platform: "enums.ClientPlatform" = enums.ClientPlatform.OTHER,
init_connection_params: Optional["raw.base.JSONValue"] = None init_connection_params: Optional["raw.base.JSONValue"] = None,
connection_factory: Type[Connection] = Connection,
protocol_factory: Type[TCP] = TCPAbridged
): ):
super().__init__() super().__init__()
@ -299,6 +303,8 @@ class Client(Methods):
self.max_message_cache_size = max_message_cache_size self.max_message_cache_size = max_message_cache_size
self.client_platform = client_platform self.client_platform = client_platform
self.init_connection_params = init_connection_params self.init_connection_params = init_connection_params
self.connection_factory = connection_factory
self.protocol_factory = protocol_factory
self.executor = ThreadPoolExecutor(self.workers, thread_name_prefix="Handler") self.executor = ThreadPoolExecutor(self.workers, thread_name_prefix="Handler")

View File

@ -18,7 +18,7 @@
import asyncio import asyncio
import logging import logging
from typing import Optional from typing import Optional, Type
from .transport import TCP, TCPAbridged from .transport import TCP, TCPAbridged
from ..session.internals import DataCenter from ..session.internals import DataCenter
@ -29,19 +29,28 @@ log = logging.getLogger(__name__)
class Connection: class Connection:
MAX_CONNECTION_ATTEMPTS = 3 MAX_CONNECTION_ATTEMPTS = 3
def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False): def __init__(
self,
dc_id: int,
test_mode: bool,
ipv6: bool,
proxy: dict,
media: bool = False,
protocol_factory: Type[TCP] = TCPAbridged
) -> None:
self.dc_id = dc_id self.dc_id = dc_id
self.test_mode = test_mode self.test_mode = test_mode
self.ipv6 = ipv6 self.ipv6 = ipv6
self.proxy = proxy self.proxy = proxy
self.media = media self.media = media
self.protocol_factory = protocol_factory
self.address = DataCenter(dc_id, test_mode, ipv6, media) self.address = DataCenter(dc_id, test_mode, ipv6, media)
self.protocol: TCP = None self.protocol: Optional[TCP] = None
async def connect(self): async def connect(self) -> None:
for i in range(Connection.MAX_CONNECTION_ATTEMPTS): for i in range(Connection.MAX_CONNECTION_ATTEMPTS):
self.protocol = TCPAbridged(self.ipv6, self.proxy) self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy)
try: try:
log.info("Connecting...") log.info("Connecting...")
@ -61,11 +70,11 @@ class Connection:
log.warning("Connection failed! Trying again...") log.warning("Connection failed! Trying again...")
raise ConnectionError raise ConnectionError
async def close(self): async def close(self) -> None:
await self.protocol.close() await self.protocol.close()
log.info("Disconnected") log.info("Disconnected")
async def send(self, data: bytes): async def send(self, data: bytes) -> None:
await self.protocol.send(data) await self.protocol.send(data)
async def recv(self) -> Optional[bytes]: async def recv(self) -> Optional[bytes]:

View File

@ -16,7 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from .tcp import TCP from .tcp import TCP, Proxy
from .tcp_abridged import TCPAbridged from .tcp_abridged import TCPAbridged
from .tcp_abridged_o import TCPAbridgedO from .tcp_abridged_o import TCPAbridgedO
from .tcp_full import TCPFull from .tcp_full import TCPFull

View File

@ -21,89 +21,134 @@ import ipaddress
import logging import logging
import socket import socket
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Tuple, Dict, TypedDict, Optional
import socks import socks
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
proxy_type_by_scheme: Dict[str, int] = {
"SOCKS4": socks.SOCKS4,
"SOCKS5": socks.SOCKS5,
"HTTP": socks.HTTP,
}
class Proxy(TypedDict):
scheme: str
hostname: str
port: int
username: Optional[str]
password: Optional[str]
class TCP: class TCP:
TIMEOUT = 10 TIMEOUT = 10
def __init__(self, ipv6: bool, proxy: dict): def __init__(self, ipv6: bool, proxy: Proxy) -> None:
self.socket = None self.ipv6 = ipv6
self.proxy = proxy
self.reader = None self.reader: Optional[asyncio.StreamReader] = None
self.writer = None self.writer: Optional[asyncio.StreamWriter] = None
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self.proxy = proxy async def _connect_via_proxy(
self,
destination: Tuple[str, int]
) -> None:
scheme = self.proxy.get("scheme")
if scheme is None:
raise ValueError("No scheme specified")
if proxy: proxy_type = proxy_type_by_scheme.get(scheme.upper())
hostname = proxy.get("hostname") if proxy_type is None:
raise ValueError(f"Unknown proxy type {scheme}")
try: hostname = self.proxy.get("hostname")
ip_address = ipaddress.ip_address(hostname) port = self.proxy.get("port")
except ValueError: username = self.proxy.get("username")
self.socket = socks.socksocket(socket.AF_INET) password = self.proxy.get("password")
else:
if isinstance(ip_address, ipaddress.IPv6Address):
self.socket = socks.socksocket(socket.AF_INET6)
else:
self.socket = socks.socksocket(socket.AF_INET)
self.socket.set_proxy(
proxy_type=getattr(socks, proxy.get("scheme").upper()),
addr=hostname,
port=proxy.get("port", None),
username=proxy.get("username", None),
password=proxy.get("password", None)
)
self.socket.settimeout(TCP.TIMEOUT)
log.info("Using proxy %s", hostname)
else:
self.socket = socket.socket(
socket.AF_INET6 if ipv6
else socket.AF_INET
)
self.socket.setblocking(False)
async def connect(self, address: tuple):
if self.proxy:
with ThreadPoolExecutor(1) as executor:
await self.loop.run_in_executor(executor, self.socket.connect, address)
else:
try:
await asyncio.wait_for(asyncio.get_event_loop().sock_connect(self.socket, address), TCP.TIMEOUT)
except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11
raise TimeoutError("Connection timed out")
self.reader, self.writer = await asyncio.open_connection(sock=self.socket)
async def close(self):
try: try:
if self.writer is not None: ip_address = ipaddress.ip_address(hostname)
self.writer.close() except ValueError:
await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT) is_proxy_ipv6 = False
else:
is_proxy_ipv6 = isinstance(ip_address, ipaddress.IPv6Address)
proxy_family = socket.AF_INET6 if is_proxy_ipv6 else socket.AF_INET
sock = socks.socksocket(proxy_family)
sock.set_proxy(
proxy_type=proxy_type,
addr=hostname,
port=port,
username=username,
password=password
)
sock.settimeout(TCP.TIMEOUT)
await self.loop.sock_connect(
sock=sock,
address=destination
)
sock.setblocking(False)
self.reader, self.writer = await asyncio.open_connection(
sock=sock
)
async def _connect_via_direct(
self,
destination: Tuple[str, int]
) -> None:
host, port = destination
family = socket.AF_INET6 if self.ipv6 else socket.AF_INET
self.reader, self.writer = await asyncio.open_connection(
host=host,
port=port,
family=family
)
async def _connect(self, destination: Tuple[str, int]) -> None:
if self.proxy:
await self._connect_via_proxy(destination)
else:
await self._connect_via_direct(destination)
async def connect(self, address: Tuple[str, int]) -> None:
try:
await asyncio.wait_for(self._connect(address), TCP.TIMEOUT)
except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11
raise TimeoutError("Connection timed out")
async def close(self) -> None:
if self.writer is None:
return None
try:
self.writer.close()
await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT)
except Exception as e: except Exception as e:
log.info("Close exception: %s %s", type(e).__name__, e) log.info("Close exception: %s %s", type(e).__name__, e)
async def send(self, data: bytes): async def send(self, data: bytes) -> None:
if self.writer is None:
return None
async with self.lock: async with self.lock:
try: try:
if self.writer is not None: self.writer.write(data)
self.writer.write(data) await self.writer.drain()
await self.writer.drain()
except Exception as e: except Exception as e:
log.info("Send exception: %s %s", type(e).__name__, e) log.info("Send exception: %s %s", type(e).__name__, e)
raise OSError(e) raise OSError(e)
async def recv(self, length: int = 0): async def recv(self, length: int = 0) -> Optional[bytes]:
data = b"" data = b""
while len(data) < length: while len(data) < length:

View File

@ -17,22 +17,22 @@
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import logging import logging
from typing import Optional from typing import Optional, Tuple
from .tcp import TCP from .tcp import TCP, Proxy
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class TCPAbridged(TCP): class TCPAbridged(TCP):
def __init__(self, ipv6: bool, proxy: dict): def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy) super().__init__(ipv6, proxy)
async def connect(self, address: tuple): async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address) await super().connect(address)
await super().send(b"\xef") await super().send(b"\xef")
async def send(self, data: bytes, *args): async def send(self, data: bytes, *args) -> None:
length = len(data) // 4 length = len(data) // 4
await super().send( await super().send(

View File

@ -18,11 +18,11 @@
import logging import logging
import os import os
from typing import Optional from typing import Optional, Tuple
import pyrogram import pyrogram
from pyrogram.crypto import aes from pyrogram.crypto import aes
from .tcp import TCP from .tcp import TCP, Proxy
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -30,13 +30,13 @@ log = logging.getLogger(__name__)
class TCPAbridgedO(TCP): class TCPAbridgedO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4) RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
def __init__(self, ipv6: bool, proxy: dict): def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy) super().__init__(ipv6, proxy)
self.encrypt = None self.encrypt = None
self.decrypt = None self.decrypt = None
async def connect(self, address: tuple): async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address) await super().connect(address)
while True: while True:
@ -55,7 +55,7 @@ class TCPAbridgedO(TCP):
await super().send(nonce) await super().send(nonce)
async def send(self, data: bytes, *args): async def send(self, data: bytes, *args) -> None:
length = len(data) // 4 length = len(data) // 4
data = (bytes([length]) if length <= 126 else b"\x7f" + length.to_bytes(3, "little")) + data data = (bytes([length]) if length <= 126 else b"\x7f" + length.to_bytes(3, "little")) + data
payload = await self.loop.run_in_executor(pyrogram.crypto_executor, aes.ctr256_encrypt, data, *self.encrypt) payload = await self.loop.run_in_executor(pyrogram.crypto_executor, aes.ctr256_encrypt, data, *self.encrypt)

View File

@ -19,24 +19,24 @@
import logging import logging
from binascii import crc32 from binascii import crc32
from struct import pack, unpack from struct import pack, unpack
from typing import Optional from typing import Optional, Tuple
from .tcp import TCP from .tcp import TCP, Proxy
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class TCPFull(TCP): class TCPFull(TCP):
def __init__(self, ipv6: bool, proxy: dict): def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy) super().__init__(ipv6, proxy)
self.seq_no = None self.seq_no: Optional[int] = None
async def connect(self, address: tuple): async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address) await super().connect(address)
self.seq_no = 0 self.seq_no = 0
async def send(self, data: bytes, *args): async def send(self, data: bytes, *args) -> None:
data = pack("<II", len(data) + 12, self.seq_no) + data data = pack("<II", len(data) + 12, self.seq_no) + data
data += pack("<I", crc32(data)) data += pack("<I", crc32(data))
self.seq_no += 1 self.seq_no += 1

View File

@ -18,22 +18,22 @@
import logging import logging
from struct import pack, unpack from struct import pack, unpack
from typing import Optional from typing import Optional, Tuple
from .tcp import TCP from .tcp import TCP, Proxy
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class TCPIntermediate(TCP): class TCPIntermediate(TCP):
def __init__(self, ipv6: bool, proxy: dict): def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy) super().__init__(ipv6, proxy)
async def connect(self, address: tuple): async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address) await super().connect(address)
await super().send(b"\xee" * 4) await super().send(b"\xee" * 4)
async def send(self, data: bytes, *args): async def send(self, data: bytes, *args) -> None:
await super().send(pack("<i", len(data)) + data) await super().send(pack("<i", len(data)) + data)
async def recv(self, length: int = 0) -> Optional[bytes]: async def recv(self, length: int = 0) -> Optional[bytes]:

View File

@ -19,10 +19,10 @@
import logging import logging
import os import os
from struct import pack, unpack from struct import pack, unpack
from typing import Optional from typing import Optional, Tuple
from pyrogram.crypto import aes from pyrogram.crypto import aes
from .tcp import TCP from .tcp import TCP, Proxy
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -30,13 +30,13 @@ log = logging.getLogger(__name__)
class TCPIntermediateO(TCP): class TCPIntermediateO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4) RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
def __init__(self, ipv6: bool, proxy: dict): def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy) super().__init__(ipv6, proxy)
self.encrypt = None self.encrypt = None
self.decrypt = None self.decrypt = None
async def connect(self, address: tuple): async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address) await super().connect(address)
while True: while True:
@ -55,7 +55,7 @@ class TCPIntermediateO(TCP):
await super().send(nonce) await super().send(nonce)
async def send(self, data: bytes, *args): async def send(self, data: bytes, *args) -> None:
await super().send( await super().send(
aes.ctr256_encrypt( aes.ctr256_encrypt(
pack("<i", len(data)) + data, pack("<i", len(data)) + data,

View File

@ -22,6 +22,7 @@ import time
from hashlib import sha1 from hashlib import sha1
from io import BytesIO from io import BytesIO
from os import urandom from os import urandom
from typing import Optional
import pyrogram import pyrogram
from pyrogram import raw from pyrogram import raw
@ -37,13 +38,20 @@ log = logging.getLogger(__name__)
class Auth: class Auth:
MAX_RETRIES = 5 MAX_RETRIES = 5
def __init__(self, client: "pyrogram.Client", dc_id: int, test_mode: bool): def __init__(
self,
client: "pyrogram.Client",
dc_id: int,
test_mode: bool
):
self.dc_id = dc_id self.dc_id = dc_id
self.test_mode = test_mode self.test_mode = test_mode
self.ipv6 = client.ipv6 self.ipv6 = client.ipv6
self.proxy = client.proxy self.proxy = client.proxy
self.connection_factory = client.connection_factory
self.protocol_factory = client.protocol_factory
self.connection = None self.connection: Optional[Connection] = None
@staticmethod @staticmethod
def pack(data: TLObject) -> bytes: def pack(data: TLObject) -> bytes:
@ -76,7 +84,14 @@ class Auth:
# The server may close the connection at any time, causing the auth key creation to fail. # The server may close the connection at any time, causing the auth key creation to fail.
# If that happens, just try again up to MAX_RETRIES times. # If that happens, just try again up to MAX_RETRIES times.
while True: while True:
self.connection = Connection(self.dc_id, self.test_mode, self.ipv6, self.proxy) self.connection = self.connection_factory(
dc_id=self.dc_id,
test_mode=self.test_mode,
ipv6=self.ipv6,
proxy=self.proxy,
media=False,
protocol_factory=self.protocol_factory
)
try: try:
log.info("Start creating a new auth key on DC%s", self.dc_id) log.info("Start creating a new auth key on DC%s", self.dc_id)

View File

@ -22,6 +22,7 @@ import logging
import os import os
from hashlib import sha1 from hashlib import sha1
from io import BytesIO from io import BytesIO
from typing import Optional
import pyrogram import pyrogram
from pyrogram import raw from pyrogram import raw
@ -75,7 +76,7 @@ class Session:
self.is_media = is_media self.is_media = is_media
self.is_cdn = is_cdn self.is_cdn = is_cdn
self.connection = None self.connection: Optional[Connection] = None
self.auth_key_id = sha1(auth_key).digest()[-8:] self.auth_key_id = sha1(auth_key).digest()[-8:]
@ -101,12 +102,13 @@ class Session:
async def start(self): async def start(self):
while True: while True:
self.connection = Connection( self.connection = self.client.connection_factory(
self.dc_id, dc_id=self.dc_id,
self.test_mode, test_mode=self.test_mode,
self.client.ipv6, ipv6=self.client.ipv6,
self.client.proxy, proxy=self.client.proxy,
self.is_media media=self.is_media,
protocol_factory=self.client.protocol_factory
) )
try: try: