Rename Object to TLObject

This commit is contained in:
Dan 2019-06-03 14:19:50 +02:00
parent 4d97aae933
commit d5517f4d5f
21 changed files with 65 additions and 62 deletions

View File

@ -53,10 +53,10 @@ def get_docstring_arg_type(t: str, is_list: bool = False, is_pyrogram_type: bool
return "``{}``".format(t.lower()) return "``{}``".format(t.lower())
elif t == "true": elif t == "true":
return "``bool``" return "``bool``"
elif t == "Object" or t == "X": elif t == "TLObject" or t == "X":
return "Any object from :obj:`pyrogram.api.types`" return "Any object from :obj:`~pyrogram.api.types`"
elif t == "!X": elif t == "!X":
return "Any method from :obj:`pyrogram.api.functions`" return "Any method from :obj:`~pyrogram.api.functions`"
elif t.startswith("Vector"): elif t.startswith("Vector"):
return "List of " + get_docstring_arg_type(t.split("<", 1)[1][:-1], True, is_pyrogram_type) return "List of " + get_docstring_arg_type(t.split("<", 1)[1][:-1], True, is_pyrogram_type)
else: else:
@ -394,7 +394,7 @@ def start():
) )
read_types += "\n " read_types += "\n "
read_types += "{} = Object.read(b{}) if flags & (1 << {}) else []\n ".format( read_types += "{} = TLObject.read(b{}) if flags & (1 << {}) else []\n ".format(
arg_name, ", {}".format(sub_type.title()) if sub_type in core_types else "", index arg_name, ", {}".format(sub_type.title()) if sub_type in core_types else "", index
) )
else: else:
@ -403,7 +403,7 @@ def start():
write_types += "b.write(self.{}.write())\n ".format(arg_name) write_types += "b.write(self.{}.write())\n ".format(arg_name)
read_types += "\n " read_types += "\n "
read_types += "{} = Object.read(b) if flags & (1 << {}) else None\n ".format( read_types += "{} = TLObject.read(b) if flags & (1 << {}) else None\n ".format(
arg_name, index arg_name, index
) )
else: else:
@ -422,7 +422,7 @@ def start():
) )
read_types += "\n " read_types += "\n "
read_types += "{} = Object.read(b{})\n ".format( read_types += "{} = TLObject.read(b{})\n ".format(
arg_name, ", {}".format(sub_type.title()) if sub_type in core_types else "" arg_name, ", {}".format(sub_type.title()) if sub_type in core_types else ""
) )
else: else:
@ -430,7 +430,7 @@ def start():
write_types += "b.write(self.{}.write())\n ".format(arg_name) write_types += "b.write(self.{}.write())\n ".format(arg_name)
read_types += "\n " read_types += "\n "
read_types += "{} = Object.read(b)\n ".format(arg_name) read_types += "{} = TLObject.read(b)\n ".format(arg_name)
if c.docs: if c.docs:
description = c.docs.split("|")[0].split("§")[1] description = c.docs.split("|")[0].split("§")[1]

View File

@ -5,7 +5,7 @@ from io import BytesIO
from pyrogram.api.core import * from pyrogram.api.core import *
class {class_name}(Object): class {class_name}(TLObject):
"""{docstring_args} """{docstring_args}
""" """

View File

@ -19,8 +19,8 @@
from importlib import import_module from importlib import import_module
from .all import objects from .all import objects
from .core.object import Object from .core.tl_object import TLObject
for k, v in objects.items(): for k, v in objects.items():
path, name = v.rsplit(".", 1) path, name = v.rsplit(".", 1)
Object.all[k] = getattr(import_module(path), name) TLObject.all[k] = getattr(import_module(path), name)

View File

@ -22,7 +22,7 @@ from .gzip_packed import GzipPacked
from .list import List from .list import List
from .message import Message from .message import Message
from .msg_container import MsgContainer from .msg_container import MsgContainer
from .object import Object from .tl_object import TLObject
from .primitives import ( from .primitives import (
Bool, BoolTrue, BoolFalse, Bytes, Double, Bool, BoolTrue, BoolFalse, Bytes, Double,
Int, Long, Int128, Int256, Null, String, Vector Int, Long, Int128, Int256, Null, String, Vector

View File

@ -18,11 +18,11 @@
from io import BytesIO from io import BytesIO
from .object import Object from .tl_object import TLObject
from .primitives import Int, Long from .primitives import Int, Long
class FutureSalt(Object): class FutureSalt(TLObject):
ID = 0x0949d9dc ID = 0x0949d9dc
__slots__ = ["valid_since", "valid_until", "salt"] __slots__ = ["valid_since", "valid_until", "salt"]

View File

@ -19,11 +19,11 @@
from io import BytesIO from io import BytesIO
from . import FutureSalt from . import FutureSalt
from .object import Object from .tl_object import TLObject
from .primitives import Int, Long from .primitives import Int, Long
class FutureSalts(Object): class FutureSalts(TLObject):
ID = 0xae500895 ID = 0xae500895
__slots__ = ["req_msg_id", "now", "salts"] __slots__ = ["req_msg_id", "now", "salts"]

View File

@ -19,24 +19,24 @@
from gzip import compress, decompress from gzip import compress, decompress
from io import BytesIO from io import BytesIO
from .object import Object from .tl_object import TLObject
from .primitives import Int, Bytes from .primitives import Int, Bytes
class GzipPacked(Object): class GzipPacked(TLObject):
ID = 0x3072cfa1 ID = 0x3072cfa1
__slots__ = ["packed_data"] __slots__ = ["packed_data"]
QUALNAME = "GzipPacked" QUALNAME = "GzipPacked"
def __init__(self, packed_data: Object): def __init__(self, packed_data: TLObject):
self.packed_data = packed_data self.packed_data = packed_data
@staticmethod @staticmethod
def read(b: BytesIO, *args) -> "GzipPacked": def read(b: BytesIO, *args) -> "GzipPacked":
# Return the Object itself instead of a GzipPacked wrapping it # Return the Object itself instead of a GzipPacked wrapping it
return Object.read( return TLObject.read(
BytesIO( BytesIO(
decompress( decompress(
Bytes.read(b) Bytes.read(b)

View File

@ -16,13 +16,13 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from .object import Object from .tl_object import TLObject
class List(list, Object): class List(list, TLObject):
__slots__ = [] __slots__ = []
def __repr__(self): def __repr__(self):
return "pyrogram.api.core.List([{}])".format( return "pyrogram.api.core.List([{}])".format(
",".join(Object.__repr__(i) for i in self) ",".join(TLObject.__repr__(i) for i in self)
) )

View File

@ -18,18 +18,18 @@
from io import BytesIO from io import BytesIO
from .object import Object from .tl_object import TLObject
from .primitives import Int, Long from .primitives import Int, Long
class Message(Object): class Message(TLObject):
ID = 0x5bb8e511 # hex(crc32(b"message msg_id:long seqno:int bytes:int body:Object = Message")) ID = 0x5bb8e511 # hex(crc32(b"message msg_id:long seqno:int bytes:int body:Object = Message"))
__slots__ = ["msg_id", "seq_no", "length", "body"] __slots__ = ["msg_id", "seq_no", "length", "body"]
QUALNAME = "Message" QUALNAME = "Message"
def __init__(self, body: Object, msg_id: int, seq_no: int, length: int): def __init__(self, body: TLObject, msg_id: int, seq_no: int, length: int):
self.msg_id = msg_id self.msg_id = msg_id
self.seq_no = seq_no self.seq_no = seq_no
self.length = length self.length = length
@ -42,7 +42,7 @@ class Message(Object):
length = Int.read(b) length = Int.read(b)
body = b.read(length) body = b.read(length)
return Message(Object.read(BytesIO(body)), msg_id, seq_no, length) return Message(TLObject.read(BytesIO(body)), msg_id, seq_no, length)
def write(self) -> bytes: def write(self) -> bytes:
b = BytesIO() b = BytesIO()

View File

@ -19,11 +19,11 @@
from io import BytesIO from io import BytesIO
from .message import Message from .message import Message
from .object import Object from .tl_object import TLObject
from .primitives import Int from .primitives import Int
class MsgContainer(Object): class MsgContainer(TLObject):
ID = 0x73f1f8dc ID = 0x73f1f8dc
__slots__ = ["messages"] __slots__ = ["messages"]

View File

@ -18,10 +18,10 @@
from io import BytesIO from io import BytesIO
from ..object import Object from ..tl_object import TLObject
class BoolFalse(Object): class BoolFalse(TLObject):
ID = 0xbc799737 ID = 0xbc799737
value = False value = False
@ -38,7 +38,7 @@ class BoolTrue(BoolFalse):
value = True value = True
class Bool(Object): class Bool(TLObject):
@classmethod @classmethod
def read(cls, b: BytesIO) -> bool: def read(cls, b: BytesIO) -> bool:
return int.from_bytes(b.read(4), "little") == BoolTrue.ID return int.from_bytes(b.read(4), "little") == BoolTrue.ID

View File

@ -18,10 +18,10 @@
from io import BytesIO from io import BytesIO
from ..object import Object from ..tl_object import TLObject
class Bytes(Object): class Bytes(TLObject):
@staticmethod @staticmethod
def read(b: BytesIO, *args) -> bytes: def read(b: BytesIO, *args) -> bytes:
length = int.from_bytes(b.read(1), "little") length = int.from_bytes(b.read(1), "little")

View File

@ -19,10 +19,10 @@
from io import BytesIO from io import BytesIO
from struct import unpack, pack from struct import unpack, pack
from ..object import Object from ..tl_object import TLObject
class Double(Object): class Double(TLObject):
@staticmethod @staticmethod
def read(b: BytesIO, *args) -> float: def read(b: BytesIO, *args) -> float:
return unpack("d", b.read(8))[0] return unpack("d", b.read(8))[0]

View File

@ -18,10 +18,10 @@
from io import BytesIO from io import BytesIO
from ..object import Object from ..tl_object import TLObject
class Int(Object): class Int(TLObject):
SIZE = 4 SIZE = 4
@classmethod @classmethod

View File

@ -18,10 +18,10 @@
from io import BytesIO from io import BytesIO
from ..object import Object from ..tl_object import TLObject
class Null(Object): class Null(TLObject):
ID = 0x56730bcc ID = 0x56730bcc
@staticmethod @staticmethod

View File

@ -20,31 +20,31 @@ from io import BytesIO
from . import Int from . import Int
from ..list import List from ..list import List
from ..object import Object from ..tl_object import TLObject
class Vector(Object): class Vector(TLObject):
ID = 0x1cb5c415 ID = 0x1cb5c415
# Method added to handle the special case when a query returns a bare Vector (of Ints); # Method added to handle the special case when a query returns a bare Vector (of Ints);
# i.e., RpcResult body starts with 0x1cb5c415 (Vector Id) - e.g., messages.GetMessagesViews. # i.e., RpcResult body starts with 0x1cb5c415 (Vector Id) - e.g., messages.GetMessagesViews.
@staticmethod @staticmethod
def _read(b: BytesIO) -> Object or int: def _read(b: BytesIO) -> TLObject or int:
try: try:
return Object.read(b) return TLObject.read(b)
except KeyError: except KeyError:
b.seek(-4, 1) b.seek(-4, 1)
return Int.read(b) return Int.read(b)
@staticmethod @staticmethod
def read(b: BytesIO, t: Object = None) -> list: def read(b: BytesIO, t: TLObject = None) -> list:
return List( return List(
t.read(b) if t t.read(b) if t
else Vector._read(b) else Vector._read(b)
for _ in range(Int.read(b)) for _ in range(Int.read(b))
) )
def __new__(cls, value: list, t: Object = None) -> bytes: def __new__(cls, value: list, t: TLObject = None) -> bytes:
return b"".join( return b"".join(
[Int(cls.ID, False), Int(len(value))] [Int(cls.ID, False), Int(len(value))]
+ [ + [

View File

@ -21,7 +21,7 @@ from io import BytesIO
from json import dumps from json import dumps
class Object: class TLObject:
all = {} all = {}
__slots__ = [] __slots__ = []
@ -30,13 +30,13 @@ class Object:
@staticmethod @staticmethod
def read(b: BytesIO, *args): # TODO: Rename b -> data def read(b: BytesIO, *args): # TODO: Rename b -> data
return Object.all[int.from_bytes(b.read(4), "little")].read(b, *args) return TLObject.all[int.from_bytes(b.read(4), "little")].read(b, *args)
def write(self, *args) -> bytes: def write(self, *args) -> bytes:
pass pass
@staticmethod @staticmethod
def default(obj: "Object"): def default(obj: "TLObject"):
if isinstance(obj, bytes): if isinstance(obj, bytes):
return repr(obj) return repr(obj)
@ -50,7 +50,7 @@ class Object:
) )
def __str__(self) -> str: def __str__(self) -> str:
return dumps(self, indent=4, default=Object.default, ensure_ascii=False) return dumps(self, indent=4, default=TLObject.default, ensure_ascii=False)
def __repr__(self) -> str: def __repr__(self) -> str:
return "pyrogram.api.{}({})".format( return "pyrogram.api.{}({})".format(
@ -62,7 +62,7 @@ class Object:
) )
) )
def __eq__(self, other: "Object") -> bool: def __eq__(self, other: "TLObject") -> bool:
for attr in self.__slots__: for attr in self.__slots__:
try: try:
if getattr(self, attr) != getattr(other, attr): if getattr(self, attr) != getattr(other, attr):
@ -77,3 +77,6 @@ class Object:
def __getitem__(self, item): def __getitem__(self, item):
return getattr(self, item) return getattr(self, item)
def __setitem__(self, key, value):
setattr(self, key, value)

View File

@ -37,7 +37,7 @@ from threading import Thread
from typing import Union, List from typing import Union, List
from pyrogram.api import functions, types from pyrogram.api import functions, types
from pyrogram.api.core import Object from pyrogram.api.core import TLObject
from pyrogram.client.handlers import DisconnectHandler from pyrogram.client.handlers import DisconnectHandler
from pyrogram.client.handlers.handler import Handler from pyrogram.client.handlers.handler import Handler
from pyrogram.client.methods.password.utils import compute_check from pyrogram.client.methods.password.utils import compute_check
@ -998,7 +998,7 @@ class Client(Methods, BaseClient):
log.debug("{} stopped".format(name)) log.debug("{} stopped".format(name))
def send(self, data: Object, retries: int = Session.MAX_RETRIES, timeout: float = Session.WAIT_TIMEOUT): def send(self, data: TLObject, retries: int = Session.MAX_RETRIES, timeout: float = Session.WAIT_TIMEOUT):
"""Send raw Telegram queries. """Send raw Telegram queries.
This method makes it possible to manually call every single Telegram API method in a low-level manner. This method makes it possible to manually call every single Telegram API method in a low-level manner.

View File

@ -23,7 +23,7 @@ from io import BytesIO
from os import urandom from os import urandom
from pyrogram.api import functions, types from pyrogram.api import functions, types
from pyrogram.api.core import Object, Long, Int from pyrogram.api.core import TLObject, Long, Int
from pyrogram.connection import Connection from pyrogram.connection import Connection
from pyrogram.crypto import AES, RSA, Prime from pyrogram.crypto import AES, RSA, Prime
from .internals import MsgId from .internals import MsgId
@ -43,7 +43,7 @@ class Auth:
self.connection = None self.connection = None
@staticmethod @staticmethod
def pack(data: Object) -> bytes: def pack(data: TLObject) -> bytes:
return ( return (
bytes(8) bytes(8)
+ Long(MsgId()) + Long(MsgId())
@ -54,9 +54,9 @@ class Auth:
@staticmethod @staticmethod
def unpack(b: BytesIO): def unpack(b: BytesIO):
b.seek(20) # Skip auth_key_id (8), message_id (8) and message_length (4) b.seek(20) # Skip auth_key_id (8), message_id (8) and message_length (4)
return Object.read(b) return TLObject.read(b)
def send(self, data: Object): def send(self, data: TLObject):
data = self.pack(data) data = self.pack(data)
self.connection.send(data) self.connection.send(data)
response = BytesIO(self.connection.recv()) response = BytesIO(self.connection.recv())
@ -158,7 +158,7 @@ class Auth:
answer_with_hash = AES.ige256_decrypt(encrypted_answer, tmp_aes_key, tmp_aes_iv) answer_with_hash = AES.ige256_decrypt(encrypted_answer, tmp_aes_key, tmp_aes_iv)
answer = answer_with_hash[20:] answer = answer_with_hash[20:]
server_dh_inner_data = Object.read(BytesIO(answer)) server_dh_inner_data = TLObject.read(BytesIO(answer))
log.debug("Done decrypting answer") log.debug("Done decrypting answer")

View File

@ -16,7 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from pyrogram.api.core import Message, MsgContainer, Object from pyrogram.api.core import Message, MsgContainer, TLObject
from pyrogram.api.functions import Ping from pyrogram.api.functions import Ping
from pyrogram.api.types import MsgsAck, HttpWait from pyrogram.api.types import MsgsAck, HttpWait
from .msg_id import MsgId from .msg_id import MsgId
@ -29,7 +29,7 @@ class MsgFactory:
def __init__(self): def __init__(self):
self.seq_no = SeqNo() self.seq_no = SeqNo()
def __call__(self, body: Object) -> Message: def __call__(self, body: TLObject) -> Message:
return Message( return Message(
body, body,
MsgId(), MsgId(),

View File

@ -30,7 +30,7 @@ import pyrogram
from pyrogram import __copyright__, __license__, __version__ from pyrogram import __copyright__, __license__, __version__
from pyrogram.api import functions, types, core from pyrogram.api import functions, types, core
from pyrogram.api.all import layer from pyrogram.api.all import layer
from pyrogram.api.core import Message, Object, MsgContainer, Long, FutureSalt, Int from pyrogram.api.core import Message, TLObject, MsgContainer, Long, FutureSalt, Int
from pyrogram.connection import Connection from pyrogram.connection import Connection
from pyrogram.crypto import AES, KDF from pyrogram.crypto import AES, KDF
from pyrogram.errors import RPCError, InternalServerError, AuthKeyDuplicated from pyrogram.errors import RPCError, InternalServerError, AuthKeyDuplicated
@ -391,7 +391,7 @@ class Session:
log.debug("RecvThread stopped") log.debug("RecvThread stopped")
def _send(self, data: Object, wait_response: bool = True, timeout: float = WAIT_TIMEOUT): def _send(self, data: TLObject, wait_response: bool = True, timeout: float = WAIT_TIMEOUT):
message = self.msg_factory(data) message = self.msg_factory(data)
msg_id = message.msg_id msg_id = message.msg_id
@ -422,7 +422,7 @@ class Session:
else: else:
return result return result
def send(self, data: Object, retries: int = MAX_RETRIES, timeout: float = WAIT_TIMEOUT): def send(self, data: TLObject, retries: int = MAX_RETRIES, timeout: float = WAIT_TIMEOUT):
self.is_connected.wait(self.WAIT_TIMEOUT) self.is_connected.wait(self.WAIT_TIMEOUT)
try: try: