diff --git a/compiler/api/source/main_api.tl b/compiler/api/source/main_api.tl index ca2ea52f..b4d0e7a8 100644 --- a/compiler/api/source/main_api.tl +++ b/compiler/api/source/main_api.tl @@ -431,7 +431,7 @@ accountDaysTTL#b8d0afdf days:int = AccountDaysTTL; documentAttributeImageSize#6c37c15c w:int h:int = DocumentAttribute; documentAttributeAnimated#11b58939 = DocumentAttribute; documentAttributeSticker#6319d612 flags:# mask:flags.1?true alt:string stickerset:InputStickerSet mask_coords:flags.0?MaskCoords = DocumentAttribute; -documentAttributeVideo#ef02ce6 flags:# round_message:flags.0?true duration:int w:int h:int = DocumentAttribute; +documentAttributeVideo#ef02ce6 flags:# round_message:flags.0?true supports_streaming:flags.1?true duration:int w:int h:int = DocumentAttribute; documentAttributeAudio#9852f9c6 flags:# voice:flags.10?true duration:int title:flags.0?string performer:flags.1?string waveform:flags.2?bytes = DocumentAttribute; documentAttributeFilename#15590068 file_name:string = DocumentAttribute; documentAttributeHasStickers#9801d2f7 = DocumentAttribute; @@ -815,6 +815,14 @@ help.recentMeUrls#e0310d7 urls:Vector chats:Vector users:Vect inputSingleMedia#1cc6e91f flags:# media:InputMedia random_id:long message:string entities:flags.0?Vector = InputSingleMedia; +webAuthorization#cac943f2 hash:long bot_id:int domain:string browser:string platform:string date_created:int date_active:int ip:string region:string = WebAuthorization; + +account.webAuthorizations#ed56c9fc authorizations:Vector users:Vector = account.WebAuthorizations; + +inputMessageID#a676a322 id:int = InputMessage; +inputMessageReplyTo#bad88395 id:int = InputMessage; +inputMessagePinned#86872538 = InputMessage; + ---functions--- invokeAfterMsg#cb9f372d {X:Type} msg_id:long query:!X = X; @@ -868,6 +876,9 @@ account.updatePasswordSettings#fa7c4b86 current_password_hash:bytes new_settings account.sendConfirmPhoneCode#1516d7bd flags:# allow_flashcall:flags.0?true hash:string current_number:flags.0?Bool = auth.SentCode; account.confirmPhone#5f2178c3 phone_code_hash:string phone_code:string = Bool; account.getTmpPassword#4a82327e password_hash:bytes period:int = account.TmpPassword; +account.getWebAuthorizations#182e6d6f = account.WebAuthorizations; +account.resetWebAuthorization#2d01b9ef hash:long = Bool; +account.resetWebAuthorizations#682d2594 = Bool; users.getUsers#d91a548 id:Vector = Vector; users.getFullUser#ca30a5b1 id:InputUser = UserFull; @@ -888,7 +899,7 @@ contacts.getTopPeers#d4982db5 flags:# correspondents:flags.0?true bots_pm:flags. contacts.resetTopPeerRating#1ae373ac category:TopPeerCategory peer:InputPeer = Bool; contacts.resetSaved#879537f1 = Bool; -messages.getMessages#4222fa74 id:Vector = messages.Messages; +messages.getMessages#63c66506 id:Vector = messages.Messages; messages.getDialogs#191ba9c5 flags:# exclude_pinned:flags.0?true offset_date:int offset_id:int offset_peer:InputPeer limit:int = messages.Dialogs; messages.getHistory#dcbb8260 peer:InputPeer offset_id:int offset_date:int add_offset:int limit:int max_id:int min_id:int hash:int = messages.Messages; messages.search#39e9ea0 flags:# peer:InputPeer q:string from_id:flags.0?InputUser filter:MessagesFilter min_date:int max_date:int offset_id:int add_offset:int limit:int max_id:int min_id:int = messages.Messages; @@ -1016,7 +1027,7 @@ channels.readHistory#cc104937 channel:InputChannel max_id:int = Bool; channels.deleteMessages#84c1fd4e channel:InputChannel id:Vector = messages.AffectedMessages; channels.deleteUserHistory#d10dd71b channel:InputChannel user_id:InputUser = messages.AffectedHistory; channels.reportSpam#fe087810 channel:InputChannel user_id:InputUser id:Vector = Bool; -channels.getMessages#93d7b347 channel:InputChannel id:Vector = messages.Messages; +channels.getMessages#ad8c9a23 channel:InputChannel id:Vector = messages.Messages; channels.getParticipants#123e05e9 channel:InputChannel filter:ChannelParticipantsFilter offset:int limit:int hash:int = channels.ChannelParticipants; channels.getParticipant#546dd7a6 channel:InputChannel user_id:InputUser = channels.ChannelParticipant; channels.getChannels#a7f6bbb id:Vector = messages.Chats; diff --git a/compiler/error/source/400_BAD_REQUEST.tsv b/compiler/error/source/400_BAD_REQUEST.tsv index 0b92aba5..762424b9 100644 --- a/compiler/error/source/400_BAD_REQUEST.tsv +++ b/compiler/error/source/400_BAD_REQUEST.tsv @@ -44,3 +44,5 @@ FILE_ID_INVALID The file id is invalid LOCATION_INVALID The file location is invalid CHAT_ADMIN_REQUIRED The method requires admin privileges PHONE_NUMBER_BANNED The phone number is banned +ABOUT_TOO_LONG The about text is too long +MULTI_MEDIA_TOO_LONG The album contains more than 10 items diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py index d5a69e73..990ccede 100644 --- a/pyrogram/client/client.py +++ b/pyrogram/client/client.py @@ -23,12 +23,14 @@ import math import mimetypes import os import re +import threading import time from collections import namedtuple from configparser import ConfigParser from hashlib import sha256, md5 +from queue import Queue from signal import signal, SIGINT, SIGTERM, SIGABRT -from threading import Event +from threading import Event, Thread from pyrogram.api import functions, types from pyrogram.api.core import Object @@ -41,8 +43,7 @@ from pyrogram.api.errors import ( ) from pyrogram.api.types import ( User, Chat, Channel, - PeerUser, PeerChat, PeerChannel, - Dialog, Message, + PeerUser, PeerChannel, InputPeerEmpty, InputPeerSelf, InputPeerUser, InputPeerChat, InputPeerChannel ) @@ -96,11 +97,12 @@ class Client: be an empty string: "" workers (:obj:`int`, optional): - Thread pool size for handling incoming messages (updates). Defaults to 4. + Thread pool size for handling incoming events (updates). Defaults to 4. """ INVITE_LINK_RE = re.compile(r"^(?:https?://)?t\.me/joinchat/(.+)$") DIALOGS_AT_ONCE = 100 + UPDATE_WORKERS = 2 def __init__(self, session_name: str, @@ -138,9 +140,12 @@ class Client: self.proxy = None self.session = None - self.update_handler = None self.is_idle = Event() + self.event_handler = None + self.update_queue = Queue() + self.event_queue = Queue() + def start(self): """Use this method to start the Client after creating it. Requires no parameters. @@ -157,7 +162,7 @@ class Client: self.proxy, self.auth_key, self.config.api_id, - workers=self.workers + client=self ) terms = self.session.start() @@ -171,7 +176,12 @@ class Client: self.rnd_id = self.session.msg_id self.get_dialogs() - self.session.update_handler = self.update_handler + + for i in range(self.UPDATE_WORKERS): + Thread(target=self.update_worker, name="UpdateWorker#{}".format(i + 1)).start() + + for i in range(self.workers): + Thread(target=self.event_worker, name="EventWorker#{}".format(i + 1)).start() mimetypes.init() @@ -181,6 +191,142 @@ class Client: """ self.session.stop() + for i in range(self.UPDATE_WORKERS): + self.update_queue.put(None) + + for i in range(self.workers): + self.event_queue.put(None) + + def fetch_peers(self, entities: list): + for entity in entities: + if isinstance(entity, User): + user_id = entity.id + + if user_id in self.peers_by_id: + continue + + access_hash = entity.access_hash + + if access_hash is None: + continue + + username = entity.username + + input_peer = InputPeerUser( + user_id=user_id, + access_hash=access_hash + ) + + self.peers_by_id[user_id] = input_peer + + if username is not None: + self.peers_by_username[username] = input_peer + + if isinstance(entity, Chat): + chat_id = entity.id + + if chat_id in self.peers_by_id: + continue + + input_peer = InputPeerChat( + chat_id=chat_id + ) + + self.peers_by_id[chat_id] = input_peer + + if isinstance(entity, Channel): + channel_id = entity.id + peer_id = int("-100" + str(channel_id)) + + if peer_id in self.peers_by_id: + continue + + access_hash = entity.access_hash + + if access_hash is None: + continue + + username = entity.username + + input_peer = InputPeerChannel( + channel_id=channel_id, + access_hash=access_hash + ) + + self.peers_by_id[peer_id] = input_peer + + if username is not None: + self.peers_by_username[username] = input_peer + + def update_worker(self): + name = threading.current_thread().name + log.debug("{} started".format(name)) + + while True: + update = self.update_queue.get() + + if update is None: + break + + try: + if isinstance(update, (types.Update, types.UpdatesCombined)): + self.fetch_peers(update.users) + self.fetch_peers(update.chats) + + for i in update.updates: + self.event_queue.put(i) + elif isinstance(update, types.UpdateShortMessage): + if update.user_id not in self.peers_by_id: + diff = self.send( + functions.updates.GetDifference( + pts=update.pts - 1, + date=update.date, + qts=-1 + ) + ) + + self.fetch_peers(diff.users) + + self.event_queue.put(update) + elif isinstance(update, types.UpdateShortChatMessage): + if update.chat_id not in self.peers_by_id: + diff = self.send( + functions.updates.GetDifference( + pts=update.pts - 1, + date=update.date, + qts=-1 + ) + ) + + self.fetch_peers(diff.users) + self.fetch_peers(diff.chats) + + self.event_queue.put(update) + elif isinstance(update, types.UpdateShort): + self.event_queue.put(update.update) + except Exception as e: + log.error(e, exc_info=True) + + log.debug("{} stopped".format(name)) + + def event_worker(self): + name = threading.current_thread().name + log.debug("{} started".format(name)) + + while True: + event = self.event_queue.get() + + if event is None: + break + + try: + if self.event_handler: + self.event_handler(self, event) + except Exception as e: + log.error(e, exc_info=True) + + log.debug("{} stopped".format(name)) + def signal_handler(self, *args): self.stop() self.is_idle.set() @@ -199,15 +345,15 @@ class Client: self.is_idle.wait() - # TODO: Better update handler - def set_update_handler(self, callback: callable): - """Use this method to set the update handler. + def set_event_handler(self, callback: callable): + """Use this method to set the event handler. Args: callback (:obj:`callable`): - A callback function that accepts a single argument: the update object. + A function that takes ``client, event`` as positional arguments. + It will be called when a new event is generated on your account. """ - self.update_handler = callback + self.event_handler = callback def send(self, data: Object): """Use this method to send :ref:`Raw Function ` queries. @@ -263,7 +409,7 @@ class Client: self.proxy, self.auth_key, self.config.api_id, - workers=self.workers + client=self ) self.session.start() @@ -433,53 +579,17 @@ class Client: ) def get_dialogs(self): - peers = [] + def parse_dialogs(d): + self.fetch_peers(d.chats) + self.fetch_peers(d.users) - def parse_dialogs(d) -> int: - oldest_date = 1 << 32 - - for dialog in d.dialogs: # type: Dialog - # Only search for Users, Chats and Channels - if not isinstance(dialog.peer, (PeerUser, PeerChat, PeerChannel)): + for m in reversed(d.messages): + if isinstance(m, types.MessageEmpty): continue - - if isinstance(dialog.peer, PeerUser): - peer_type = "user" - peer_id = dialog.peer.user_id - elif isinstance(dialog.peer, PeerChat): - peer_type = "chat" - peer_id = dialog.peer.chat_id - elif isinstance(dialog.peer, PeerChannel): - peer_type = "channel" - peer_id = dialog.peer.channel_id else: - continue - - for message in d.messages: # type: Message - is_this = peer_id == message.from_id or dialog.peer == message.to_id - - if is_this: - for entity in (d.users if peer_type == "user" else d.chats): # type: User or Chat or Channel - if entity.id == peer_id: - peers.append( - dict( - id=peer_id, - access_hash=getattr(entity, "access_hash", None), - type=peer_type, - first_name=getattr(entity, "first_name", None), - last_name=getattr(entity, "last_name", None), - title=getattr(entity, "title", None), - username=getattr(entity, "username", None), - ) - ) - - if message.date < oldest_date: - oldest_date = message.date - - break - break - - return oldest_date + return m.date + else: + return 0 pinned_dialogs = self.send(functions.messages.GetPinnedDialogs()) parse_dialogs(pinned_dialogs) @@ -492,7 +602,7 @@ class Client: ) offset_date = parse_dialogs(dialogs) - log.info("Dialogs count: {}".format(len(peers))) + log.info("Entities count: {}".format(len(self.peers_by_id))) while len(dialogs.dialogs) == self.DIALOGS_AT_ONCE: try: @@ -508,37 +618,7 @@ class Client: continue offset_date = parse_dialogs(dialogs) - log.info("Dialogs count: {}".format(len(peers))) - - for i in peers: - peer_id = i["id"] - peer_type = i["type"] - peer_username = i["username"] - peer_access_hash = i["access_hash"] - - if peer_type == "user": - input_peer = InputPeerUser( - peer_id, - peer_access_hash - ) - elif peer_type == "chat": - input_peer = InputPeerChat( - peer_id - ) - elif peer_type == "channel": - input_peer = InputPeerChannel( - peer_id, - peer_access_hash - ) - peer_id = int("-100" + str(peer_id)) - else: - continue - - self.peers_by_id[peer_id] = input_peer - - if peer_username: - peer_username = peer_username.lower() - self.peers_by_username[peer_username] = input_peer + log.info("Entities count: {}".format(len(self.peers_by_id))) def resolve_username(self, username: str): username = username.lower().strip("@") @@ -985,7 +1065,8 @@ class Client: duration=duration, w=width, h=height - ) + ), + types.DocumentAttributeFilename(os.path.basename(video)) ] ), silent=disable_notification or None, @@ -2016,7 +2097,8 @@ class Client: duration=i.duration, w=i.width, h=i.height - ) + ), + types.DocumentAttributeFilename(os.path.basename(i.media)) ] ) ) diff --git a/pyrogram/client/input_media.py b/pyrogram/client/input_media.py index d7612121..74da22b8 100644 --- a/pyrogram/client/input_media.py +++ b/pyrogram/client/input_media.py @@ -1,3 +1,22 @@ +# Pyrogram - Telegram MTProto API Client Library for Python +# Copyright (C) 2017-2018 Dan Tès +# +# This file is part of Pyrogram. +# +# Pyrogram is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Pyrogram is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Pyrogram. If not, see . + + class InputMedia: class Photo: """This object represents a photo to be sent inside an album. diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index 23d686a4..e3e236b6 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -26,6 +26,7 @@ from os import urandom from queue import Queue from threading import Event, Thread +import pyrogram from pyrogram import __copyright__, __license__, __version__ from pyrogram.api import functions, types, core from pyrogram.api.all import layer @@ -59,7 +60,7 @@ class Session: ) INITIAL_SALT = 0x616e67656c696361 - + NET_WORKERS = 2 WAIT_TIMEOUT = 10 MAX_RETRIES = 5 ACKS_THRESHOLD = 8 @@ -74,18 +75,16 @@ class Session: auth_key: bytes, api_id: str, is_cdn: bool = False, - workers: int = 2): + client: pyrogram = None): if not Session.notice_displayed: print("Pyrogram v{}, {}".format(__version__, __copyright__)) print("Licensed under the terms of the " + __license__, end="\n\n") Session.notice_displayed = True - self.is_cdn = is_cdn - self.workers = workers - self.connection = Connection(DataCenter(dc_id, test_mode), proxy) - self.api_id = api_id + self.is_cdn = is_cdn + self.client = client self.auth_key = auth_key self.auth_key_id = sha1(auth_key).digest()[-8:] @@ -109,12 +108,6 @@ class Session: self.is_connected = Event() - self.update_handler = None - - self.total_connections = 0 - self.total_messages = 0 - self.total_bytes = 0 - def start(self): terms = None @@ -122,8 +115,8 @@ class Session: try: self.connection.connect() - for i in range(self.workers): - Thread(target=self.worker, name="Worker#{}".format(i + 1)).start() + for i in range(self.NET_WORKERS): + Thread(target=self.net_worker, name="NetWorker#{}".format(i + 1)).start() Thread(target=self.recv, name="RecvThread").start() @@ -159,7 +152,6 @@ class Session: break self.is_connected.set() - self.total_connections += 1 log.debug("Session started") @@ -182,7 +174,7 @@ class Session: self.connection.close() - for i in range(self.workers): + for i in range(self.NET_WORKERS): self.recv_queue.put(None) log.debug("Session stopped") @@ -193,10 +185,6 @@ class Session: def pack(self, message: Message): data = Long(self.current_salt.salt) + self.session_id + message.write() - # MTProto 2.0 requires a minimum of 12 padding bytes. - # I don't get why it says up to 1024 when what it actually needs after the - # required 12 bytes is just extra 0..15 padding bytes for aes - # TODO: It works, but recheck this. What's the meaning of 12..1024 padding bytes? padding = urandom(-(len(data) + 12) % 16 + 12) # 88 = 88 + 0 (outgoing message) @@ -230,7 +218,7 @@ class Session: return message - def worker(self): + def net_worker(self): name = threading.current_thread().name log.debug("{} started".format(name)) @@ -248,7 +236,6 @@ class Session: log.debug("{} stopped".format(name)) def unpack_dispatch_and_ack(self, packet: bytes): - # TODO: A better dispatcher data = self.unpack(BytesIO(packet)) messages = ( @@ -259,49 +246,36 @@ class Session: log.debug(data) - self.total_bytes += len(packet) - self.total_messages += len(messages) - - for i in messages: - if i.seq_no % 2 != 0: - if i.msg_id in self.pending_acks: + for msg in messages: + if msg.seq_no % 2 != 0: + if msg.msg_id in self.pending_acks: continue else: - self.pending_acks.add(i.msg_id) + self.pending_acks.add(msg.msg_id) - # log.debug("{}".format(type(i.body))) - - if isinstance(i.body, (types.MsgDetailedInfo, types.MsgNewDetailedInfo)): - self.pending_acks.add(i.body.answer_msg_id) + if isinstance(msg.body, (types.MsgDetailedInfo, types.MsgNewDetailedInfo)): + self.pending_acks.add(msg.body.answer_msg_id) continue - if isinstance(i.body, types.NewSessionCreated): + if isinstance(msg.body, types.NewSessionCreated): continue msg_id = None - if isinstance(i.body, (types.BadMsgNotification, types.BadServerSalt)): - msg_id = i.body.bad_msg_id - elif isinstance(i.body, (core.FutureSalts, types.RpcResult)): - msg_id = i.body.req_msg_id - elif isinstance(i.body, types.Pong): - msg_id = i.body.msg_id + if isinstance(msg.body, (types.BadMsgNotification, types.BadServerSalt)): + msg_id = msg.body.bad_msg_id + elif isinstance(msg.body, (core.FutureSalts, types.RpcResult)): + msg_id = msg.body.req_msg_id + elif isinstance(msg.body, types.Pong): + msg_id = msg.body.msg_id else: - if self.update_handler: - self.update_handler(i.body) + if self.client is not None: + self.client.update_queue.put(msg.body) if msg_id in self.results: - self.results[msg_id].value = getattr(i.body, "result", i.body) + self.results[msg_id].value = getattr(msg.body, "result", msg.body) self.results[msg_id].event.set() - # print( - # "This packet bytes: ({}) | Total bytes: ({})\n" - # "This packet messages: ({}) | Total messages: ({})\n" - # "Total connections: ({})".format( - # len(packet), self.total_bytes, len(messages), self.total_messages, self.total_connections - # ) - # ) - if len(self.pending_acks) >= self.ACKS_THRESHOLD: log.info("Send {} acks".format(len(self.pending_acks)))