Add initial support for downloading media
This commit is contained in:
parent
d89d238d30
commit
15561d19d5
@ -27,6 +27,7 @@ import threading
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from configparser import ConfigParser
|
||||
from datetime import datetime
|
||||
from hashlib import sha256, md5
|
||||
from queue import Queue
|
||||
from signal import signal, SIGINT, SIGTERM, SIGABRT
|
||||
@ -39,8 +40,8 @@ from pyrogram.api.errors import (
|
||||
PhoneNumberUnoccupied, PhoneCodeInvalid, PhoneCodeHashEmpty,
|
||||
PhoneCodeExpired, PhoneCodeEmpty, SessionPasswordNeeded,
|
||||
PasswordHashInvalid, FloodWait, PeerIdInvalid, FilePartMissing,
|
||||
ChatAdminRequired, FirstnameInvalid, PhoneNumberBanned
|
||||
)
|
||||
ChatAdminRequired, FirstnameInvalid, PhoneNumberBanned,
|
||||
VolumeLocNotFound)
|
||||
from pyrogram.api.types import (
|
||||
User, Chat, Channel,
|
||||
PeerUser, PeerChannel,
|
||||
@ -49,6 +50,7 @@ from pyrogram.api.types import (
|
||||
)
|
||||
from pyrogram.crypto import AES
|
||||
from pyrogram.session import Auth, Session
|
||||
from pyrogram.session.internals import MsgId
|
||||
from .input_media import InputMedia
|
||||
from .style import Markdown, HTML
|
||||
|
||||
@ -103,6 +105,7 @@ class Client:
|
||||
INVITE_LINK_RE = re.compile(r"^(?:https?://)?t\.me/joinchat/(.+)$")
|
||||
DIALOGS_AT_ONCE = 100
|
||||
UPDATES_WORKERS = 2
|
||||
DOWNLOAD_WORKERS = 1
|
||||
|
||||
def __init__(self,
|
||||
session_name: str,
|
||||
@ -148,6 +151,8 @@ class Client:
|
||||
self.update_queue = Queue()
|
||||
self.update_handler = None
|
||||
|
||||
self.download_queue = Queue()
|
||||
|
||||
def start(self):
|
||||
"""Use this method to start the Client after creating it.
|
||||
Requires no parameters.
|
||||
@ -176,7 +181,7 @@ class Client:
|
||||
self.password = None
|
||||
self.save_session()
|
||||
|
||||
self.rnd_id = self.session.msg_id
|
||||
self.rnd_id = MsgId
|
||||
self.get_dialogs()
|
||||
|
||||
for i in range(self.UPDATES_WORKERS):
|
||||
@ -185,6 +190,9 @@ class Client:
|
||||
for i in range(self.workers):
|
||||
Thread(target=self.update_worker, name="UpdateWorker#{}".format(i + 1)).start()
|
||||
|
||||
for i in range(self.DOWNLOAD_WORKERS):
|
||||
Thread(target=self.download_worker, name="DownloadWorker#{}".format(i + 1)).start()
|
||||
|
||||
mimetypes.init()
|
||||
|
||||
def stop(self):
|
||||
@ -199,6 +207,9 @@ class Client:
|
||||
for _ in range(self.workers):
|
||||
self.update_queue.put(None)
|
||||
|
||||
for _ in range(self.DOWNLOAD_WORKERS):
|
||||
self.download_queue.put(None)
|
||||
|
||||
def fetch_peers(self, entities: list):
|
||||
for entity in entities:
|
||||
if isinstance(entity, User):
|
||||
@ -260,6 +271,67 @@ class Client:
|
||||
if username is not None:
|
||||
self.peers_by_username[username] = input_peer
|
||||
|
||||
def download_worker(self):
|
||||
name = threading.current_thread().name
|
||||
log.debug("{} started".format(name))
|
||||
|
||||
while True:
|
||||
message = self.download_queue.get()
|
||||
|
||||
if message is None:
|
||||
break
|
||||
|
||||
message, done = message
|
||||
|
||||
try:
|
||||
if isinstance(message.media, types.MessageMediaDocument):
|
||||
document = message.media.document
|
||||
|
||||
if isinstance(document, types.Document):
|
||||
file_name = "doc_{}{}".format(
|
||||
datetime.fromtimestamp(document.date).strftime("%Y-%m-%d_%H-%M-%S"),
|
||||
mimetypes.guess_extension(document.mime_type) or ".unknown"
|
||||
)
|
||||
|
||||
for i in document.attributes:
|
||||
if isinstance(i, types.DocumentAttributeFilename):
|
||||
file_name = i.file_name
|
||||
break
|
||||
elif isinstance(i, types.DocumentAttributeSticker):
|
||||
file_name = file_name.replace("doc", "sticker")
|
||||
elif isinstance(i, types.DocumentAttributeAudio):
|
||||
file_name = file_name.replace("doc", "audio")
|
||||
elif isinstance(i, types.DocumentAttributeVideo):
|
||||
file_name = file_name.replace("doc", "video")
|
||||
elif isinstance(i, types.DocumentAttributeAnimated):
|
||||
file_name = file_name.replace("doc", "gif")
|
||||
|
||||
tmp_file_name = self.get_file(
|
||||
dc_id=document.dc_id,
|
||||
id=document.id,
|
||||
access_hash=document.access_hash,
|
||||
version=document.version
|
||||
)
|
||||
|
||||
i = 1
|
||||
while True:
|
||||
try:
|
||||
os.renames("./{}".format(tmp_file_name), "./downloads/{}".format(
|
||||
".".join(file_name.split(".")[:-1])
|
||||
+ (" ({}).".format(i) if i > 1 else ".")
|
||||
+ file_name.split(".")[-1]
|
||||
))
|
||||
except FileExistsError:
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
|
||||
done.set()
|
||||
except Exception as e:
|
||||
log.error(e, exc_info=True)
|
||||
|
||||
log.debug("{} stopped".format(name))
|
||||
|
||||
def updates_worker(self):
|
||||
name = threading.current_thread().name
|
||||
log.debug("{} started".format(name))
|
||||
@ -1667,8 +1739,7 @@ class Client:
|
||||
part_size = 512 * 1024
|
||||
file_size = os.path.getsize(path)
|
||||
file_total_parts = math.ceil(file_size / part_size)
|
||||
# is_big = True if file_size > 10 * 1024 * 1024 else False
|
||||
is_big = False # Treat all files as not-big to have the server check for the md5 sum
|
||||
is_big = True if file_size > 10 * 1024 * 1024 else False
|
||||
is_missing_part = True if file_id is not None else False
|
||||
file_id = file_id or self.rnd_id()
|
||||
md5_sum = md5() if not is_big and not is_missing_part else None
|
||||
@ -1759,22 +1830,19 @@ class Client:
|
||||
session.start()
|
||||
|
||||
if volume_id: # Photos are accessed by volume_id, local_id, secret
|
||||
file_name = "_".join(str(i) for i in [dc_id, volume_id, local_id, secret])
|
||||
|
||||
location = types.InputFileLocation(
|
||||
volume_id=volume_id,
|
||||
local_id=local_id,
|
||||
secret=secret
|
||||
)
|
||||
else: # Any other file can be more easily accessed by id and access_hash
|
||||
file_name = "_".join(str(i) for i in [dc_id, id, access_hash, version])
|
||||
|
||||
location = types.InputDocumentFileLocation(
|
||||
id=id,
|
||||
access_hash=access_hash,
|
||||
version=version
|
||||
)
|
||||
|
||||
file_name = str(MsgId())
|
||||
limit = 1024 * 1024
|
||||
offset = 0
|
||||
|
||||
@ -1822,63 +1890,57 @@ class Client:
|
||||
cdn_session.start()
|
||||
|
||||
try:
|
||||
r2 = cdn_session.send(
|
||||
functions.upload.GetCdnFile(
|
||||
location=location,
|
||||
file_token=r.file_token,
|
||||
offset=offset,
|
||||
limit=limit
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(r2, types.upload.CdnFileReuploadNeeded):
|
||||
session.send(
|
||||
functions.upload.ReuploadCdnFile(
|
||||
file_token=r.file_token,
|
||||
request_token=r2.request_token
|
||||
with open(file_name, "wb") as f:
|
||||
while True:
|
||||
r2 = cdn_session.send(
|
||||
functions.upload.GetCdnFile(
|
||||
location=location,
|
||||
file_token=r.file_token,
|
||||
offset=offset,
|
||||
limit=limit
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
with open(file_name, "wb") as f:
|
||||
while True:
|
||||
if not isinstance(r2, types.upload.CdnFile):
|
||||
|
||||
if isinstance(r2, types.upload.CdnFileReuploadNeeded):
|
||||
try:
|
||||
session.send(
|
||||
functions.upload.ReuploadCdnFile(
|
||||
file_token=r.file_token,
|
||||
request_token=r2.request_token
|
||||
)
|
||||
)
|
||||
except VolumeLocNotFound:
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
chunk = r2.bytes
|
||||
chunk = r2.bytes
|
||||
|
||||
# https://core.telegram.org/cdn#decrypting-files
|
||||
decrypted_chunk = AES.ctr_decrypt(
|
||||
chunk,
|
||||
r.encryption_key,
|
||||
r.encryption_iv,
|
||||
# https://core.telegram.org/cdn#decrypting-files
|
||||
decrypted_chunk = AES.ctr_decrypt(
|
||||
chunk,
|
||||
r.encryption_key,
|
||||
r.encryption_iv,
|
||||
offset
|
||||
)
|
||||
|
||||
hashes = session.send(
|
||||
functions.upload.GetCdnFileHashes(
|
||||
r.file_token,
|
||||
offset
|
||||
)
|
||||
)
|
||||
|
||||
hashes = session.send(
|
||||
functions.upload.GetCdnFileHashes(
|
||||
r.file_token,
|
||||
offset
|
||||
)
|
||||
)
|
||||
# https://core.telegram.org/cdn#verifying-files
|
||||
for i, h in enumerate(hashes):
|
||||
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
|
||||
assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i)
|
||||
|
||||
for i, h in enumerate(hashes):
|
||||
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
|
||||
assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i)
|
||||
f.write(decrypted_chunk)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
|
||||
f.write(decrypted_chunk)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
|
||||
offset += limit
|
||||
|
||||
r2 = cdn_session.send(
|
||||
functions.upload.GetCdnFile(
|
||||
location=location,
|
||||
file_token=r.file_token,
|
||||
offset=offset,
|
||||
limit=limit
|
||||
)
|
||||
)
|
||||
offset += limit
|
||||
except Exception as e:
|
||||
log.error(e)
|
||||
finally:
|
||||
@ -2238,3 +2300,8 @@ class Client:
|
||||
reply_to_msg_id=reply_to_message_id
|
||||
)
|
||||
)
|
||||
|
||||
def download_media(self, message: types.Message):
|
||||
done = Event()
|
||||
self.download_queue.put((message, done))
|
||||
done.wait()
|
||||
|
Loading…
Reference in New Issue
Block a user