Rework download_media to accept the new Message type

This commit is contained in:
Dan 2018-04-16 01:07:02 +02:00
parent 6275a4003f
commit fdac67de69

View File

@ -129,15 +129,15 @@ class Client:
OFFLINE_SLEEP = 300 OFFLINE_SLEEP = 300
MEDIA_TYPE_ID = { MEDIA_TYPE_ID = {
0: "Thumbnail", 0: "thumbnail",
2: "Photo", 2: "photo",
3: "Voice", 3: "voice",
4: "Video", 4: "video",
5: "Document", 5: "document",
8: "Sticker", 8: "sticker",
9: "Audio", 9: "audio",
10: "GIF", 10: "gif",
13: "VideoNote" 13: "video_note"
} }
def __init__(self, def __init__(self,
@ -618,71 +618,84 @@ class Client:
while True: while True:
media = self.download_queue.get() media = self.download_queue.get()
temp_file_path = ""
final_file_path = ""
if media is None: if media is None:
break break
temp_file_path = ""
final_file_path = ""
try: try:
media, file_name, done, progress, path = media media, file_name, done, progress, path = media
file_id = media.file_id
size = media.file_size
directory, file_name = os.path.split(file_name) directory, file_name = os.path.split(file_name)
directory = directory or "downloads" directory = directory or "downloads"
if isinstance(media, types.MessageMediaDocument): try:
document = media.document decoded = utils.decode(file_id)
fmt = "<iiqqqqi" if len(decoded) > 24 else "<iiqq"
if isinstance(document, types.Document): unpacked = struct.unpack(fmt, decoded)
if not file_name: except (AssertionError, binascii.Error, struct.error):
file_name = "doc_{}{}".format( raise FileIdInvalid from None
datetime.fromtimestamp(document.date).strftime("%Y-%m-%d_%H-%M-%S"),
".txt" if document.mime_type == "text/plain" else
mimetypes.guess_extension(document.mime_type) if document.mime_type else ".unknown"
)
for i in document.attributes:
if isinstance(i, types.DocumentAttributeFilename):
file_name = i.file_name
break
elif isinstance(i, types.DocumentAttributeSticker):
file_name = file_name.replace("doc", "sticker")
elif isinstance(i, types.DocumentAttributeAudio):
file_name = file_name.replace("doc", "audio")
elif isinstance(i, types.DocumentAttributeVideo):
file_name = file_name.replace("doc", "video")
elif isinstance(i, types.DocumentAttributeAnimated):
file_name = file_name.replace("doc", "gif")
temp_file_path = self.get_file(
dc_id=document.dc_id,
id=document.id,
access_hash=document.access_hash,
version=document.version,
size=document.size,
progress=progress
)
elif isinstance(media, (types.MessageMediaPhoto, types.Photo)):
if isinstance(media, types.MessageMediaPhoto):
photo = media.photo
else: else:
photo = media media_type = unpacked[0]
dc_id = unpacked[1]
id = unpacked[2]
access_hash = unpacked[3]
volume_id = None
secret = None
local_id = None
if len(decoded) > 24:
volume_id = unpacked[4]
secret = unpacked[5]
local_id = unpacked[6]
media_type_str = Client.MEDIA_TYPE_ID.get(media_type, None)
if media_type_str:
log.info("The file_id belongs to a {}".format(media_type_str))
else:
raise FileIdInvalid("Unknown media type: {}".format(unpacked[0]))
file_name = file_name or getattr(media, "file_name", None)
if isinstance(photo, types.Photo):
if not file_name: if not file_name:
file_name = "photo_{}_{}.jpg".format( if media_type == 3:
datetime.fromtimestamp(photo.date).strftime("%Y-%m-%d_%H-%M-%S"), extension = ".ogg"
self.rnd_id() elif media_type in (4, 10, 13):
extension = mimetypes.guess_extension(media.mime_type) or ".mp4"
elif media_type == 5:
extension = mimetypes.guess_extension(media.mime_type) or ".unknown"
elif media_type == 8:
extension = ".webp"
elif media_type == 9:
extension = mimetypes.guess_extension(media.mime_type) or ".mp3"
elif media_type == 0:
extension = ".jpg"
elif media_type == 2:
extension = ".jpg"
else:
continue
file_name = "{}_{}_{}{}".format(
media_type_str,
datetime.fromtimestamp(media.date or time.time()).strftime("%Y-%m-%d_%H-%M-%S"),
self.rnd_id(),
extension
) )
photo_loc = photo.sizes[-1].location
temp_file_path = self.get_file( temp_file_path = self.get_file(
dc_id=photo_loc.dc_id, dc_id=dc_id,
volume_id=photo_loc.volume_id, id=id,
local_id=photo_loc.local_id, access_hash=access_hash,
secret=photo_loc.secret, volume_id=volume_id,
size=photo.sizes[-1].size, local_id=local_id,
secret=secret,
size=size,
progress=progress progress=progress
) )
@ -3002,14 +3015,14 @@ class Client:
return False return False
def download_media(self, def download_media(self,
message: types.Message, message: pyrogram.Message,
file_name: str = "", file_name: str = "",
block: bool = True, block: bool = True,
progress: callable = None): progress: callable = None):
"""Use this method to download the media from a Message. """Use this method to download the media from a Message.
Args: Args:
message (:obj:`Message <pyrogram.api.types.Message>`): message (:obj:`Message <pyrogram.api.types.pyrogram.Message>`):
The Message containing the media. The Message containing the media.
file_name (``str``, optional): file_name (``str``, optional):
@ -3039,19 +3052,40 @@ class Client:
Raises: Raises:
:class:`Error <pyrogram.Error>` :class:`Error <pyrogram.Error>`
""" """
if isinstance(message, (types.Message, types.Photo)): if isinstance(message, pyrogram.Message):
if message.photo:
media = message.photo[-1]
elif message.audio:
media = message.audio
elif message.document:
media = message.document
elif message.video:
media = message.video
elif message.voice:
media = message.voice
elif message.video_note:
media = message.video_note
elif message.sticker:
media = message.sticker
else:
return
elif isinstance(message, (
pyrogram.PhotoSize,
pyrogram.Audio,
pyrogram.Document,
pyrogram.Video,
pyrogram.Voice,
pyrogram.VideoNote,
pyrogram.Sticker
)):
media = message
else:
return
done = Event() done = Event()
path = [None] path = [None]
if isinstance(message, types.Message):
media = message.media
else:
media = message
if media is not None:
self.download_queue.put((media, file_name, done, progress, path)) self.download_queue.put((media, file_name, done, progress, path))
else:
return
if block: if block:
done.wait() done.wait()