Merge branch 'develop' into asyncio

# Conflicts:
#	pyrogram/client/client.py
#	pyrogram/connection/connection.py
#	pyrogram/connection/transport/tcp/tcp.py
#	pyrogram/connection/transport/tcp/tcp_intermediate.py
#	pyrogram/session/session.py
This commit is contained in:
Dan 2018-09-01 01:05:25 +02:00
commit 10f3829c93
16 changed files with 116 additions and 91 deletions

View File

@ -62,6 +62,9 @@ Input Media
InputMediaPhoto
InputMediaVideo
InputMediaAudio
InputMediaAnimation
InputMediaDocument
InputPhoneContact
.. User & Chats
@ -172,5 +175,14 @@ Input Media
.. autoclass:: InputMediaVideo
:members:
.. autoclass:: InputMediaAudio
:members:
.. autoclass:: InputMediaAnimation
:members:
.. autoclass:: InputMediaDocument
:members:
.. autoclass:: InputPhoneContact
:members:

View File

@ -146,6 +146,7 @@ class Client(Methods, BaseClient):
device_model: str = None,
system_version: str = None,
lang_code: str = None,
ipv6: bool = False,
proxy: dict = None,
test_mode: bool = False,
phone_number: str = None,
@ -166,6 +167,7 @@ class Client(Methods, BaseClient):
self.device_model = device_model
self.system_version = system_version
self.lang_code = lang_code
self.ipv6 = ipv6
# TODO: Make code consistent, use underscore for private/protected fields
self._proxy = proxy
self.test_mode = test_mode
@ -201,7 +203,7 @@ class Client(Methods, BaseClient):
raise ConnectionError("Client has already been started")
if self.BOT_TOKEN_RE.match(self.session_name):
self.token = self.session_name
self.bot_token = self.session_name
self.session_name = self.session_name.split(":")[0]
self.load_config()
@ -217,14 +219,14 @@ class Client(Methods, BaseClient):
self.is_started = True
if self.user_id is None:
if self.token is None:
if self.bot_token is None:
await self.authorize_user()
else:
await self.authorize_bot()
self.save_session()
if self.token is None:
if self.bot_token is None:
now = time.time()
if abs(now - self.date) > Client.OFFLINE_SLEEP:
@ -385,14 +387,14 @@ class Client(Methods, BaseClient):
flags=0,
api_id=self.api_id,
api_hash=self.api_hash,
bot_auth_token=self.token
bot_auth_token=self.bot_token
)
)
except UserMigrate as e:
await self.session.stop()
self.dc_id = e.x
self.auth_key = await Auth(self.dc_id, self.test_mode, self._proxy).create()
self.auth_key = await Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
self.session = Session(
self,
@ -437,7 +439,7 @@ class Client(Methods, BaseClient):
await self.session.stop()
self.dc_id = e.x
self.auth_key = await Auth(self.dc_id, self.test_mode, self._proxy).create()
self.auth_key = await Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
self.session = Session(
self,
@ -934,7 +936,7 @@ class Client(Methods, BaseClient):
except FileNotFoundError:
self.dc_id = 1
self.date = 0
self.auth_key = await Auth(self.dc_id, self.test_mode, self._proxy).create()
self.auth_key = await Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
else:
self.dc_id = s["dc_id"]
self.test_mode = s["test_mode"]
@ -1177,7 +1179,7 @@ class Client(Methods, BaseClient):
session = Session(
self,
dc_id,
await Auth(dc_id, self.test_mode, self._proxy).create(),
await Auth(dc_id, self.test_mode, self.ipv6, self._proxy).create(),
is_media=True
)
@ -1262,7 +1264,7 @@ class Client(Methods, BaseClient):
cdn_session = Session(
self,
r.dc_id,
await Auth(r.dc_id, self.test_mode, self._proxy).create(),
await Auth(r.dc_id, self.test_mode, self.ipv6, self._proxy).create(),
is_media=True,
is_cdn=True
)

View File

@ -63,7 +63,7 @@ class BaseClient:
}
def __init__(self):
self.token = None
self.bot_token = None
self.dc_id = None
self.auth_key = None
self.user_id = None

View File

@ -20,6 +20,7 @@ import asyncio
import logging
from .transport import *
from ..session.internals import DataCenter
log = logging.getLogger(__name__)
@ -36,24 +37,30 @@ class Connection:
4: TCPIntermediateO
}
def __init__(self, address: tuple, proxy: dict, mode: int = 2):
self.address = address
def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, mode: int = 2):
self.ipv6 = ipv6
self.proxy = proxy
self.address = DataCenter(dc_id, test_mode, ipv6)
self.mode = self.MODES.get(mode, TCPAbridged)
self.protocol = None # type: TCP
async def connect(self):
for i in range(Connection.MAX_RETRIES):
self.protocol = self.mode(self.proxy)
self.protocol = self.mode(self.ipv6, self.proxy)
try:
log.info("Connecting...")
await self.protocol.connect(self.address)
except OSError:
except OSError as e:
log.warning(e) # TODO: Remove
self.protocol.close()
await asyncio.sleep(1)
else:
log.info("Connected! IPv{} - {}".format(
"6" if self.ipv6 else "4",
self.mode.__name__
))
break
else:
log.warning("Connection failed! Trying again...")

View File

@ -36,17 +36,17 @@ log = logging.getLogger(__name__)
class TCP:
TIMEOUT = 10
def __init__(self, proxy: dict):
def __init__(self, ipv6: bool, proxy: dict):
self.proxy = proxy
self.lock = asyncio.Lock()
self.socket = socks.socksocket()
self.socket = socks.socksocket(family=socket.AF_INET6 if ipv6 else socket.AF_INET)
self.socket.settimeout(TCP.TIMEOUT)
self.reader = None # type: asyncio.StreamReader
self.writer = None # type: asyncio.StreamWriter
self.proxy_enabled = proxy.get("enabled", False)
if proxy and self.proxy_enabled:

View File

@ -24,15 +24,13 @@ log = logging.getLogger(__name__)
class TCPAbridged(TCP):
def __init__(self, proxy: dict):
super().__init__(proxy)
def __init__(self, ipv6: bool, proxy: dict):
super().__init__(ipv6, proxy)
def connect(self, address: tuple):
super().connect(address)
super().sendall(b"\xef")
log.info("Connected{}!".format(" with proxy" if self.proxy_enabled else ""))
def sendall(self, data: bytes, *args):
length = len(data) // 4

View File

@ -28,8 +28,9 @@ log = logging.getLogger(__name__)
class TCPAbridgedO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
def __init__(self, proxy: dict):
super().__init__(proxy)
def __init__(self, ipv6: bool, proxy: dict):
super().__init__(ipv6, proxy)
self.encrypt = None
self.decrypt = None
@ -54,8 +55,6 @@ class TCPAbridgedO(TCP):
super().sendall(nonce)
log.info("Connected{}!".format(" with proxy" if self.proxy_enabled else ""))
def sendall(self, data: bytes, *args):
length = len(data) // 4

View File

@ -26,14 +26,14 @@ log = logging.getLogger(__name__)
class TCPFull(TCP):
def __init__(self, proxy: dict):
super().__init__(proxy)
def __init__(self, ipv6: bool, proxy: dict):
super().__init__(ipv6, proxy)
self.seq_no = None
def connect(self, address: tuple):
super().connect(address)
self.seq_no = 0
log.info("Connected{}!".format(" with proxy" if self.proxy_enabled else ""))
def sendall(self, data: bytes, *args):
# 12 = packet_length (4), seq_no (4), crc32 (4) (at the end)

View File

@ -25,19 +25,13 @@ log = logging.getLogger(__name__)
class TCPIntermediate(TCP):
def __init__(self, proxy: dict):
super().__init__(proxy)
def __init__(self, ipv6: bool, proxy: dict):
super().__init__(ipv6, proxy)
async def connect(self, address: tuple):
await super().connect(address)
await super().send(b"\xee" * 4)
log.info("Connected{}!".format(
" with proxy"
if self.proxy_enabled
else ""
))
async def send(self, data: bytes, *args):
await super().send(pack("<i", len(data)) + data)

View File

@ -29,8 +29,9 @@ log = logging.getLogger(__name__)
class TCPIntermediateO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
def __init__(self, proxy: dict):
super().__init__(proxy)
def __init__(self, ipv6: bool, proxy: dict):
super().__init__(ipv6, proxy)
self.encrypt = None
self.decrypt = None
@ -55,8 +56,6 @@ class TCPIntermediateO(TCP):
super().sendall(nonce)
log.info("Connected{}!".format(" with proxy" if self.proxy_enabled else ""))
def sendall(self, data: bytes, *args):
super().sendall(
AES.ctr256_encrypt(

View File

@ -27,22 +27,21 @@ try:
class AES:
# TODO: Use new tgcrypto function names
@classmethod
def ige256_encrypt(cls, data: bytes, key: bytes, iv: bytes) -> bytes:
return tgcrypto.ige_encrypt(data, key, iv)
return tgcrypto.ige256_encrypt(data, key, iv)
@classmethod
def ige256_decrypt(cls, data: bytes, key: bytes, iv: bytes) -> bytes:
return tgcrypto.ige_decrypt(data, key, iv)
return tgcrypto.ige256_decrypt(data, key, iv)
@staticmethod
def ctr256_encrypt(data: bytes, key: bytes, iv: bytearray, state: bytearray = None) -> bytes:
return tgcrypto.ctr_encrypt(data, key, iv, state or bytearray(1))
return tgcrypto.ctr256_encrypt(data, key, iv, state or bytearray(1))
@staticmethod
def ctr256_decrypt(data: bytes, key: bytes, iv: bytearray, state: bytearray = None) -> bytes:
return tgcrypto.ctr_decrypt(data, key, iv, state or bytearray(1))
return tgcrypto.ctr256_decrypt(data, key, iv, state or bytearray(1))
@staticmethod
def xor(a: bytes, b: bytes) -> bytes:

View File

@ -20,6 +20,18 @@ from random import randint
class Prime:
CURRENT_DH_PRIME = int(
"C71CAEB9C6B1C9048E6C522F70F13F73980D40238E3E21C14934D037563D930F"
"48198A0AA7C14058229493D22530F4DBFA336F6E0AC925139543AED44CCE7C37"
"20FD51F69458705AC68CD4FE6B6B13ABDC9746512969328454F18FAF8C595F64"
"2477FE96BB2A941D5BCD1D4AC8CC49880708FA9B378E3C4F3A9060BEE67CF9A4"
"A4A695811051907E162753B56B0F6B410DBA74D8A84B2A14B3144E0EF1284754"
"FD17ED950D5965B4B9DD46582DB1178D169C6BC465B0D6FF9CA3928FEF5B9AE4"
"E418FC15E83EBEA0F87FA9FF5EED70050DED2849F47BF959D956850CE929851F"
"0D8115F635B105EE2E4E15D04B2454BF6F4FADF034B10403119CD8E3B92FCC5B",
16
)
# Recursive variant
# @classmethod
# def gcd(cls, a: int, b: int) -> int:

View File

@ -26,7 +26,7 @@ from pyrogram.api import functions, types
from pyrogram.api.core import Object, Long, Int
from pyrogram.connection import Connection
from pyrogram.crypto import AES, RSA, Prime
from .internals import MsgId, DataCenter
from .internals import MsgId
log = logging.getLogger(__name__)
@ -34,21 +34,10 @@ log = logging.getLogger(__name__)
class Auth:
MAX_RETRIES = 5
CURRENT_DH_PRIME = int(
"C71CAEB9C6B1C9048E6C522F70F13F73980D40238E3E21C14934D037563D930F"
"48198A0AA7C14058229493D22530F4DBFA336F6E0AC925139543AED44CCE7C37"
"20FD51F69458705AC68CD4FE6B6B13ABDC9746512969328454F18FAF8C595F64"
"2477FE96BB2A941D5BCD1D4AC8CC49880708FA9B378E3C4F3A9060BEE67CF9A4"
"A4A695811051907E162753B56B0F6B410DBA74D8A84B2A14B3144E0EF1284754"
"FD17ED950D5965B4B9DD46582DB1178D169C6BC465B0D6FF9CA3928FEF5B9AE4"
"E418FC15E83EBEA0F87FA9FF5EED70050DED2849F47BF959D956850CE929851F"
"0D8115F635B105EE2E4E15D04B2454BF6F4FADF034B10403119CD8E3B92FCC5B",
16
)
def __init__(self, dc_id: int, test_mode: bool, proxy: dict):
def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict):
self.dc_id = dc_id
self.test_mode = test_mode
self.ipv6 = ipv6
self.proxy = proxy
self.connection = None
@ -84,7 +73,7 @@ class Auth:
# The server may close the connection at any time, causing the auth key creation to fail.
# If that happens, just try again up to MAX_RETRIES times.
while True:
self.connection = Connection(DataCenter(self.dc_id, self.test_mode), self.proxy)
self.connection = Connection(self.dc_id, self.test_mode, self.ipv6, self.proxy)
try:
log.info("Start creating a new auth key on DC{}".format(self.dc_id))
@ -219,7 +208,7 @@ class Auth:
# Security checks
#######################
assert dh_prime == self.CURRENT_DH_PRIME
assert dh_prime == Prime.CURRENT_DH_PRIME
log.debug("DH parameters check: OK")
# https://core.telegram.org/mtproto/security_guidelines#g-a-and-g-b-validation

View File

@ -34,5 +34,32 @@ class DataCenter:
121: "95.213.217.195"
}
def __new__(cls, dc_id: int, test_mode: bool):
return (cls.TEST[dc_id], 80) if test_mode else (cls.PROD[dc_id], 443)
TEST_IPV6 = {
1: "2001:b28:f23d:f001::e",
2: "2001:67c:4e8:f002::e",
3: "2001:b28:f23d:f003::e",
121: "2a03:b0c0:3:d0::114:d001"
}
PROD_IPV6 = {
1: "2001:b28:f23d:f001::a",
2: "2001:67c:4e8:f002::a",
3: "2001:b28:f23d:f003::a",
4: "2001:67c:4e8:f004::a",
5: "2001:b28:f23f:f005::a",
121: "2a03:b0c0:3:d0::114:d001"
}
def __new__(cls, dc_id: int, test_mode: bool, ipv6: bool):
if ipv6:
return (
(cls.TEST_IPV6[dc_id], 80)
if test_mode
else (cls.PROD_IPV6[dc_id], 443)
)
else:
return (
(cls.TEST[dc_id], 80)
if test_mode
else (cls.PROD[dc_id], 443)
)

View File

@ -30,7 +30,7 @@ from pyrogram.api.core import Object, MsgContainer, Int, Long, FutureSalt, Futur
from pyrogram.api.errors import Error, InternalServerError, AuthKeyDuplicated
from pyrogram.connection import Connection
from pyrogram.crypto import MTProto
from .internals import MsgId, MsgFactory, DataCenter
from .internals import MsgId, MsgFactory
log = logging.getLogger(__name__)
@ -109,7 +109,7 @@ class Session:
async def start(self):
while True:
self.connection = Connection(DataCenter(self.dc_id, self.client.test_mode), self.client.proxy)
self.connection = Connection(self.dc_id, self.client.test_mode, self.client.ipv6, self.client.proxy)
try:
await self.connection.connect()
@ -142,7 +142,10 @@ class Session:
self.ping_task = asyncio.ensure_future(self.ping())
log.info("Connection inited: Layer {}".format(layer))
log.info("Session initialized: Layer {}".format(layer))
log.info("Device: {} - {}".format(self.client.device_model, self.client.app_version))
log.info("System: {} ({})".format(self.client.system_version, self.client.lang_code.upper()))
except AuthKeyDuplicated as e:
self.stop()
raise e

View File

@ -46,25 +46,9 @@ def get_readme():
class Clean(Command):
DIST = [
"./build",
"./dist",
"./Pyrogram.egg-info"
]
API = [
"pyrogram/api/errors/exceptions",
"pyrogram/api/functions",
"pyrogram/api/types",
"pyrogram/api/all.py",
]
DOCS = [
"docs/source/functions",
"docs/source/types",
"docs/build"
]
DIST = ["./build", "./dist", "./Pyrogram.egg-info"]
API = ["pyrogram/api/errors/exceptions", "pyrogram/api/functions", "pyrogram/api/types", "pyrogram/api/all.py"]
DOCS = ["docs/source/functions", "docs/source/types", "docs/build"]
ALL = DIST + API + DOCS
description = "Clean generated files"
@ -102,7 +86,7 @@ class Clean(Command):
if self.docs:
paths.update(Clean.DOCS)
if self.all:
if self.all or not paths:
paths.update(Clean.ALL)
for path in sorted(list(paths)):
@ -114,12 +98,12 @@ class Clean(Command):
print("removing {}".format(path))
class Build(Command):
description = "Build Pyrogram files"
class Generate(Command):
description = "Generate Pyrogram files"
user_options = [
("api", None, "Build API files"),
("docs", None, "Build docs files"),
("api", None, "Generate API files"),
("docs", None, "Generate docs files")
]
def __init__(self, dist, **kw):
@ -191,6 +175,6 @@ setup(
extras_require={"tgcrypto": ["tgcrypto>=1.0.4"]},
cmdclass={
"clean": Clean,
"build": Build,
"generate": Generate
}
)