Force keyword arguments for all TL types

This commit is contained in:
Dan 2019-03-16 16:50:40 +01:00
parent e0f1f6aaeb
commit 34b51b6481
28 changed files with 133 additions and 108 deletions

View File

@ -287,9 +287,11 @@ def start():
sorted_args = sort_args(c.args)
arguments = ", " + ", ".join(
[get_argument_type(i) for i in sorted_args if i != ("flags", "#")]
) if c.args else ""
arguments = (
", "
+ ("*, " if c.args else "")
+ (", ".join([get_argument_type(i) for i in sorted_args if i != ("flags", "#")]) if c.args else "")
)
fields = "\n ".join(
["self.{0} = {0} # {1}".format(i[0], i[1]) for i in c.args if i != ("flags", "#")]
@ -456,7 +458,9 @@ def start():
fields=fields,
read_types=read_types,
write_types=write_types,
return_arguments=", ".join([i[0] for i in sorted_args if i != ("flags", "#")]),
return_arguments=", ".join(
["{0}={0}".format(i[0]) for i in sorted_args if i != ("flags", "#")]
),
slots=", ".join(['"{}"'.format(i[0]) for i in sorted_args if i != ("flags", "#")]),
qualname="{}{}".format("{}.".format(c.namespace) if c.namespace else "", c.name)
)

View File

@ -627,9 +627,9 @@ class Client(Methods, BaseClient):
try:
r = self.send(
functions.auth.SignIn(
self.phone_number,
phone_code_hash,
self.phone_code
phone_number=self.phone_number,
phone_code_hash=phone_code_hash,
phone_code=self.phone_code
)
)
except PhoneNumberUnoccupied:
@ -640,11 +640,11 @@ class Client(Methods, BaseClient):
try:
r = self.send(
functions.auth.SignUp(
self.phone_number,
phone_code_hash,
self.phone_code,
self.first_name,
self.last_name
phone_number=self.phone_number,
phone_code_hash=phone_code_hash,
phone_code=self.phone_code,
first_name=self.first_name,
last_name=self.last_name
)
)
except PhoneNumberOccupied:
@ -738,7 +738,11 @@ class Client(Methods, BaseClient):
break
if terms_of_service:
assert self.send(functions.help.AcceptTermsOfService(terms_of_service.id))
assert self.send(
functions.help.AcceptTermsOfService(
id=terms_of_service.id
)
)
self.password = None
self.user_id = r.user.id
@ -1036,10 +1040,10 @@ class Client(Methods, BaseClient):
raise ConnectionError("Client has not been started")
if self.no_updates:
data = functions.InvokeWithoutUpdates(data)
data = functions.InvokeWithoutUpdates(query=data)
if self.takeout_id:
data = functions.InvokeWithTakeout(self.takeout_id, data)
data = functions.InvokeWithTakeout(takeout_id=self.takeout_id, query=data)
r = self.session.send(data, retries, timeout)
@ -1353,7 +1357,7 @@ class Client(Methods, BaseClient):
self.fetch_peers(
self.send(
functions.users.GetUsers(
id=[types.InputUser(peer_id, 0)]
id=[types.InputUser(user_id=peer_id, access_hash=0)]
)
)
)
@ -1361,7 +1365,7 @@ class Client(Methods, BaseClient):
if str(peer_id).startswith("-100"):
self.send(
functions.channels.GetChannels(
id=[types.InputChannel(int(str(peer_id)[4:]), 0)]
id=[types.InputChannel(channel_id=int(str(peer_id)[4:]), access_hash=0)]
)
)
else:
@ -1668,8 +1672,8 @@ class Client(Methods, BaseClient):
hashes = session.send(
functions.upload.GetCdnFileHashes(
r.file_token,
offset
file_token=r.file_token,
offset=offset
)
)

View File

@ -67,10 +67,10 @@ def get_peer_id(input_peer) -> int:
def get_input_peer(peer_id: int, access_hash: int):
return (
types.InputPeerUser(peer_id, access_hash) if peer_id > 0
else types.InputPeerChannel(int(str(peer_id)[4:]), access_hash)
types.InputPeerUser(user_id=peer_id, access_hash=access_hash) if peer_id > 0
else types.InputPeerChannel(channel_id=int(str(peer_id)[4:]), access_hash=access_hash)
if (str(peer_id).startswith("-100") and access_hash)
else types.InputPeerChat(-peer_id)
else types.InputPeerChat(chat_id=-peer_id)
)

View File

@ -45,7 +45,7 @@ class ExportChatInviteLink(BaseClient):
if isinstance(peer, types.InputPeerChat):
return self.send(
functions.messages.ExportChatInvite(
chat_id=peer.chat_id
peer=peer.chat_id
)
).link
elif isinstance(peer, types.InputPeerChannel):

View File

@ -67,10 +67,10 @@ class GetChat(BaseClient):
peer = self.resolve_peer(chat_id)
if isinstance(peer, types.InputPeerChannel):
r = self.send(functions.channels.GetFullChannel(peer))
r = self.send(functions.channels.GetFullChannel(channel=peer))
elif isinstance(peer, (types.InputPeerUser, types.InputPeerSelf)):
r = self.send(functions.users.GetFullUser(peer))
r = self.send(functions.users.GetFullUser(id=peer))
else:
r = self.send(functions.messages.GetFullChat(peer.chat_id))
r = self.send(functions.messages.GetFullChat(chat_id=peer.chat_id))
return pyrogram.Chat._parse_full(self, r)

View File

@ -92,7 +92,7 @@ class GetChatMembers(BaseClient):
self,
self.send(
functions.messages.GetFullChat(
peer.chat_id
chat_id=peer.chat_id
)
)
)

View File

@ -39,7 +39,7 @@ class GetContacts(BaseClient):
"""
while True:
try:
contacts = self.send(functions.contacts.GetContacts(0))
contacts = self.send(functions.contacts.GetContacts(hash=0))
except FloodWait as e:
log.warning("get_contacts flood: waiting {} seconds".format(e.x))
time.sleep(e.x)

View File

@ -131,7 +131,9 @@ class EditMessageMedia(BaseClient):
w=media.width,
h=media.height
),
types.DocumentAttributeFilename(os.path.basename(media.media))
types.DocumentAttributeFilename(
file_name=os.path.basename(media.media)
)
]
)
)
@ -187,7 +189,9 @@ class EditMessageMedia(BaseClient):
performer=media.performer,
title=media.title
),
types.DocumentAttributeFilename(os.path.basename(media.media))
types.DocumentAttributeFilename(
file_name=os.path.basename(media.media)
)
]
)
)
@ -244,7 +248,9 @@ class EditMessageMedia(BaseClient):
w=media.width,
h=media.height
),
types.DocumentAttributeFilename(os.path.basename(media.media)),
types.DocumentAttributeFilename(
file_name=os.path.basename(media.media)
),
types.DocumentAttributeAnimated()
]
)
@ -296,7 +302,9 @@ class EditMessageMedia(BaseClient):
thumb=None if media.thumb is None else self.save_file(media.thumb),
file=self.save_file(media.media),
attributes=[
types.DocumentAttributeFilename(os.path.basename(media.media))
types.DocumentAttributeFilename(
file_name=os.path.basename(media.media)
)
]
)
)

View File

@ -76,7 +76,7 @@ class GetMessages(BaseClient):
is_iterable = not isinstance(ids, int)
ids = list(ids) if is_iterable else [ids]
ids = [ids_type(i) for i in ids]
ids = [ids_type(id=i) for i in ids]
if isinstance(peer, types.InputPeerChannel):
rpc = functions.channels.GetMessages(channel=peer, id=ids)

View File

@ -141,7 +141,7 @@ class SendAnimation(BaseClient):
w=width,
h=height
),
types.DocumentAttributeFilename(os.path.basename(animation)),
types.DocumentAttributeFilename(file_name=os.path.basename(animation)),
types.DocumentAttributeAnimated()
]
)

View File

@ -142,7 +142,7 @@ class SendAudio(BaseClient):
performer=performer,
title=title
),
types.DocumentAttributeFilename(os.path.basename(audio))
types.DocumentAttributeFilename(file_name=os.path.basename(audio))
]
)
elif audio.startswith("http"):

View File

@ -123,7 +123,7 @@ class SendDocument(BaseClient):
file=file,
thumb=thumb,
attributes=[
types.DocumentAttributeFilename(os.path.basename(document))
types.DocumentAttributeFilename(file_name=os.path.basename(document))
]
)
elif document.startswith("http"):

View File

@ -69,9 +69,9 @@ class SendLocation(BaseClient):
functions.messages.SendMedia(
peer=self.resolve_peer(chat_id),
media=types.InputMediaGeoPoint(
types.InputGeoPoint(
latitude,
longitude
geo_point=types.InputGeoPoint(
lat=latitude,
long=longitude
)
),
message="",

View File

@ -137,7 +137,7 @@ class SendMediaGroup(BaseClient):
w=i.width,
h=i.height
),
types.DocumentAttributeFilename(os.path.basename(i.media))
types.DocumentAttributeFilename(file_name=os.path.basename(i.media))
]
)
)

View File

@ -103,7 +103,7 @@ class SendSticker(BaseClient):
mime_type="image/webp",
file=file,
attributes=[
types.DocumentAttributeFilename(os.path.basename(sticker))
types.DocumentAttributeFilename(file_name=os.path.basename(sticker))
]
)
elif sticker.startswith("http"):

View File

@ -145,7 +145,7 @@ class SendVideo(BaseClient):
w=width,
h=height
),
types.DocumentAttributeFilename(os.path.basename(video))
types.DocumentAttributeFilename(file_name=os.path.basename(video))
]
)
elif video.startswith("http"):

View File

@ -101,4 +101,4 @@ def compute_check(r: types.account.Password, password: str) -> types.InputCheckP
+ K_bytes
)
return types.InputCheckPasswordSRP(srp_id, A_bytes, M1_bytes)
return types.InputCheckPasswordSRP(srp_id=srp_id, A=A_bytes, M1=M1_bytes)

View File

@ -35,7 +35,7 @@ class GetMe(BaseClient):
self,
self.send(
functions.users.GetFullUser(
types.InputPeerSelf()
id=types.InputPeerSelf()
)
).user
)

View File

@ -43,7 +43,7 @@ class SetUserProfilePhoto(BaseClient):
return bool(
self.send(
functions.photos.UploadProfilePhoto(
self.save_file(photo)
file=self.save_file(photo)
)
)
)

View File

@ -55,20 +55,20 @@ class HTML:
input_user = self.peers_by_id.get(user_id, None)
entity = (
Mention(start, len(body), input_user)
if input_user else MentionInvalid(start, len(body), user_id)
Mention(offset=start, length=len(body), user_id=input_user)
if input_user else MentionInvalid(offset=start, length=len(body), user_id=user_id)
)
else:
entity = Url(start, len(body), url)
entity = Url(offset=start, length=len(body), url=url)
else:
if style == "b" or style == "strong":
entity = Bold(start, len(body))
entity = Bold(offset=start, length=len(body))
elif style == "i" or style == "em":
entity = Italic(start, len(body))
entity = Italic(offset=start, length=len(body))
elif style == "code":
entity = Code(start, len(body))
entity = Code(offset=start, length=len(body))
elif style == "pre":
entity = Pre(start, len(body), "")
entity = Pre(offset=start, length=len(body), language="")
else:
continue

View File

@ -72,24 +72,24 @@ class Markdown:
input_user = self.peers_by_id.get(user_id, None)
entity = (
Mention(start, len(text), input_user)
Mention(offset=start, length=len(text), user_id=input_user)
if input_user
else MentionInvalid(start, len(text), user_id)
else MentionInvalid(offset=start, length=len(text), user_id=user_id)
)
else:
entity = Url(start, len(text), url)
entity = Url(offset=start, length=len(text), url=url)
body = text
offset += len(url) + 4
else:
if style == self.BOLD_DELIMITER:
entity = Bold(start, len(body))
entity = Bold(offset=start, length=len(body))
elif style == self.ITALIC_DELIMITER:
entity = Italic(start, len(body))
entity = Italic(offset=start, length=len(body))
elif style == self.CODE_DELIMITER:
entity = Code(start, len(body))
entity = Code(offset=start, length=len(body))
elif style == self.PRE_DELIMITER:
entity = Pre(start, len(body), "")
entity = Pre(offset=start, length=len(body), language="")
else:
continue

View File

@ -111,16 +111,20 @@ class InlineKeyboardButton(PyrogramType):
def write(self):
if self.callback_data:
return KeyboardButtonCallback(self.text, self.callback_data)
return KeyboardButtonCallback(text=self.text, data=self.callback_data)
if self.url:
return KeyboardButtonUrl(self.text, self.url)
return KeyboardButtonUrl(text=self.text, url=self.url)
if self.switch_inline_query:
return KeyboardButtonSwitchInline(self.text, self.switch_inline_query)
return KeyboardButtonSwitchInline(text=self.text, query=self.switch_inline_query)
if self.switch_inline_query_current_chat:
return KeyboardButtonSwitchInline(self.text, self.switch_inline_query_current_chat, same_peer=True)
return KeyboardButtonSwitchInline(
text=self.text,
query=self.switch_inline_query_current_chat,
same_peer=True
)
if self.callback_game:
return KeyboardButtonGame(self.text)
return KeyboardButtonGame(text=self.text)

View File

@ -59,7 +59,7 @@ class InlineKeyboardMarkup(PyrogramType):
def write(self):
return ReplyInlineMarkup(
[KeyboardButtonRow(
[j.write() for j in i]
rows=[KeyboardButtonRow(
buttons=[j.write() for j in i]
) for i in self.inline_keyboard]
)

View File

@ -75,8 +75,8 @@ class KeyboardButton(PyrogramType):
# TODO: Enforce optional args mutual exclusiveness
if self.request_contact:
return KeyboardButtonRequestPhone(self.text)
return KeyboardButtonRequestPhone(text=self.text)
elif self.request_location:
return KeyboardButtonRequestGeoLocation(self.text)
return KeyboardButtonRequestGeoLocation(text=self.text)
else:
return RawKeyboardButton(self.text)
return RawKeyboardButton(text=self.text)

View File

@ -87,9 +87,11 @@ class ReplyKeyboardMarkup(PyrogramType):
def write(self):
return RawReplyKeyboardMarkup(
rows=[KeyboardButtonRow(
[KeyboardButton(j).write()
if isinstance(j, str) else j.write()
for j in i]
buttons=[
KeyboardButton(j).write()
if isinstance(j, str) else j.write()
for j in i
]
) for i in self.keyboard],
resize=self.resize_keyboard or None,
single_use=self.one_time_keyboard or None,

View File

@ -103,7 +103,10 @@ class Sticker(PyrogramType):
try:
return send(
functions.messages.GetStickerSet(
types.InputStickerSetID(*input_sticker_set_id)
stickerset=types.InputStickerSetID(
id=input_sticker_set_id[0],
access_hash=input_sticker_set_id[1]
)
)
).set.short_name
except StickersetInvalid:

View File

@ -45,10 +45,10 @@ class Auth:
@staticmethod
def pack(data: Object) -> bytes:
return (
bytes(8)
+ Long(MsgId())
+ Int(len(data.write()))
+ data.write()
bytes(8)
+ Long(MsgId())
+ Int(len(data.write()))
+ data.write()
)
@staticmethod
@ -83,7 +83,7 @@ class Auth:
# Step 1; Step 2
nonce = int.from_bytes(urandom(16), "little", signed=True)
log.debug("Send req_pq: {}".format(nonce))
res_pq = self.send(functions.ReqPqMulti(nonce))
res_pq = self.send(functions.ReqPqMulti(nonce=nonce))
log.debug("Got ResPq: {}".format(res_pq.server_nonce))
log.debug("Server public key fingerprints: {}".format(res_pq.server_public_key_fingerprints))
@ -110,12 +110,12 @@ class Auth:
new_nonce = int.from_bytes(urandom(32), "little", signed=True)
data = types.PQInnerData(
res_pq.pq,
p.to_bytes(4, "big"),
q.to_bytes(4, "big"),
nonce,
server_nonce,
new_nonce,
pq=res_pq.pq,
p=p.to_bytes(4, "big"),
q=q.to_bytes(4, "big"),
nonce=nonce,
server_nonce=server_nonce,
new_nonce=new_nonce,
).write()
sha = sha1(data).digest()
@ -129,12 +129,12 @@ class Auth:
log.debug("Send req_DH_params")
server_dh_params = self.send(
functions.ReqDHParams(
nonce,
server_nonce,
p.to_bytes(4, "big"),
q.to_bytes(4, "big"),
public_key_fingerprint,
encrypted_data
nonce=nonce,
server_nonce=server_nonce,
p=p.to_bytes(4, "big"),
q=q.to_bytes(4, "big"),
public_key_fingerprint=public_key_fingerprint,
encrypted_data=encrypted_data
)
)
@ -144,13 +144,13 @@ class Auth:
new_nonce = new_nonce.to_bytes(32, "little", signed=True)
tmp_aes_key = (
sha1(new_nonce + server_nonce).digest()
+ sha1(server_nonce + new_nonce).digest()[:12]
sha1(new_nonce + server_nonce).digest()
+ sha1(server_nonce + new_nonce).digest()[:12]
)
tmp_aes_iv = (
sha1(server_nonce + new_nonce).digest()[12:]
+ sha1(new_nonce + new_nonce).digest() + new_nonce[:4]
sha1(server_nonce + new_nonce).digest()[12:]
+ sha1(new_nonce + new_nonce).digest() + new_nonce[:4]
)
server_nonce = int.from_bytes(server_nonce, "little", signed=True)
@ -175,10 +175,10 @@ class Auth:
retry_id = 0
data = types.ClientDHInnerData(
nonce,
server_nonce,
retry_id,
g_b
nonce=nonce,
server_nonce=server_nonce,
retry_id=retry_id,
g_b=g_b
).write()
sha = sha1(data).digest()
@ -189,9 +189,9 @@ class Auth:
log.debug("Send set_client_DH_params")
set_client_dh_params_answer = self.send(
functions.SetClientDHParams(
nonce,
server_nonce,
encrypted_data
nonce=nonce,
server_nonce=server_nonce,
encrypted_data=encrypted_data
)
)

View File

@ -134,11 +134,11 @@ class Session:
self.current_salt = FutureSalt(
0, 0,
self._send(
functions.Ping(0),
functions.Ping(ping_id=0),
timeout=self.START_TIMEOUT
).new_server_salt
)
self.current_salt = self._send(functions.GetFutureSalts(1), timeout=self.START_TIMEOUT).salts[0]
self.current_salt = self._send(functions.GetFutureSalts(num=1), timeout=self.START_TIMEOUT).salts[0]
self.next_salt_thread = Thread(target=self.next_salt, name="NextSaltThread")
self.next_salt_thread.start()
@ -146,8 +146,8 @@ class Session:
if not self.is_cdn:
self._send(
functions.InvokeWithLayer(
layer,
functions.InitConnection(
layer=layer,
query=functions.InitConnection(
api_id=self.client.api_id,
app_version=self.client.app_version,
device_model=self.client.device_model,
@ -314,7 +314,7 @@ class Session:
log.info("Send {} acks".format(len(self.pending_acks)))
try:
self._send(types.MsgsAck(list(self.pending_acks)), False)
self._send(types.MsgsAck(msg_ids=list(self.pending_acks)), False)
except (OSError, TimeoutError):
pass
else:
@ -335,7 +335,7 @@ class Session:
try:
self._send(functions.PingDelayDisconnect(
0, self.WAIT_TIMEOUT + 10
ping_id=0, disconnect_delay=self.WAIT_TIMEOUT + 10
), False)
except (OSError, TimeoutError, Error):
pass
@ -365,7 +365,7 @@ class Session:
break
try:
self.current_salt = self._send(functions.GetFutureSalts(1)).salts[0]
self.current_salt = self._send(functions.GetFutureSalts(num=1)).salts[0]
except (OSError, TimeoutError, Error):
self.connection.close()
break