Add initial support for downloading media

This commit is contained in:
Dan 2018-02-18 18:11:33 +01:00
parent d89d238d30
commit 15561d19d5

View File

@ -27,6 +27,7 @@ import threading
import time
from collections import namedtuple
from configparser import ConfigParser
from datetime import datetime
from hashlib import sha256, md5
from queue import Queue
from signal import signal, SIGINT, SIGTERM, SIGABRT
@ -39,8 +40,8 @@ from pyrogram.api.errors import (
PhoneNumberUnoccupied, PhoneCodeInvalid, PhoneCodeHashEmpty,
PhoneCodeExpired, PhoneCodeEmpty, SessionPasswordNeeded,
PasswordHashInvalid, FloodWait, PeerIdInvalid, FilePartMissing,
ChatAdminRequired, FirstnameInvalid, PhoneNumberBanned
)
ChatAdminRequired, FirstnameInvalid, PhoneNumberBanned,
VolumeLocNotFound)
from pyrogram.api.types import (
User, Chat, Channel,
PeerUser, PeerChannel,
@ -49,6 +50,7 @@ from pyrogram.api.types import (
)
from pyrogram.crypto import AES
from pyrogram.session import Auth, Session
from pyrogram.session.internals import MsgId
from .input_media import InputMedia
from .style import Markdown, HTML
@ -103,6 +105,7 @@ class Client:
INVITE_LINK_RE = re.compile(r"^(?:https?://)?t\.me/joinchat/(.+)$")
DIALOGS_AT_ONCE = 100
UPDATES_WORKERS = 2
DOWNLOAD_WORKERS = 1
def __init__(self,
session_name: str,
@ -148,6 +151,8 @@ class Client:
self.update_queue = Queue()
self.update_handler = None
self.download_queue = Queue()
def start(self):
"""Use this method to start the Client after creating it.
Requires no parameters.
@ -176,7 +181,7 @@ class Client:
self.password = None
self.save_session()
self.rnd_id = self.session.msg_id
self.rnd_id = MsgId
self.get_dialogs()
for i in range(self.UPDATES_WORKERS):
@ -185,6 +190,9 @@ class Client:
for i in range(self.workers):
Thread(target=self.update_worker, name="UpdateWorker#{}".format(i + 1)).start()
for i in range(self.DOWNLOAD_WORKERS):
Thread(target=self.download_worker, name="DownloadWorker#{}".format(i + 1)).start()
mimetypes.init()
def stop(self):
@ -199,6 +207,9 @@ class Client:
for _ in range(self.workers):
self.update_queue.put(None)
for _ in range(self.DOWNLOAD_WORKERS):
self.download_queue.put(None)
def fetch_peers(self, entities: list):
for entity in entities:
if isinstance(entity, User):
@ -260,6 +271,67 @@ class Client:
if username is not None:
self.peers_by_username[username] = input_peer
def download_worker(self):
name = threading.current_thread().name
log.debug("{} started".format(name))
while True:
message = self.download_queue.get()
if message is None:
break
message, done = message
try:
if isinstance(message.media, types.MessageMediaDocument):
document = message.media.document
if isinstance(document, types.Document):
file_name = "doc_{}{}".format(
datetime.fromtimestamp(document.date).strftime("%Y-%m-%d_%H-%M-%S"),
mimetypes.guess_extension(document.mime_type) or ".unknown"
)
for i in document.attributes:
if isinstance(i, types.DocumentAttributeFilename):
file_name = i.file_name
break
elif isinstance(i, types.DocumentAttributeSticker):
file_name = file_name.replace("doc", "sticker")
elif isinstance(i, types.DocumentAttributeAudio):
file_name = file_name.replace("doc", "audio")
elif isinstance(i, types.DocumentAttributeVideo):
file_name = file_name.replace("doc", "video")
elif isinstance(i, types.DocumentAttributeAnimated):
file_name = file_name.replace("doc", "gif")
tmp_file_name = self.get_file(
dc_id=document.dc_id,
id=document.id,
access_hash=document.access_hash,
version=document.version
)
i = 1
while True:
try:
os.renames("./{}".format(tmp_file_name), "./downloads/{}".format(
".".join(file_name.split(".")[:-1])
+ (" ({}).".format(i) if i > 1 else ".")
+ file_name.split(".")[-1]
))
except FileExistsError:
i += 1
else:
break
done.set()
except Exception as e:
log.error(e, exc_info=True)
log.debug("{} stopped".format(name))
def updates_worker(self):
name = threading.current_thread().name
log.debug("{} started".format(name))
@ -1667,8 +1739,7 @@ class Client:
part_size = 512 * 1024
file_size = os.path.getsize(path)
file_total_parts = math.ceil(file_size / part_size)
# is_big = True if file_size > 10 * 1024 * 1024 else False
is_big = False # Treat all files as not-big to have the server check for the md5 sum
is_big = True if file_size > 10 * 1024 * 1024 else False
is_missing_part = True if file_id is not None else False
file_id = file_id or self.rnd_id()
md5_sum = md5() if not is_big and not is_missing_part else None
@ -1759,22 +1830,19 @@ class Client:
session.start()
if volume_id: # Photos are accessed by volume_id, local_id, secret
file_name = "_".join(str(i) for i in [dc_id, volume_id, local_id, secret])
location = types.InputFileLocation(
volume_id=volume_id,
local_id=local_id,
secret=secret
)
else: # Any other file can be more easily accessed by id and access_hash
file_name = "_".join(str(i) for i in [dc_id, id, access_hash, version])
location = types.InputDocumentFileLocation(
id=id,
access_hash=access_hash,
version=version
)
file_name = str(MsgId())
limit = 1024 * 1024
offset = 0
@ -1822,6 +1890,8 @@ class Client:
cdn_session.start()
try:
with open(file_name, "wb") as f:
while True:
r2 = cdn_session.send(
functions.upload.GetCdnFile(
location=location,
@ -1832,17 +1902,17 @@ class Client:
)
if isinstance(r2, types.upload.CdnFileReuploadNeeded):
try:
session.send(
functions.upload.ReuploadCdnFile(
file_token=r.file_token,
request_token=r2.request_token
)
)
else:
with open(file_name, "wb") as f:
while True:
if not isinstance(r2, types.upload.CdnFile):
except VolumeLocNotFound:
break
else:
continue
chunk = r2.bytes
@ -1861,6 +1931,7 @@ class Client:
)
)
# https://core.telegram.org/cdn#verifying-files
for i, h in enumerate(hashes):
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i)
@ -1870,15 +1941,6 @@ class Client:
os.fsync(f.fileno())
offset += limit
r2 = cdn_session.send(
functions.upload.GetCdnFile(
location=location,
file_token=r.file_token,
offset=offset,
limit=limit
)
)
except Exception as e:
log.error(e)
finally:
@ -2238,3 +2300,8 @@ class Client:
reply_to_msg_id=reply_to_message_id
)
)
def download_media(self, message: types.Message):
done = Event()
self.download_queue.put((message, done))
done.wait()