Implement non-blocking TCP connection (#71)

* impl non-blocking open connection via direct

- Proxy typing
- proxy scheme validation
- TCP typing

* impl provide connection and protocol factory

* impl non-blocking open connection via proxy
This commit is contained in:
Artem Ukolov 2024-06-15 19:08:34 +03:00 committed by GitHub
parent 589caf4466
commit c2998466b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 179 additions and 102 deletions

View File

@ -32,7 +32,7 @@ from importlib import import_module
from io import StringIO, BytesIO
from mimetypes import MimeTypes
from pathlib import Path
from typing import Union, List, Optional, Callable, AsyncGenerator
from typing import Union, List, Optional, Callable, AsyncGenerator, Type
import pyrogram
from pyrogram import __version__, __license__
@ -53,6 +53,8 @@ from pyrogram.session import Auth, Session
from pyrogram.storage import Storage, FileStorage, MemoryStorage
from pyrogram.types import User, TermsOfService
from pyrogram.utils import ainput
from .connection import Connection
from .connection.transport import TCP, TCPAbridged
from .dispatcher import Dispatcher
from .file_id import FileId, FileType, ThumbnailSource
from .mime_types import mime_types
@ -264,7 +266,9 @@ class Client(Methods):
max_message_cache_size: int = MAX_MESSAGE_CACHE_SIZE,
storage_engine: Optional[Storage] = None,
client_platform: "enums.ClientPlatform" = enums.ClientPlatform.OTHER,
init_connection_params: Optional["raw.base.JSONValue"] = None
init_connection_params: Optional["raw.base.JSONValue"] = None,
connection_factory: Type[Connection] = Connection,
protocol_factory: Type[TCP] = TCPAbridged
):
super().__init__()
@ -299,6 +303,8 @@ class Client(Methods):
self.max_message_cache_size = max_message_cache_size
self.client_platform = client_platform
self.init_connection_params = init_connection_params
self.connection_factory = connection_factory
self.protocol_factory = protocol_factory
self.executor = ThreadPoolExecutor(self.workers, thread_name_prefix="Handler")

View File

@ -18,7 +18,7 @@
import asyncio
import logging
from typing import Optional
from typing import Optional, Type
from .transport import TCP, TCPAbridged
from ..session.internals import DataCenter
@ -29,19 +29,28 @@ log = logging.getLogger(__name__)
class Connection:
MAX_CONNECTION_ATTEMPTS = 3
def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False):
def __init__(
self,
dc_id: int,
test_mode: bool,
ipv6: bool,
proxy: dict,
media: bool = False,
protocol_factory: Type[TCP] = TCPAbridged
) -> None:
self.dc_id = dc_id
self.test_mode = test_mode
self.ipv6 = ipv6
self.proxy = proxy
self.media = media
self.protocol_factory = protocol_factory
self.address = DataCenter(dc_id, test_mode, ipv6, media)
self.protocol: TCP = None
self.protocol: Optional[TCP] = None
async def connect(self):
async def connect(self) -> None:
for i in range(Connection.MAX_CONNECTION_ATTEMPTS):
self.protocol = TCPAbridged(self.ipv6, self.proxy)
self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy)
try:
log.info("Connecting...")
@ -61,11 +70,11 @@ class Connection:
log.warning("Connection failed! Trying again...")
raise ConnectionError
async def close(self):
async def close(self) -> None:
await self.protocol.close()
log.info("Disconnected")
async def send(self, data: bytes):
async def send(self, data: bytes) -> None:
await self.protocol.send(data)
async def recv(self) -> Optional[bytes]:

View File

@ -16,7 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from .tcp import TCP
from .tcp import TCP, Proxy
from .tcp_abridged import TCPAbridged
from .tcp_abridged_o import TCPAbridgedO
from .tcp_full import TCPFull

View File

@ -21,89 +21,134 @@ import ipaddress
import logging
import socket
from concurrent.futures import ThreadPoolExecutor
from typing import Tuple, Dict, TypedDict, Optional
import socks
log = logging.getLogger(__name__)
proxy_type_by_scheme: Dict[str, int] = {
"SOCKS4": socks.SOCKS4,
"SOCKS5": socks.SOCKS5,
"HTTP": socks.HTTP,
}
class Proxy(TypedDict):
scheme: str
hostname: str
port: int
username: Optional[str]
password: Optional[str]
class TCP:
TIMEOUT = 10
def __init__(self, ipv6: bool, proxy: dict):
self.socket = None
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
self.ipv6 = ipv6
self.proxy = proxy
self.reader = None
self.writer = None
self.reader: Optional[asyncio.StreamReader] = None
self.writer: Optional[asyncio.StreamWriter] = None
self.lock = asyncio.Lock()
self.loop = asyncio.get_event_loop()
self.proxy = proxy
async def _connect_via_proxy(
self,
destination: Tuple[str, int]
) -> None:
scheme = self.proxy.get("scheme")
if scheme is None:
raise ValueError("No scheme specified")
if proxy:
hostname = proxy.get("hostname")
proxy_type = proxy_type_by_scheme.get(scheme.upper())
if proxy_type is None:
raise ValueError(f"Unknown proxy type {scheme}")
hostname = self.proxy.get("hostname")
port = self.proxy.get("port")
username = self.proxy.get("username")
password = self.proxy.get("password")
try:
ip_address = ipaddress.ip_address(hostname)
except ValueError:
self.socket = socks.socksocket(socket.AF_INET)
is_proxy_ipv6 = False
else:
if isinstance(ip_address, ipaddress.IPv6Address):
self.socket = socks.socksocket(socket.AF_INET6)
else:
self.socket = socks.socksocket(socket.AF_INET)
is_proxy_ipv6 = isinstance(ip_address, ipaddress.IPv6Address)
self.socket.set_proxy(
proxy_type=getattr(socks, proxy.get("scheme").upper()),
proxy_family = socket.AF_INET6 if is_proxy_ipv6 else socket.AF_INET
sock = socks.socksocket(proxy_family)
sock.set_proxy(
proxy_type=proxy_type,
addr=hostname,
port=proxy.get("port", None),
username=proxy.get("username", None),
password=proxy.get("password", None)
port=port,
username=username,
password=password
)
sock.settimeout(TCP.TIMEOUT)
await self.loop.sock_connect(
sock=sock,
address=destination
)
self.socket.settimeout(TCP.TIMEOUT)
sock.setblocking(False)
log.info("Using proxy %s", hostname)
else:
self.socket = socket.socket(
socket.AF_INET6 if ipv6
else socket.AF_INET
self.reader, self.writer = await asyncio.open_connection(
sock=sock
)
self.socket.setblocking(False)
async def _connect_via_direct(
self,
destination: Tuple[str, int]
) -> None:
host, port = destination
family = socket.AF_INET6 if self.ipv6 else socket.AF_INET
self.reader, self.writer = await asyncio.open_connection(
host=host,
port=port,
family=family
)
async def connect(self, address: tuple):
async def _connect(self, destination: Tuple[str, int]) -> None:
if self.proxy:
with ThreadPoolExecutor(1) as executor:
await self.loop.run_in_executor(executor, self.socket.connect, address)
await self._connect_via_proxy(destination)
else:
await self._connect_via_direct(destination)
async def connect(self, address: Tuple[str, int]) -> None:
try:
await asyncio.wait_for(asyncio.get_event_loop().sock_connect(self.socket, address), TCP.TIMEOUT)
await asyncio.wait_for(self._connect(address), TCP.TIMEOUT)
except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11
raise TimeoutError("Connection timed out")
self.reader, self.writer = await asyncio.open_connection(sock=self.socket)
async def close(self) -> None:
if self.writer is None:
return None
async def close(self):
try:
if self.writer is not None:
self.writer.close()
await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT)
except Exception as e:
log.info("Close exception: %s %s", type(e).__name__, e)
async def send(self, data: bytes):
async def send(self, data: bytes) -> None:
if self.writer is None:
return None
async with self.lock:
try:
if self.writer is not None:
self.writer.write(data)
await self.writer.drain()
except Exception as e:
log.info("Send exception: %s %s", type(e).__name__, e)
raise OSError(e)
async def recv(self, length: int = 0):
async def recv(self, length: int = 0) -> Optional[bytes]:
data = b""
while len(data) < length:

View File

@ -17,22 +17,22 @@
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import logging
from typing import Optional
from typing import Optional, Tuple
from .tcp import TCP
from .tcp import TCP, Proxy
log = logging.getLogger(__name__)
class TCPAbridged(TCP):
def __init__(self, ipv6: bool, proxy: dict):
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy)
async def connect(self, address: tuple):
async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address)
await super().send(b"\xef")
async def send(self, data: bytes, *args):
async def send(self, data: bytes, *args) -> None:
length = len(data) // 4
await super().send(

View File

@ -18,11 +18,11 @@
import logging
import os
from typing import Optional
from typing import Optional, Tuple
import pyrogram
from pyrogram.crypto import aes
from .tcp import TCP
from .tcp import TCP, Proxy
log = logging.getLogger(__name__)
@ -30,13 +30,13 @@ log = logging.getLogger(__name__)
class TCPAbridgedO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
def __init__(self, ipv6: bool, proxy: dict):
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy)
self.encrypt = None
self.decrypt = None
async def connect(self, address: tuple):
async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address)
while True:
@ -55,7 +55,7 @@ class TCPAbridgedO(TCP):
await super().send(nonce)
async def send(self, data: bytes, *args):
async def send(self, data: bytes, *args) -> None:
length = len(data) // 4
data = (bytes([length]) if length <= 126 else b"\x7f" + length.to_bytes(3, "little")) + data
payload = await self.loop.run_in_executor(pyrogram.crypto_executor, aes.ctr256_encrypt, data, *self.encrypt)

View File

@ -19,24 +19,24 @@
import logging
from binascii import crc32
from struct import pack, unpack
from typing import Optional
from typing import Optional, Tuple
from .tcp import TCP
from .tcp import TCP, Proxy
log = logging.getLogger(__name__)
class TCPFull(TCP):
def __init__(self, ipv6: bool, proxy: dict):
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy)
self.seq_no = None
self.seq_no: Optional[int] = None
async def connect(self, address: tuple):
async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address)
self.seq_no = 0
async def send(self, data: bytes, *args):
async def send(self, data: bytes, *args) -> None:
data = pack("<II", len(data) + 12, self.seq_no) + data
data += pack("<I", crc32(data))
self.seq_no += 1

View File

@ -18,22 +18,22 @@
import logging
from struct import pack, unpack
from typing import Optional
from typing import Optional, Tuple
from .tcp import TCP
from .tcp import TCP, Proxy
log = logging.getLogger(__name__)
class TCPIntermediate(TCP):
def __init__(self, ipv6: bool, proxy: dict):
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy)
async def connect(self, address: tuple):
async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address)
await super().send(b"\xee" * 4)
async def send(self, data: bytes, *args):
async def send(self, data: bytes, *args) -> None:
await super().send(pack("<i", len(data)) + data)
async def recv(self, length: int = 0) -> Optional[bytes]:

View File

@ -19,10 +19,10 @@
import logging
import os
from struct import pack, unpack
from typing import Optional
from typing import Optional, Tuple
from pyrogram.crypto import aes
from .tcp import TCP
from .tcp import TCP, Proxy
log = logging.getLogger(__name__)
@ -30,13 +30,13 @@ log = logging.getLogger(__name__)
class TCPIntermediateO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
def __init__(self, ipv6: bool, proxy: dict):
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy)
self.encrypt = None
self.decrypt = None
async def connect(self, address: tuple):
async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address)
while True:
@ -55,7 +55,7 @@ class TCPIntermediateO(TCP):
await super().send(nonce)
async def send(self, data: bytes, *args):
async def send(self, data: bytes, *args) -> None:
await super().send(
aes.ctr256_encrypt(
pack("<i", len(data)) + data,

View File

@ -22,6 +22,7 @@ import time
from hashlib import sha1
from io import BytesIO
from os import urandom
from typing import Optional
import pyrogram
from pyrogram import raw
@ -37,13 +38,20 @@ log = logging.getLogger(__name__)
class Auth:
MAX_RETRIES = 5
def __init__(self, client: "pyrogram.Client", dc_id: int, test_mode: bool):
def __init__(
self,
client: "pyrogram.Client",
dc_id: int,
test_mode: bool
):
self.dc_id = dc_id
self.test_mode = test_mode
self.ipv6 = client.ipv6
self.proxy = client.proxy
self.connection_factory = client.connection_factory
self.protocol_factory = client.protocol_factory
self.connection = None
self.connection: Optional[Connection] = None
@staticmethod
def pack(data: TLObject) -> bytes:
@ -76,7 +84,14 @@ class Auth:
# The server may close the connection at any time, causing the auth key creation to fail.
# If that happens, just try again up to MAX_RETRIES times.
while True:
self.connection = Connection(self.dc_id, self.test_mode, self.ipv6, self.proxy)
self.connection = self.connection_factory(
dc_id=self.dc_id,
test_mode=self.test_mode,
ipv6=self.ipv6,
proxy=self.proxy,
media=False,
protocol_factory=self.protocol_factory
)
try:
log.info("Start creating a new auth key on DC%s", self.dc_id)

View File

@ -22,6 +22,7 @@ import logging
import os
from hashlib import sha1
from io import BytesIO
from typing import Optional
import pyrogram
from pyrogram import raw
@ -75,7 +76,7 @@ class Session:
self.is_media = is_media
self.is_cdn = is_cdn
self.connection = None
self.connection: Optional[Connection] = None
self.auth_key_id = sha1(auth_key).digest()[-8:]
@ -101,12 +102,13 @@ class Session:
async def start(self):
while True:
self.connection = Connection(
self.dc_id,
self.test_mode,
self.client.ipv6,
self.client.proxy,
self.is_media
self.connection = self.client.connection_factory(
dc_id=self.dc_id,
test_mode=self.test_mode,
ipv6=self.client.ipv6,
proxy=self.client.proxy,
media=self.is_media,
protocol_factory=self.client.protocol_factory
)
try: