Rework download_media to accept the new Message type

This commit is contained in:
Dan 2018-04-16 01:07:02 +02:00
parent 6275a4003f
commit fdac67de69

View File

@ -129,15 +129,15 @@ class Client:
OFFLINE_SLEEP = 300 OFFLINE_SLEEP = 300
MEDIA_TYPE_ID = { MEDIA_TYPE_ID = {
0: "Thumbnail", 0: "thumbnail",
2: "Photo", 2: "photo",
3: "Voice", 3: "voice",
4: "Video", 4: "video",
5: "Document", 5: "document",
8: "Sticker", 8: "sticker",
9: "Audio", 9: "audio",
10: "GIF", 10: "gif",
13: "VideoNote" 13: "video_note"
} }
def __init__(self, def __init__(self,
@ -618,73 +618,86 @@ class Client:
while True: while True:
media = self.download_queue.get() media = self.download_queue.get()
temp_file_path = ""
final_file_path = ""
if media is None: if media is None:
break break
temp_file_path = ""
final_file_path = ""
try: try:
media, file_name, done, progress, path = media media, file_name, done, progress, path = media
file_id = media.file_id
size = media.file_size
directory, file_name = os.path.split(file_name) directory, file_name = os.path.split(file_name)
directory = directory or "downloads" directory = directory or "downloads"
if isinstance(media, types.MessageMediaDocument): try:
document = media.document decoded = utils.decode(file_id)
fmt = "<iiqqqqi" if len(decoded) > 24 else "<iiqq"
unpacked = struct.unpack(fmt, decoded)
except (AssertionError, binascii.Error, struct.error):
raise FileIdInvalid from None
else:
media_type = unpacked[0]
dc_id = unpacked[1]
id = unpacked[2]
access_hash = unpacked[3]
volume_id = None
secret = None
local_id = None
if isinstance(document, types.Document): if len(decoded) > 24:
if not file_name: volume_id = unpacked[4]
file_name = "doc_{}{}".format( secret = unpacked[5]
datetime.fromtimestamp(document.date).strftime("%Y-%m-%d_%H-%M-%S"), local_id = unpacked[6]
".txt" if document.mime_type == "text/plain" else
mimetypes.guess_extension(document.mime_type) if document.mime_type else ".unknown"
)
for i in document.attributes: media_type_str = Client.MEDIA_TYPE_ID.get(media_type, None)
if isinstance(i, types.DocumentAttributeFilename):
file_name = i.file_name
break
elif isinstance(i, types.DocumentAttributeSticker):
file_name = file_name.replace("doc", "sticker")
elif isinstance(i, types.DocumentAttributeAudio):
file_name = file_name.replace("doc", "audio")
elif isinstance(i, types.DocumentAttributeVideo):
file_name = file_name.replace("doc", "video")
elif isinstance(i, types.DocumentAttributeAnimated):
file_name = file_name.replace("doc", "gif")
temp_file_path = self.get_file( if media_type_str:
dc_id=document.dc_id, log.info("The file_id belongs to a {}".format(media_type_str))
id=document.id,
access_hash=document.access_hash,
version=document.version,
size=document.size,
progress=progress
)
elif isinstance(media, (types.MessageMediaPhoto, types.Photo)):
if isinstance(media, types.MessageMediaPhoto):
photo = media.photo
else: else:
photo = media raise FileIdInvalid("Unknown media type: {}".format(unpacked[0]))
if isinstance(photo, types.Photo): file_name = file_name or getattr(media, "file_name", None)
if not file_name:
file_name = "photo_{}_{}.jpg".format(
datetime.fromtimestamp(photo.date).strftime("%Y-%m-%d_%H-%M-%S"),
self.rnd_id()
)
photo_loc = photo.sizes[-1].location if not file_name:
if media_type == 3:
extension = ".ogg"
elif media_type in (4, 10, 13):
extension = mimetypes.guess_extension(media.mime_type) or ".mp4"
elif media_type == 5:
extension = mimetypes.guess_extension(media.mime_type) or ".unknown"
elif media_type == 8:
extension = ".webp"
elif media_type == 9:
extension = mimetypes.guess_extension(media.mime_type) or ".mp3"
elif media_type == 0:
extension = ".jpg"
elif media_type == 2:
extension = ".jpg"
else:
continue
temp_file_path = self.get_file( file_name = "{}_{}_{}{}".format(
dc_id=photo_loc.dc_id, media_type_str,
volume_id=photo_loc.volume_id, datetime.fromtimestamp(media.date or time.time()).strftime("%Y-%m-%d_%H-%M-%S"),
local_id=photo_loc.local_id, self.rnd_id(),
secret=photo_loc.secret, extension
size=photo.sizes[-1].size, )
progress=progress
) temp_file_path = self.get_file(
dc_id=dc_id,
id=id,
access_hash=access_hash,
volume_id=volume_id,
local_id=local_id,
secret=secret,
size=size,
progress=progress
)
if temp_file_path: if temp_file_path:
final_file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name))) final_file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name)))
@ -3002,14 +3015,14 @@ class Client:
return False return False
def download_media(self, def download_media(self,
message: types.Message, message: pyrogram.Message,
file_name: str = "", file_name: str = "",
block: bool = True, block: bool = True,
progress: callable = None): progress: callable = None):
"""Use this method to download the media from a Message. """Use this method to download the media from a Message.
Args: Args:
message (:obj:`Message <pyrogram.api.types.Message>`): message (:obj:`Message <pyrogram.api.types.pyrogram.Message>`):
The Message containing the media. The Message containing the media.
file_name (``str``, optional): file_name (``str``, optional):
@ -3039,24 +3052,45 @@ class Client:
Raises: Raises:
:class:`Error <pyrogram.Error>` :class:`Error <pyrogram.Error>`
""" """
if isinstance(message, (types.Message, types.Photo)): if isinstance(message, pyrogram.Message):
done = Event() if message.photo:
path = [None] media = message.photo[-1]
elif message.audio:
if isinstance(message, types.Message): media = message.audio
media = message.media elif message.document:
else: media = message.document
media = message elif message.video:
media = message.video
if media is not None: elif message.voice:
self.download_queue.put((media, file_name, done, progress, path)) media = message.voice
elif message.video_note:
media = message.video_note
elif message.sticker:
media = message.sticker
else: else:
return return
elif isinstance(message, (
pyrogram.PhotoSize,
pyrogram.Audio,
pyrogram.Document,
pyrogram.Video,
pyrogram.Voice,
pyrogram.VideoNote,
pyrogram.Sticker
)):
media = message
else:
return
if block: done = Event()
done.wait() path = [None]
return path[0] self.download_queue.put((media, file_name, done, progress, path))
if block:
done.wait()
return path[0]
def download_photo(self, def download_photo(self,
photo: types.Photo or types.UserProfilePhoto or types.ChatPhoto, photo: types.Photo or types.UserProfilePhoto or types.ChatPhoto,