Merge branch 'develop' into asyncio

# Conflicts:
#	pyrogram/client/client.py
#	pyrogram/connection/connection.py
#	pyrogram/connection/transport/tcp/tcp.py
#	pyrogram/connection/transport/tcp/tcp_intermediate.py
#	pyrogram/session/session.py
This commit is contained in:
Dan 2018-09-01 01:05:25 +02:00
commit 10f3829c93
16 changed files with 116 additions and 91 deletions

View File

@ -62,6 +62,9 @@ Input Media
InputMediaPhoto InputMediaPhoto
InputMediaVideo InputMediaVideo
InputMediaAudio
InputMediaAnimation
InputMediaDocument
InputPhoneContact InputPhoneContact
.. User & Chats .. User & Chats
@ -172,5 +175,14 @@ Input Media
.. autoclass:: InputMediaVideo .. autoclass:: InputMediaVideo
:members: :members:
.. autoclass:: InputMediaAudio
:members:
.. autoclass:: InputMediaAnimation
:members:
.. autoclass:: InputMediaDocument
:members:
.. autoclass:: InputPhoneContact .. autoclass:: InputPhoneContact
:members: :members:

View File

@ -146,6 +146,7 @@ class Client(Methods, BaseClient):
device_model: str = None, device_model: str = None,
system_version: str = None, system_version: str = None,
lang_code: str = None, lang_code: str = None,
ipv6: bool = False,
proxy: dict = None, proxy: dict = None,
test_mode: bool = False, test_mode: bool = False,
phone_number: str = None, phone_number: str = None,
@ -166,6 +167,7 @@ class Client(Methods, BaseClient):
self.device_model = device_model self.device_model = device_model
self.system_version = system_version self.system_version = system_version
self.lang_code = lang_code self.lang_code = lang_code
self.ipv6 = ipv6
# TODO: Make code consistent, use underscore for private/protected fields # TODO: Make code consistent, use underscore for private/protected fields
self._proxy = proxy self._proxy = proxy
self.test_mode = test_mode self.test_mode = test_mode
@ -201,7 +203,7 @@ class Client(Methods, BaseClient):
raise ConnectionError("Client has already been started") raise ConnectionError("Client has already been started")
if self.BOT_TOKEN_RE.match(self.session_name): if self.BOT_TOKEN_RE.match(self.session_name):
self.token = self.session_name self.bot_token = self.session_name
self.session_name = self.session_name.split(":")[0] self.session_name = self.session_name.split(":")[0]
self.load_config() self.load_config()
@ -217,14 +219,14 @@ class Client(Methods, BaseClient):
self.is_started = True self.is_started = True
if self.user_id is None: if self.user_id is None:
if self.token is None: if self.bot_token is None:
await self.authorize_user() await self.authorize_user()
else: else:
await self.authorize_bot() await self.authorize_bot()
self.save_session() self.save_session()
if self.token is None: if self.bot_token is None:
now = time.time() now = time.time()
if abs(now - self.date) > Client.OFFLINE_SLEEP: if abs(now - self.date) > Client.OFFLINE_SLEEP:
@ -385,14 +387,14 @@ class Client(Methods, BaseClient):
flags=0, flags=0,
api_id=self.api_id, api_id=self.api_id,
api_hash=self.api_hash, api_hash=self.api_hash,
bot_auth_token=self.token bot_auth_token=self.bot_token
) )
) )
except UserMigrate as e: except UserMigrate as e:
await self.session.stop() await self.session.stop()
self.dc_id = e.x self.dc_id = e.x
self.auth_key = await Auth(self.dc_id, self.test_mode, self._proxy).create() self.auth_key = await Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
self.session = Session( self.session = Session(
self, self,
@ -437,7 +439,7 @@ class Client(Methods, BaseClient):
await self.session.stop() await self.session.stop()
self.dc_id = e.x self.dc_id = e.x
self.auth_key = await Auth(self.dc_id, self.test_mode, self._proxy).create() self.auth_key = await Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
self.session = Session( self.session = Session(
self, self,
@ -934,7 +936,7 @@ class Client(Methods, BaseClient):
except FileNotFoundError: except FileNotFoundError:
self.dc_id = 1 self.dc_id = 1
self.date = 0 self.date = 0
self.auth_key = await Auth(self.dc_id, self.test_mode, self._proxy).create() self.auth_key = await Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
else: else:
self.dc_id = s["dc_id"] self.dc_id = s["dc_id"]
self.test_mode = s["test_mode"] self.test_mode = s["test_mode"]
@ -1177,7 +1179,7 @@ class Client(Methods, BaseClient):
session = Session( session = Session(
self, self,
dc_id, dc_id,
await Auth(dc_id, self.test_mode, self._proxy).create(), await Auth(dc_id, self.test_mode, self.ipv6, self._proxy).create(),
is_media=True is_media=True
) )
@ -1262,7 +1264,7 @@ class Client(Methods, BaseClient):
cdn_session = Session( cdn_session = Session(
self, self,
r.dc_id, r.dc_id,
await Auth(r.dc_id, self.test_mode, self._proxy).create(), await Auth(r.dc_id, self.test_mode, self.ipv6, self._proxy).create(),
is_media=True, is_media=True,
is_cdn=True is_cdn=True
) )

View File

@ -63,7 +63,7 @@ class BaseClient:
} }
def __init__(self): def __init__(self):
self.token = None self.bot_token = None
self.dc_id = None self.dc_id = None
self.auth_key = None self.auth_key = None
self.user_id = None self.user_id = None

View File

@ -20,6 +20,7 @@ import asyncio
import logging import logging
from .transport import * from .transport import *
from ..session.internals import DataCenter
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -36,24 +37,30 @@ class Connection:
4: TCPIntermediateO 4: TCPIntermediateO
} }
def __init__(self, address: tuple, proxy: dict, mode: int = 2): def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, mode: int = 2):
self.address = address self.ipv6 = ipv6
self.proxy = proxy self.proxy = proxy
self.address = DataCenter(dc_id, test_mode, ipv6)
self.mode = self.MODES.get(mode, TCPAbridged) self.mode = self.MODES.get(mode, TCPAbridged)
self.protocol = None # type: TCP self.protocol = None # type: TCP
async def connect(self): async def connect(self):
for i in range(Connection.MAX_RETRIES): for i in range(Connection.MAX_RETRIES):
self.protocol = self.mode(self.proxy) self.protocol = self.mode(self.ipv6, self.proxy)
try: try:
log.info("Connecting...") log.info("Connecting...")
await self.protocol.connect(self.address) await self.protocol.connect(self.address)
except OSError: except OSError as e:
log.warning(e) # TODO: Remove
self.protocol.close() self.protocol.close()
await asyncio.sleep(1) await asyncio.sleep(1)
else: else:
log.info("Connected! IPv{} - {}".format(
"6" if self.ipv6 else "4",
self.mode.__name__
))
break break
else: else:
log.warning("Connection failed! Trying again...") log.warning("Connection failed! Trying again...")

View File

@ -36,17 +36,17 @@ log = logging.getLogger(__name__)
class TCP: class TCP:
TIMEOUT = 10 TIMEOUT = 10
def __init__(self, proxy: dict): def __init__(self, ipv6: bool, proxy: dict):
self.proxy = proxy self.proxy = proxy
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
self.socket = socks.socksocket() self.socket = socks.socksocket(family=socket.AF_INET6 if ipv6 else socket.AF_INET)
self.socket.settimeout(TCP.TIMEOUT) self.socket.settimeout(TCP.TIMEOUT)
self.reader = None # type: asyncio.StreamReader self.reader = None # type: asyncio.StreamReader
self.writer = None # type: asyncio.StreamWriter self.writer = None # type: asyncio.StreamWriter
self.proxy_enabled = proxy.get("enabled", False) self.proxy_enabled = proxy.get("enabled", False)
if proxy and self.proxy_enabled: if proxy and self.proxy_enabled:

View File

@ -24,15 +24,13 @@ log = logging.getLogger(__name__)
class TCPAbridged(TCP): class TCPAbridged(TCP):
def __init__(self, proxy: dict): def __init__(self, ipv6: bool, proxy: dict):
super().__init__(proxy) super().__init__(ipv6, proxy)
def connect(self, address: tuple): def connect(self, address: tuple):
super().connect(address) super().connect(address)
super().sendall(b"\xef") super().sendall(b"\xef")
log.info("Connected{}!".format(" with proxy" if self.proxy_enabled else ""))
def sendall(self, data: bytes, *args): def sendall(self, data: bytes, *args):
length = len(data) // 4 length = len(data) // 4

View File

@ -28,8 +28,9 @@ log = logging.getLogger(__name__)
class TCPAbridgedO(TCP): class TCPAbridgedO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4) RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
def __init__(self, proxy: dict): def __init__(self, ipv6: bool, proxy: dict):
super().__init__(proxy) super().__init__(ipv6, proxy)
self.encrypt = None self.encrypt = None
self.decrypt = None self.decrypt = None
@ -54,8 +55,6 @@ class TCPAbridgedO(TCP):
super().sendall(nonce) super().sendall(nonce)
log.info("Connected{}!".format(" with proxy" if self.proxy_enabled else ""))
def sendall(self, data: bytes, *args): def sendall(self, data: bytes, *args):
length = len(data) // 4 length = len(data) // 4

View File

@ -26,14 +26,14 @@ log = logging.getLogger(__name__)
class TCPFull(TCP): class TCPFull(TCP):
def __init__(self, proxy: dict): def __init__(self, ipv6: bool, proxy: dict):
super().__init__(proxy) super().__init__(ipv6, proxy)
self.seq_no = None self.seq_no = None
def connect(self, address: tuple): def connect(self, address: tuple):
super().connect(address) super().connect(address)
self.seq_no = 0 self.seq_no = 0
log.info("Connected{}!".format(" with proxy" if self.proxy_enabled else ""))
def sendall(self, data: bytes, *args): def sendall(self, data: bytes, *args):
# 12 = packet_length (4), seq_no (4), crc32 (4) (at the end) # 12 = packet_length (4), seq_no (4), crc32 (4) (at the end)

View File

@ -25,19 +25,13 @@ log = logging.getLogger(__name__)
class TCPIntermediate(TCP): class TCPIntermediate(TCP):
def __init__(self, proxy: dict): def __init__(self, ipv6: bool, proxy: dict):
super().__init__(proxy) super().__init__(ipv6, proxy)
async def connect(self, address: tuple): async def connect(self, address: tuple):
await super().connect(address) await super().connect(address)
await super().send(b"\xee" * 4) await super().send(b"\xee" * 4)
log.info("Connected{}!".format(
" with proxy"
if self.proxy_enabled
else ""
))
async def send(self, data: bytes, *args): async def send(self, data: bytes, *args):
await super().send(pack("<i", len(data)) + data) await super().send(pack("<i", len(data)) + data)

View File

@ -29,8 +29,9 @@ log = logging.getLogger(__name__)
class TCPIntermediateO(TCP): class TCPIntermediateO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4) RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
def __init__(self, proxy: dict): def __init__(self, ipv6: bool, proxy: dict):
super().__init__(proxy) super().__init__(ipv6, proxy)
self.encrypt = None self.encrypt = None
self.decrypt = None self.decrypt = None
@ -55,8 +56,6 @@ class TCPIntermediateO(TCP):
super().sendall(nonce) super().sendall(nonce)
log.info("Connected{}!".format(" with proxy" if self.proxy_enabled else ""))
def sendall(self, data: bytes, *args): def sendall(self, data: bytes, *args):
super().sendall( super().sendall(
AES.ctr256_encrypt( AES.ctr256_encrypt(

View File

@ -27,22 +27,21 @@ try:
class AES: class AES:
# TODO: Use new tgcrypto function names
@classmethod @classmethod
def ige256_encrypt(cls, data: bytes, key: bytes, iv: bytes) -> bytes: def ige256_encrypt(cls, data: bytes, key: bytes, iv: bytes) -> bytes:
return tgcrypto.ige_encrypt(data, key, iv) return tgcrypto.ige256_encrypt(data, key, iv)
@classmethod @classmethod
def ige256_decrypt(cls, data: bytes, key: bytes, iv: bytes) -> bytes: def ige256_decrypt(cls, data: bytes, key: bytes, iv: bytes) -> bytes:
return tgcrypto.ige_decrypt(data, key, iv) return tgcrypto.ige256_decrypt(data, key, iv)
@staticmethod @staticmethod
def ctr256_encrypt(data: bytes, key: bytes, iv: bytearray, state: bytearray = None) -> bytes: def ctr256_encrypt(data: bytes, key: bytes, iv: bytearray, state: bytearray = None) -> bytes:
return tgcrypto.ctr_encrypt(data, key, iv, state or bytearray(1)) return tgcrypto.ctr256_encrypt(data, key, iv, state or bytearray(1))
@staticmethod @staticmethod
def ctr256_decrypt(data: bytes, key: bytes, iv: bytearray, state: bytearray = None) -> bytes: def ctr256_decrypt(data: bytes, key: bytes, iv: bytearray, state: bytearray = None) -> bytes:
return tgcrypto.ctr_decrypt(data, key, iv, state or bytearray(1)) return tgcrypto.ctr256_decrypt(data, key, iv, state or bytearray(1))
@staticmethod @staticmethod
def xor(a: bytes, b: bytes) -> bytes: def xor(a: bytes, b: bytes) -> bytes:

View File

@ -20,6 +20,18 @@ from random import randint
class Prime: class Prime:
CURRENT_DH_PRIME = int(
"C71CAEB9C6B1C9048E6C522F70F13F73980D40238E3E21C14934D037563D930F"
"48198A0AA7C14058229493D22530F4DBFA336F6E0AC925139543AED44CCE7C37"
"20FD51F69458705AC68CD4FE6B6B13ABDC9746512969328454F18FAF8C595F64"
"2477FE96BB2A941D5BCD1D4AC8CC49880708FA9B378E3C4F3A9060BEE67CF9A4"
"A4A695811051907E162753B56B0F6B410DBA74D8A84B2A14B3144E0EF1284754"
"FD17ED950D5965B4B9DD46582DB1178D169C6BC465B0D6FF9CA3928FEF5B9AE4"
"E418FC15E83EBEA0F87FA9FF5EED70050DED2849F47BF959D956850CE929851F"
"0D8115F635B105EE2E4E15D04B2454BF6F4FADF034B10403119CD8E3B92FCC5B",
16
)
# Recursive variant # Recursive variant
# @classmethod # @classmethod
# def gcd(cls, a: int, b: int) -> int: # def gcd(cls, a: int, b: int) -> int:

View File

@ -26,7 +26,7 @@ from pyrogram.api import functions, types
from pyrogram.api.core import Object, Long, Int from pyrogram.api.core import Object, Long, Int
from pyrogram.connection import Connection from pyrogram.connection import Connection
from pyrogram.crypto import AES, RSA, Prime from pyrogram.crypto import AES, RSA, Prime
from .internals import MsgId, DataCenter from .internals import MsgId
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -34,21 +34,10 @@ log = logging.getLogger(__name__)
class Auth: class Auth:
MAX_RETRIES = 5 MAX_RETRIES = 5
CURRENT_DH_PRIME = int( def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict):
"C71CAEB9C6B1C9048E6C522F70F13F73980D40238E3E21C14934D037563D930F"
"48198A0AA7C14058229493D22530F4DBFA336F6E0AC925139543AED44CCE7C37"
"20FD51F69458705AC68CD4FE6B6B13ABDC9746512969328454F18FAF8C595F64"
"2477FE96BB2A941D5BCD1D4AC8CC49880708FA9B378E3C4F3A9060BEE67CF9A4"
"A4A695811051907E162753B56B0F6B410DBA74D8A84B2A14B3144E0EF1284754"
"FD17ED950D5965B4B9DD46582DB1178D169C6BC465B0D6FF9CA3928FEF5B9AE4"
"E418FC15E83EBEA0F87FA9FF5EED70050DED2849F47BF959D956850CE929851F"
"0D8115F635B105EE2E4E15D04B2454BF6F4FADF034B10403119CD8E3B92FCC5B",
16
)
def __init__(self, dc_id: int, test_mode: bool, proxy: dict):
self.dc_id = dc_id self.dc_id = dc_id
self.test_mode = test_mode self.test_mode = test_mode
self.ipv6 = ipv6
self.proxy = proxy self.proxy = proxy
self.connection = None self.connection = None
@ -84,7 +73,7 @@ class Auth:
# The server may close the connection at any time, causing the auth key creation to fail. # The server may close the connection at any time, causing the auth key creation to fail.
# If that happens, just try again up to MAX_RETRIES times. # If that happens, just try again up to MAX_RETRIES times.
while True: while True:
self.connection = Connection(DataCenter(self.dc_id, self.test_mode), self.proxy) self.connection = Connection(self.dc_id, self.test_mode, self.ipv6, self.proxy)
try: try:
log.info("Start creating a new auth key on DC{}".format(self.dc_id)) log.info("Start creating a new auth key on DC{}".format(self.dc_id))
@ -219,7 +208,7 @@ class Auth:
# Security checks # Security checks
####################### #######################
assert dh_prime == self.CURRENT_DH_PRIME assert dh_prime == Prime.CURRENT_DH_PRIME
log.debug("DH parameters check: OK") log.debug("DH parameters check: OK")
# https://core.telegram.org/mtproto/security_guidelines#g-a-and-g-b-validation # https://core.telegram.org/mtproto/security_guidelines#g-a-and-g-b-validation

View File

@ -34,5 +34,32 @@ class DataCenter:
121: "95.213.217.195" 121: "95.213.217.195"
} }
def __new__(cls, dc_id: int, test_mode: bool): TEST_IPV6 = {
return (cls.TEST[dc_id], 80) if test_mode else (cls.PROD[dc_id], 443) 1: "2001:b28:f23d:f001::e",
2: "2001:67c:4e8:f002::e",
3: "2001:b28:f23d:f003::e",
121: "2a03:b0c0:3:d0::114:d001"
}
PROD_IPV6 = {
1: "2001:b28:f23d:f001::a",
2: "2001:67c:4e8:f002::a",
3: "2001:b28:f23d:f003::a",
4: "2001:67c:4e8:f004::a",
5: "2001:b28:f23f:f005::a",
121: "2a03:b0c0:3:d0::114:d001"
}
def __new__(cls, dc_id: int, test_mode: bool, ipv6: bool):
if ipv6:
return (
(cls.TEST_IPV6[dc_id], 80)
if test_mode
else (cls.PROD_IPV6[dc_id], 443)
)
else:
return (
(cls.TEST[dc_id], 80)
if test_mode
else (cls.PROD[dc_id], 443)
)

View File

@ -30,7 +30,7 @@ from pyrogram.api.core import Object, MsgContainer, Int, Long, FutureSalt, Futur
from pyrogram.api.errors import Error, InternalServerError, AuthKeyDuplicated from pyrogram.api.errors import Error, InternalServerError, AuthKeyDuplicated
from pyrogram.connection import Connection from pyrogram.connection import Connection
from pyrogram.crypto import MTProto from pyrogram.crypto import MTProto
from .internals import MsgId, MsgFactory, DataCenter from .internals import MsgId, MsgFactory
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -109,7 +109,7 @@ class Session:
async def start(self): async def start(self):
while True: while True:
self.connection = Connection(DataCenter(self.dc_id, self.client.test_mode), self.client.proxy) self.connection = Connection(self.dc_id, self.client.test_mode, self.client.ipv6, self.client.proxy)
try: try:
await self.connection.connect() await self.connection.connect()
@ -142,7 +142,10 @@ class Session:
self.ping_task = asyncio.ensure_future(self.ping()) self.ping_task = asyncio.ensure_future(self.ping())
log.info("Connection inited: Layer {}".format(layer)) log.info("Session initialized: Layer {}".format(layer))
log.info("Device: {} - {}".format(self.client.device_model, self.client.app_version))
log.info("System: {} ({})".format(self.client.system_version, self.client.lang_code.upper()))
except AuthKeyDuplicated as e: except AuthKeyDuplicated as e:
self.stop() self.stop()
raise e raise e

View File

@ -46,25 +46,9 @@ def get_readme():
class Clean(Command): class Clean(Command):
DIST = [ DIST = ["./build", "./dist", "./Pyrogram.egg-info"]
"./build", API = ["pyrogram/api/errors/exceptions", "pyrogram/api/functions", "pyrogram/api/types", "pyrogram/api/all.py"]
"./dist", DOCS = ["docs/source/functions", "docs/source/types", "docs/build"]
"./Pyrogram.egg-info"
]
API = [
"pyrogram/api/errors/exceptions",
"pyrogram/api/functions",
"pyrogram/api/types",
"pyrogram/api/all.py",
]
DOCS = [
"docs/source/functions",
"docs/source/types",
"docs/build"
]
ALL = DIST + API + DOCS ALL = DIST + API + DOCS
description = "Clean generated files" description = "Clean generated files"
@ -102,7 +86,7 @@ class Clean(Command):
if self.docs: if self.docs:
paths.update(Clean.DOCS) paths.update(Clean.DOCS)
if self.all: if self.all or not paths:
paths.update(Clean.ALL) paths.update(Clean.ALL)
for path in sorted(list(paths)): for path in sorted(list(paths)):
@ -114,12 +98,12 @@ class Clean(Command):
print("removing {}".format(path)) print("removing {}".format(path))
class Build(Command): class Generate(Command):
description = "Build Pyrogram files" description = "Generate Pyrogram files"
user_options = [ user_options = [
("api", None, "Build API files"), ("api", None, "Generate API files"),
("docs", None, "Build docs files"), ("docs", None, "Generate docs files")
] ]
def __init__(self, dist, **kw): def __init__(self, dist, **kw):
@ -191,6 +175,6 @@ setup(
extras_require={"tgcrypto": ["tgcrypto>=1.0.4"]}, extras_require={"tgcrypto": ["tgcrypto>=1.0.4"]},
cmdclass={ cmdclass={
"clean": Clean, "clean": Clean,
"build": Build, "generate": Generate
} }
) )