diff --git a/compiler/error/source/400_BAD_REQUEST.tsv b/compiler/error/source/400_BAD_REQUEST.tsv index 762424b9..ac1989b8 100644 --- a/compiler/error/source/400_BAD_REQUEST.tsv +++ b/compiler/error/source/400_BAD_REQUEST.tsv @@ -46,3 +46,4 @@ CHAT_ADMIN_REQUIRED The method requires admin privileges PHONE_NUMBER_BANNED The phone number is banned ABOUT_TOO_LONG The about text is too long MULTI_MEDIA_TOO_LONG The album contains more than 10 items +USERNAME_OCCUPIED The username is already in use diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py index 990ccede..511a3e45 100644 --- a/pyrogram/client/client.py +++ b/pyrogram/client/client.py @@ -97,12 +97,12 @@ class Client: be an empty string: "" workers (:obj:`int`, optional): - Thread pool size for handling incoming events (updates). Defaults to 4. + Thread pool size for handling incoming updates. Defaults to 4. """ INVITE_LINK_RE = re.compile(r"^(?:https?://)?t\.me/joinchat/(.+)$") DIALOGS_AT_ONCE = 100 - UPDATE_WORKERS = 2 + UPDATES_WORKERS = 2 def __init__(self, session_name: str, @@ -133,6 +133,8 @@ class Client: self.peers_by_id = {} self.peers_by_username = {} + self.channels_pts = {} + self.markdown = Markdown(self.peers_by_id) self.html = HTML(self.peers_by_id) @@ -142,9 +144,9 @@ class Client: self.is_idle = Event() - self.event_handler = None + self.updates_queue = Queue() self.update_queue = Queue() - self.event_queue = Queue() + self.update_handler = None def start(self): """Use this method to start the Client after creating it. @@ -177,11 +179,11 @@ class Client: self.rnd_id = self.session.msg_id self.get_dialogs() - for i in range(self.UPDATE_WORKERS): - Thread(target=self.update_worker, name="UpdateWorker#{}".format(i + 1)).start() + for i in range(self.UPDATES_WORKERS): + Thread(target=self.updates_worker, name="UpdatesWorker#{}".format(i + 1)).start() for i in range(self.workers): - Thread(target=self.event_worker, name="EventWorker#{}".format(i + 1)).start() + Thread(target=self.update_worker, name="UpdateWorker#{}".format(i + 1)).start() mimetypes.init() @@ -191,11 +193,11 @@ class Client: """ self.session.stop() - for i in range(self.UPDATE_WORKERS): - self.update_queue.put(None) + for _ in range(self.UPDATES_WORKERS): + self.updates_queue.put(None) - for i in range(self.workers): - self.event_queue.put(None) + for _ in range(self.workers): + self.update_queue.put(None) def fetch_peers(self, entities: list): for entity in entities: @@ -258,6 +260,73 @@ class Client: if username is not None: self.peers_by_username[username] = input_peer + def updates_worker(self): + name = threading.current_thread().name + log.debug("{} started".format(name)) + + while True: + updates = self.updates_queue.get() + + if updates is None: + break + + try: + if isinstance(updates, (types.Update, types.UpdatesCombined)): + self.fetch_peers(updates.users) + self.fetch_peers(updates.chats) + + for update in updates.updates: + channel_id = getattr( + getattr( + getattr( + update, "message", None + ), "to_id", None + ), "channel_id", None + ) or getattr(update, "channel_id", None) + + pts = getattr(update, "pts", None) + + if channel_id and pts: + if channel_id not in self.channels_pts: + self.channels_pts[channel_id] = [] + + if pts in self.channels_pts[channel_id]: + continue + + self.channels_pts[channel_id].append(pts) + + if len(self.channels_pts[channel_id]) > 50: + self.channels_pts[channel_id] = self.channels_pts[channel_id][25:] + + self.update_queue.put((update, updates.users, updates.chats)) + elif isinstance(updates, (types.UpdateShortMessage, types.UpdateShortChatMessage)): + diff = self.send( + functions.updates.GetDifference( + pts=updates.pts - updates.pts_count, + date=updates.date, + qts=-1 + ) + ) + + self.fetch_peers(diff.users) + self.fetch_peers(diff.chats) + + self.update_queue.put(( + types.UpdateNewMessage( + message=diff.new_messages[0], + pts=updates.pts, + pts_count=updates.pts_count + ), + diff.users, + diff.chats + )) + elif isinstance(updates, types.UpdateShort): + self.update_queue.put((updates.update, [], [])) + except Exception as e: + log.error(e, exc_info=True) + + log.debug("{} stopped".format(name)) + def update_worker(self): name = threading.current_thread().name log.debug("{} started".format(name)) @@ -269,59 +338,13 @@ class Client: break try: - if isinstance(update, (types.Update, types.UpdatesCombined)): - self.fetch_peers(update.users) - self.fetch_peers(update.chats) - - for i in update.updates: - self.event_queue.put(i) - elif isinstance(update, types.UpdateShortMessage): - if update.user_id not in self.peers_by_id: - diff = self.send( - functions.updates.GetDifference( - pts=update.pts - 1, - date=update.date, - qts=-1 - ) - ) - - self.fetch_peers(diff.users) - - self.event_queue.put(update) - elif isinstance(update, types.UpdateShortChatMessage): - if update.chat_id not in self.peers_by_id: - diff = self.send( - functions.updates.GetDifference( - pts=update.pts - 1, - date=update.date, - qts=-1 - ) - ) - - self.fetch_peers(diff.users) - self.fetch_peers(diff.chats) - - self.event_queue.put(update) - elif isinstance(update, types.UpdateShort): - self.event_queue.put(update.update) - except Exception as e: - log.error(e, exc_info=True) - - log.debug("{} stopped".format(name)) - - def event_worker(self): - name = threading.current_thread().name - log.debug("{} started".format(name)) - - while True: - event = self.event_queue.get() - - if event is None: - break - - try: - if self.event_handler: - self.event_handler(self, event) + if self.update_handler: + self.update_handler( + self, + update[0], + {i.id: i for i in update[1]}, + {i.id: i for i in update[2]} + ) except Exception as e: log.error(e, exc_info=True) @@ -345,18 +368,49 @@ class Client: self.is_idle.wait() - def set_event_handler(self, callback: callable): - """Use this method to set the event handler. + def set_update_handler(self, callback: callable): + """Use this method to set the update handler. + + You must call this method *before* you *start()* the Client. Args: callback (:obj:`callable`): - A function that takes ``client, event`` as positional arguments. - It will be called when a new event is generated on your account. + A function that will be called when a new update is received from the server. It takes + :obj:`(client, update, users, chats)` as positional arguments (Look at the section below for + a detailed description). + + Other Parameters: + client (:obj:`pyrogram.Client`): + The Client itself, useful when you want to call other API methods inside the update handler. + + update (:obj:`Update`): + The received update, which can be one of the many single Updates listed in the *updates* + field you see in the :obj:`types.Update ` type. + + users (:obj:`dict`): + Dictionary of all :obj:`types.User ` mentioned in the update. + You can access extra info about the user (such as *first_name*, *last_name*, etc...) by using + the IDs you find in the *update* argument (e.g.: *users[1768841572]*). + + chats (:obj:`dict`): + Dictionary of all :obj:`types.Chat ` and + :obj:`types.Channel ` mentioned in the update. + You can access extra info about the chat (such as *title*, *participants_count*, etc...) + by using the IDs you find in the *update* argument (e.g.: *chats[1701277281]*). + + Note: + The following Empty or Forbidden types may exist inside the *users* and *chats* dictionaries. + They mean you have been blocked by the user or banned from the group/channel. + + - :obj:`types.UserEmpty ` + - :obj:`types.ChatEmpty ` + - :obj:`types.ChatForbidden ` + - :obj:`types.ChannelForbidden ` """ - self.event_handler = callback + self.update_handler = callback def send(self, data: Object): - """Use this method to send :ref:`Raw Function ` queries. + """Use this method to send Raw Function queries. This method makes possible to manually call every single Telegram API method in a low-level manner. Available functions are listed in the :obj:`pyrogram.api.functions` package and may accept compound @@ -650,24 +704,51 @@ class Client: return input_peer def resolve_peer(self, peer_id: int or str): - if peer_id in ("self", "me"): - return InputPeerSelf() - else: - if type(peer_id) is str: - peer_id = peer_id.lower().strip("@") + """Use this method to get the *InputPeer* of a known *peer_id*. - try: - return self.peers_by_username[peer_id] - except KeyError: - return self.resolve_username(peer_id) - else: - try: - return self.peers_by_id[peer_id] - except KeyError: - try: - return self.peers_by_id[int("-100" + str(peer_id))] - except KeyError: - raise PeerIdInvalid + It is intended to be used when working with Raw Functions (i.e: a Telegram API method you wish to use which is + not available yet in the Client class as an easy-to-use method). + + Args: + peer_id (:obj:`int` | :obj:`str` | :obj:`Peer`): + The Peer ID you want to extract the InputPeer from. Can be one of these types: :obj:`int` (direct ID), + :obj:`str` (@username), :obj:`PeerUser `, + :obj:`PeerChat `, :obj:`PeerChannel ` + + Returns: + :obj:`InputPeerUser ` or + :obj:`InputPeerChat ` or + :obj:`InputPeerChannel ` depending on the *peer_id*. + + Raises: + :class:`pyrogram.Error` + """ + if type(peer_id) is str: + if peer_id in ("self", "me"): + return InputPeerSelf() + + peer_id = peer_id.lower().strip("@") + + try: + return self.peers_by_username[peer_id] + except KeyError: + return self.resolve_username(peer_id) + + if type(peer_id) is not int: + if isinstance(peer_id, types.PeerUser): + peer_id = peer_id.user_id + elif isinstance(peer_id, types.PeerChat): + peer_id = peer_id.chat_id + elif isinstance(peer_id, types.PeerChannel): + peer_id = int("-100" + str(peer_id.channel_id)) + + try: + return self.peers_by_id[peer_id] + except KeyError: + try: + return self.peers_by_id[int("-100" + str(peer_id))] + except KeyError: + raise PeerIdInvalid def get_me(self): """A simple method for testing the user authorization. Requires no parameters. diff --git a/pyrogram/session/internals/msg_id.py b/pyrogram/session/internals/msg_id.py index 583e0320..cf8c0402 100644 --- a/pyrogram/session/internals/msg_id.py +++ b/pyrogram/session/internals/msg_id.py @@ -16,6 +16,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +from threading import Lock from time import time @@ -24,11 +25,13 @@ class MsgId: self.delta_time = delta_time self.last_time = 0 self.offset = 0 + self.lock = Lock() def __call__(self) -> int: - now = time() - self.offset = self.offset + 4 if now == self.last_time else 0 - msg_id = int((now + self.delta_time) * 2 ** 32) + self.offset - self.last_time = now + with self.lock: + now = time() + self.offset = self.offset + 4 if now == self.last_time else 0 + msg_id = int((now + self.delta_time) * 2 ** 32) + self.offset + self.last_time = now - return msg_id + return msg_id diff --git a/pyrogram/session/internals/seq_no.py b/pyrogram/session/internals/seq_no.py index 44a953c5..bef0d1a3 100644 --- a/pyrogram/session/internals/seq_no.py +++ b/pyrogram/session/internals/seq_no.py @@ -16,15 +16,19 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +from threading import Lock + class SeqNo: def __init__(self): self.content_related_messages_sent = 0 + self.lock = Lock() def __call__(self, is_content_related: bool) -> int: - seq_no = (self.content_related_messages_sent * 2) + (1 if is_content_related else 0) + with self.lock: + seq_no = (self.content_related_messages_sent * 2) + (1 if is_content_related else 0) - if is_content_related: - self.content_related_messages_sent += 1 + if is_content_related: + self.content_related_messages_sent += 1 - return seq_no + return seq_no diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index e3e236b6..8e56911f 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -60,7 +60,7 @@ class Session: ) INITIAL_SALT = 0x616e67656c696361 - NET_WORKERS = 2 + NET_WORKERS = 1 WAIT_TIMEOUT = 10 MAX_RETRIES = 5 ACKS_THRESHOLD = 8 @@ -270,7 +270,7 @@ class Session: msg_id = msg.body.msg_id else: if self.client is not None: - self.client.update_queue.put(msg.body) + self.client.updates_queue.put(msg.body) if msg_id in self.results: self.results[msg_id].value = getattr(msg.body, "result", msg.body)