mirror of
https://github.com/TeamPGM/pyrogram.git
synced 2024-11-18 13:34:54 +00:00
Merge branch 'develop' into asyncio
# Conflicts: # pyrogram/client/client.py # pyrogram/connection/connection.py # pyrogram/connection/transport/tcp/tcp.py # pyrogram/connection/transport/tcp/tcp_intermediate.py # pyrogram/session/session.py
This commit is contained in:
commit
10f3829c93
@ -62,6 +62,9 @@ Input Media
|
||||
|
||||
InputMediaPhoto
|
||||
InputMediaVideo
|
||||
InputMediaAudio
|
||||
InputMediaAnimation
|
||||
InputMediaDocument
|
||||
InputPhoneContact
|
||||
|
||||
.. User & Chats
|
||||
@ -172,5 +175,14 @@ Input Media
|
||||
.. autoclass:: InputMediaVideo
|
||||
:members:
|
||||
|
||||
.. autoclass:: InputMediaAudio
|
||||
:members:
|
||||
|
||||
.. autoclass:: InputMediaAnimation
|
||||
:members:
|
||||
|
||||
.. autoclass:: InputMediaDocument
|
||||
:members:
|
||||
|
||||
.. autoclass:: InputPhoneContact
|
||||
:members:
|
||||
|
@ -146,6 +146,7 @@ class Client(Methods, BaseClient):
|
||||
device_model: str = None,
|
||||
system_version: str = None,
|
||||
lang_code: str = None,
|
||||
ipv6: bool = False,
|
||||
proxy: dict = None,
|
||||
test_mode: bool = False,
|
||||
phone_number: str = None,
|
||||
@ -166,6 +167,7 @@ class Client(Methods, BaseClient):
|
||||
self.device_model = device_model
|
||||
self.system_version = system_version
|
||||
self.lang_code = lang_code
|
||||
self.ipv6 = ipv6
|
||||
# TODO: Make code consistent, use underscore for private/protected fields
|
||||
self._proxy = proxy
|
||||
self.test_mode = test_mode
|
||||
@ -201,7 +203,7 @@ class Client(Methods, BaseClient):
|
||||
raise ConnectionError("Client has already been started")
|
||||
|
||||
if self.BOT_TOKEN_RE.match(self.session_name):
|
||||
self.token = self.session_name
|
||||
self.bot_token = self.session_name
|
||||
self.session_name = self.session_name.split(":")[0]
|
||||
|
||||
self.load_config()
|
||||
@ -217,14 +219,14 @@ class Client(Methods, BaseClient):
|
||||
self.is_started = True
|
||||
|
||||
if self.user_id is None:
|
||||
if self.token is None:
|
||||
if self.bot_token is None:
|
||||
await self.authorize_user()
|
||||
else:
|
||||
await self.authorize_bot()
|
||||
|
||||
self.save_session()
|
||||
|
||||
if self.token is None:
|
||||
if self.bot_token is None:
|
||||
now = time.time()
|
||||
|
||||
if abs(now - self.date) > Client.OFFLINE_SLEEP:
|
||||
@ -385,14 +387,14 @@ class Client(Methods, BaseClient):
|
||||
flags=0,
|
||||
api_id=self.api_id,
|
||||
api_hash=self.api_hash,
|
||||
bot_auth_token=self.token
|
||||
bot_auth_token=self.bot_token
|
||||
)
|
||||
)
|
||||
except UserMigrate as e:
|
||||
await self.session.stop()
|
||||
|
||||
self.dc_id = e.x
|
||||
self.auth_key = await Auth(self.dc_id, self.test_mode, self._proxy).create()
|
||||
self.auth_key = await Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
|
||||
|
||||
self.session = Session(
|
||||
self,
|
||||
@ -437,7 +439,7 @@ class Client(Methods, BaseClient):
|
||||
await self.session.stop()
|
||||
|
||||
self.dc_id = e.x
|
||||
self.auth_key = await Auth(self.dc_id, self.test_mode, self._proxy).create()
|
||||
self.auth_key = await Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
|
||||
|
||||
self.session = Session(
|
||||
self,
|
||||
@ -934,7 +936,7 @@ class Client(Methods, BaseClient):
|
||||
except FileNotFoundError:
|
||||
self.dc_id = 1
|
||||
self.date = 0
|
||||
self.auth_key = await Auth(self.dc_id, self.test_mode, self._proxy).create()
|
||||
self.auth_key = await Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
|
||||
else:
|
||||
self.dc_id = s["dc_id"]
|
||||
self.test_mode = s["test_mode"]
|
||||
@ -1177,7 +1179,7 @@ class Client(Methods, BaseClient):
|
||||
session = Session(
|
||||
self,
|
||||
dc_id,
|
||||
await Auth(dc_id, self.test_mode, self._proxy).create(),
|
||||
await Auth(dc_id, self.test_mode, self.ipv6, self._proxy).create(),
|
||||
is_media=True
|
||||
)
|
||||
|
||||
@ -1262,7 +1264,7 @@ class Client(Methods, BaseClient):
|
||||
cdn_session = Session(
|
||||
self,
|
||||
r.dc_id,
|
||||
await Auth(r.dc_id, self.test_mode, self._proxy).create(),
|
||||
await Auth(r.dc_id, self.test_mode, self.ipv6, self._proxy).create(),
|
||||
is_media=True,
|
||||
is_cdn=True
|
||||
)
|
||||
|
@ -63,7 +63,7 @@ class BaseClient:
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.token = None
|
||||
self.bot_token = None
|
||||
self.dc_id = None
|
||||
self.auth_key = None
|
||||
self.user_id = None
|
||||
|
@ -20,6 +20,7 @@ import asyncio
|
||||
import logging
|
||||
|
||||
from .transport import *
|
||||
from ..session.internals import DataCenter
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -36,24 +37,30 @@ class Connection:
|
||||
4: TCPIntermediateO
|
||||
}
|
||||
|
||||
def __init__(self, address: tuple, proxy: dict, mode: int = 2):
|
||||
self.address = address
|
||||
def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, mode: int = 2):
|
||||
self.ipv6 = ipv6
|
||||
self.proxy = proxy
|
||||
self.address = DataCenter(dc_id, test_mode, ipv6)
|
||||
self.mode = self.MODES.get(mode, TCPAbridged)
|
||||
|
||||
self.protocol = None # type: TCP
|
||||
|
||||
async def connect(self):
|
||||
for i in range(Connection.MAX_RETRIES):
|
||||
self.protocol = self.mode(self.proxy)
|
||||
self.protocol = self.mode(self.ipv6, self.proxy)
|
||||
|
||||
try:
|
||||
log.info("Connecting...")
|
||||
await self.protocol.connect(self.address)
|
||||
except OSError:
|
||||
except OSError as e:
|
||||
log.warning(e) # TODO: Remove
|
||||
self.protocol.close()
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
log.info("Connected! IPv{} - {}".format(
|
||||
"6" if self.ipv6 else "4",
|
||||
self.mode.__name__
|
||||
))
|
||||
break
|
||||
else:
|
||||
log.warning("Connection failed! Trying again...")
|
||||
|
@ -36,17 +36,17 @@ log = logging.getLogger(__name__)
|
||||
class TCP:
|
||||
TIMEOUT = 10
|
||||
|
||||
def __init__(self, proxy: dict):
|
||||
def __init__(self, ipv6: bool, proxy: dict):
|
||||
self.proxy = proxy
|
||||
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
self.socket = socks.socksocket()
|
||||
self.socket = socks.socksocket(family=socket.AF_INET6 if ipv6 else socket.AF_INET)
|
||||
|
||||
self.socket.settimeout(TCP.TIMEOUT)
|
||||
|
||||
self.reader = None # type: asyncio.StreamReader
|
||||
self.writer = None # type: asyncio.StreamWriter
|
||||
|
||||
self.proxy_enabled = proxy.get("enabled", False)
|
||||
|
||||
if proxy and self.proxy_enabled:
|
||||
|
@ -24,15 +24,13 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TCPAbridged(TCP):
|
||||
def __init__(self, proxy: dict):
|
||||
super().__init__(proxy)
|
||||
def __init__(self, ipv6: bool, proxy: dict):
|
||||
super().__init__(ipv6, proxy)
|
||||
|
||||
def connect(self, address: tuple):
|
||||
super().connect(address)
|
||||
super().sendall(b"\xef")
|
||||
|
||||
log.info("Connected{}!".format(" with proxy" if self.proxy_enabled else ""))
|
||||
|
||||
def sendall(self, data: bytes, *args):
|
||||
length = len(data) // 4
|
||||
|
||||
|
@ -28,8 +28,9 @@ log = logging.getLogger(__name__)
|
||||
class TCPAbridgedO(TCP):
|
||||
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
|
||||
|
||||
def __init__(self, proxy: dict):
|
||||
super().__init__(proxy)
|
||||
def __init__(self, ipv6: bool, proxy: dict):
|
||||
super().__init__(ipv6, proxy)
|
||||
|
||||
self.encrypt = None
|
||||
self.decrypt = None
|
||||
|
||||
@ -54,8 +55,6 @@ class TCPAbridgedO(TCP):
|
||||
|
||||
super().sendall(nonce)
|
||||
|
||||
log.info("Connected{}!".format(" with proxy" if self.proxy_enabled else ""))
|
||||
|
||||
def sendall(self, data: bytes, *args):
|
||||
length = len(data) // 4
|
||||
|
||||
|
@ -26,14 +26,14 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TCPFull(TCP):
|
||||
def __init__(self, proxy: dict):
|
||||
super().__init__(proxy)
|
||||
def __init__(self, ipv6: bool, proxy: dict):
|
||||
super().__init__(ipv6, proxy)
|
||||
|
||||
self.seq_no = None
|
||||
|
||||
def connect(self, address: tuple):
|
||||
super().connect(address)
|
||||
self.seq_no = 0
|
||||
log.info("Connected{}!".format(" with proxy" if self.proxy_enabled else ""))
|
||||
|
||||
def sendall(self, data: bytes, *args):
|
||||
# 12 = packet_length (4), seq_no (4), crc32 (4) (at the end)
|
||||
|
@ -25,19 +25,13 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TCPIntermediate(TCP):
|
||||
def __init__(self, proxy: dict):
|
||||
super().__init__(proxy)
|
||||
def __init__(self, ipv6: bool, proxy: dict):
|
||||
super().__init__(ipv6, proxy)
|
||||
|
||||
async def connect(self, address: tuple):
|
||||
await super().connect(address)
|
||||
await super().send(b"\xee" * 4)
|
||||
|
||||
log.info("Connected{}!".format(
|
||||
" with proxy"
|
||||
if self.proxy_enabled
|
||||
else ""
|
||||
))
|
||||
|
||||
async def send(self, data: bytes, *args):
|
||||
await super().send(pack("<i", len(data)) + data)
|
||||
|
||||
|
@ -29,8 +29,9 @@ log = logging.getLogger(__name__)
|
||||
class TCPIntermediateO(TCP):
|
||||
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
|
||||
|
||||
def __init__(self, proxy: dict):
|
||||
super().__init__(proxy)
|
||||
def __init__(self, ipv6: bool, proxy: dict):
|
||||
super().__init__(ipv6, proxy)
|
||||
|
||||
self.encrypt = None
|
||||
self.decrypt = None
|
||||
|
||||
@ -55,8 +56,6 @@ class TCPIntermediateO(TCP):
|
||||
|
||||
super().sendall(nonce)
|
||||
|
||||
log.info("Connected{}!".format(" with proxy" if self.proxy_enabled else ""))
|
||||
|
||||
def sendall(self, data: bytes, *args):
|
||||
super().sendall(
|
||||
AES.ctr256_encrypt(
|
||||
|
@ -27,22 +27,21 @@ try:
|
||||
|
||||
|
||||
class AES:
|
||||
# TODO: Use new tgcrypto function names
|
||||
@classmethod
|
||||
def ige256_encrypt(cls, data: bytes, key: bytes, iv: bytes) -> bytes:
|
||||
return tgcrypto.ige_encrypt(data, key, iv)
|
||||
return tgcrypto.ige256_encrypt(data, key, iv)
|
||||
|
||||
@classmethod
|
||||
def ige256_decrypt(cls, data: bytes, key: bytes, iv: bytes) -> bytes:
|
||||
return tgcrypto.ige_decrypt(data, key, iv)
|
||||
return tgcrypto.ige256_decrypt(data, key, iv)
|
||||
|
||||
@staticmethod
|
||||
def ctr256_encrypt(data: bytes, key: bytes, iv: bytearray, state: bytearray = None) -> bytes:
|
||||
return tgcrypto.ctr_encrypt(data, key, iv, state or bytearray(1))
|
||||
return tgcrypto.ctr256_encrypt(data, key, iv, state or bytearray(1))
|
||||
|
||||
@staticmethod
|
||||
def ctr256_decrypt(data: bytes, key: bytes, iv: bytearray, state: bytearray = None) -> bytes:
|
||||
return tgcrypto.ctr_decrypt(data, key, iv, state or bytearray(1))
|
||||
return tgcrypto.ctr256_decrypt(data, key, iv, state or bytearray(1))
|
||||
|
||||
@staticmethod
|
||||
def xor(a: bytes, b: bytes) -> bytes:
|
||||
|
@ -20,6 +20,18 @@ from random import randint
|
||||
|
||||
|
||||
class Prime:
|
||||
CURRENT_DH_PRIME = int(
|
||||
"C71CAEB9C6B1C9048E6C522F70F13F73980D40238E3E21C14934D037563D930F"
|
||||
"48198A0AA7C14058229493D22530F4DBFA336F6E0AC925139543AED44CCE7C37"
|
||||
"20FD51F69458705AC68CD4FE6B6B13ABDC9746512969328454F18FAF8C595F64"
|
||||
"2477FE96BB2A941D5BCD1D4AC8CC49880708FA9B378E3C4F3A9060BEE67CF9A4"
|
||||
"A4A695811051907E162753B56B0F6B410DBA74D8A84B2A14B3144E0EF1284754"
|
||||
"FD17ED950D5965B4B9DD46582DB1178D169C6BC465B0D6FF9CA3928FEF5B9AE4"
|
||||
"E418FC15E83EBEA0F87FA9FF5EED70050DED2849F47BF959D956850CE929851F"
|
||||
"0D8115F635B105EE2E4E15D04B2454BF6F4FADF034B10403119CD8E3B92FCC5B",
|
||||
16
|
||||
)
|
||||
|
||||
# Recursive variant
|
||||
# @classmethod
|
||||
# def gcd(cls, a: int, b: int) -> int:
|
||||
|
@ -26,7 +26,7 @@ from pyrogram.api import functions, types
|
||||
from pyrogram.api.core import Object, Long, Int
|
||||
from pyrogram.connection import Connection
|
||||
from pyrogram.crypto import AES, RSA, Prime
|
||||
from .internals import MsgId, DataCenter
|
||||
from .internals import MsgId
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -34,21 +34,10 @@ log = logging.getLogger(__name__)
|
||||
class Auth:
|
||||
MAX_RETRIES = 5
|
||||
|
||||
CURRENT_DH_PRIME = int(
|
||||
"C71CAEB9C6B1C9048E6C522F70F13F73980D40238E3E21C14934D037563D930F"
|
||||
"48198A0AA7C14058229493D22530F4DBFA336F6E0AC925139543AED44CCE7C37"
|
||||
"20FD51F69458705AC68CD4FE6B6B13ABDC9746512969328454F18FAF8C595F64"
|
||||
"2477FE96BB2A941D5BCD1D4AC8CC49880708FA9B378E3C4F3A9060BEE67CF9A4"
|
||||
"A4A695811051907E162753B56B0F6B410DBA74D8A84B2A14B3144E0EF1284754"
|
||||
"FD17ED950D5965B4B9DD46582DB1178D169C6BC465B0D6FF9CA3928FEF5B9AE4"
|
||||
"E418FC15E83EBEA0F87FA9FF5EED70050DED2849F47BF959D956850CE929851F"
|
||||
"0D8115F635B105EE2E4E15D04B2454BF6F4FADF034B10403119CD8E3B92FCC5B",
|
||||
16
|
||||
)
|
||||
|
||||
def __init__(self, dc_id: int, test_mode: bool, proxy: dict):
|
||||
def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict):
|
||||
self.dc_id = dc_id
|
||||
self.test_mode = test_mode
|
||||
self.ipv6 = ipv6
|
||||
self.proxy = proxy
|
||||
|
||||
self.connection = None
|
||||
@ -84,7 +73,7 @@ class Auth:
|
||||
# The server may close the connection at any time, causing the auth key creation to fail.
|
||||
# If that happens, just try again up to MAX_RETRIES times.
|
||||
while True:
|
||||
self.connection = Connection(DataCenter(self.dc_id, self.test_mode), self.proxy)
|
||||
self.connection = Connection(self.dc_id, self.test_mode, self.ipv6, self.proxy)
|
||||
|
||||
try:
|
||||
log.info("Start creating a new auth key on DC{}".format(self.dc_id))
|
||||
@ -219,7 +208,7 @@ class Auth:
|
||||
# Security checks
|
||||
#######################
|
||||
|
||||
assert dh_prime == self.CURRENT_DH_PRIME
|
||||
assert dh_prime == Prime.CURRENT_DH_PRIME
|
||||
log.debug("DH parameters check: OK")
|
||||
|
||||
# https://core.telegram.org/mtproto/security_guidelines#g-a-and-g-b-validation
|
||||
|
@ -34,5 +34,32 @@ class DataCenter:
|
||||
121: "95.213.217.195"
|
||||
}
|
||||
|
||||
def __new__(cls, dc_id: int, test_mode: bool):
|
||||
return (cls.TEST[dc_id], 80) if test_mode else (cls.PROD[dc_id], 443)
|
||||
TEST_IPV6 = {
|
||||
1: "2001:b28:f23d:f001::e",
|
||||
2: "2001:67c:4e8:f002::e",
|
||||
3: "2001:b28:f23d:f003::e",
|
||||
121: "2a03:b0c0:3:d0::114:d001"
|
||||
}
|
||||
|
||||
PROD_IPV6 = {
|
||||
1: "2001:b28:f23d:f001::a",
|
||||
2: "2001:67c:4e8:f002::a",
|
||||
3: "2001:b28:f23d:f003::a",
|
||||
4: "2001:67c:4e8:f004::a",
|
||||
5: "2001:b28:f23f:f005::a",
|
||||
121: "2a03:b0c0:3:d0::114:d001"
|
||||
}
|
||||
|
||||
def __new__(cls, dc_id: int, test_mode: bool, ipv6: bool):
|
||||
if ipv6:
|
||||
return (
|
||||
(cls.TEST_IPV6[dc_id], 80)
|
||||
if test_mode
|
||||
else (cls.PROD_IPV6[dc_id], 443)
|
||||
)
|
||||
else:
|
||||
return (
|
||||
(cls.TEST[dc_id], 80)
|
||||
if test_mode
|
||||
else (cls.PROD[dc_id], 443)
|
||||
)
|
||||
|
@ -30,7 +30,7 @@ from pyrogram.api.core import Object, MsgContainer, Int, Long, FutureSalt, Futur
|
||||
from pyrogram.api.errors import Error, InternalServerError, AuthKeyDuplicated
|
||||
from pyrogram.connection import Connection
|
||||
from pyrogram.crypto import MTProto
|
||||
from .internals import MsgId, MsgFactory, DataCenter
|
||||
from .internals import MsgId, MsgFactory
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -109,7 +109,7 @@ class Session:
|
||||
|
||||
async def start(self):
|
||||
while True:
|
||||
self.connection = Connection(DataCenter(self.dc_id, self.client.test_mode), self.client.proxy)
|
||||
self.connection = Connection(self.dc_id, self.client.test_mode, self.client.ipv6, self.client.proxy)
|
||||
|
||||
try:
|
||||
await self.connection.connect()
|
||||
@ -142,7 +142,10 @@ class Session:
|
||||
|
||||
self.ping_task = asyncio.ensure_future(self.ping())
|
||||
|
||||
log.info("Connection inited: Layer {}".format(layer))
|
||||
log.info("Session initialized: Layer {}".format(layer))
|
||||
log.info("Device: {} - {}".format(self.client.device_model, self.client.app_version))
|
||||
log.info("System: {} ({})".format(self.client.system_version, self.client.lang_code.upper()))
|
||||
|
||||
except AuthKeyDuplicated as e:
|
||||
self.stop()
|
||||
raise e
|
||||
|
34
setup.py
34
setup.py
@ -46,25 +46,9 @@ def get_readme():
|
||||
|
||||
|
||||
class Clean(Command):
|
||||
DIST = [
|
||||
"./build",
|
||||
"./dist",
|
||||
"./Pyrogram.egg-info"
|
||||
]
|
||||
|
||||
API = [
|
||||
"pyrogram/api/errors/exceptions",
|
||||
"pyrogram/api/functions",
|
||||
"pyrogram/api/types",
|
||||
"pyrogram/api/all.py",
|
||||
]
|
||||
|
||||
DOCS = [
|
||||
"docs/source/functions",
|
||||
"docs/source/types",
|
||||
"docs/build"
|
||||
]
|
||||
|
||||
DIST = ["./build", "./dist", "./Pyrogram.egg-info"]
|
||||
API = ["pyrogram/api/errors/exceptions", "pyrogram/api/functions", "pyrogram/api/types", "pyrogram/api/all.py"]
|
||||
DOCS = ["docs/source/functions", "docs/source/types", "docs/build"]
|
||||
ALL = DIST + API + DOCS
|
||||
|
||||
description = "Clean generated files"
|
||||
@ -102,7 +86,7 @@ class Clean(Command):
|
||||
if self.docs:
|
||||
paths.update(Clean.DOCS)
|
||||
|
||||
if self.all:
|
||||
if self.all or not paths:
|
||||
paths.update(Clean.ALL)
|
||||
|
||||
for path in sorted(list(paths)):
|
||||
@ -114,12 +98,12 @@ class Clean(Command):
|
||||
print("removing {}".format(path))
|
||||
|
||||
|
||||
class Build(Command):
|
||||
description = "Build Pyrogram files"
|
||||
class Generate(Command):
|
||||
description = "Generate Pyrogram files"
|
||||
|
||||
user_options = [
|
||||
("api", None, "Build API files"),
|
||||
("docs", None, "Build docs files"),
|
||||
("api", None, "Generate API files"),
|
||||
("docs", None, "Generate docs files")
|
||||
]
|
||||
|
||||
def __init__(self, dist, **kw):
|
||||
@ -191,6 +175,6 @@ setup(
|
||||
extras_require={"tgcrypto": ["tgcrypto>=1.0.4"]},
|
||||
cmdclass={
|
||||
"clean": Clean,
|
||||
"build": Build,
|
||||
"generate": Generate
|
||||
}
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user