File size: 8,435 Bytes
157e137 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
import asyncio
import logging
from typing import Dict, Union
from FileStream.bot import work_loads
from pyrogram import Client, utils, raw
from .file_properties import get_file_ids
from pyrogram.session import Session, Auth
from pyrogram.errors import AuthBytesInvalid
from pyrogram.file_id import FileId, FileType, ThumbnailSource
from pyrogram.types import Message
class ByteStreamer:
def __init__(self, client: Client):
self.clean_timer = 30 * 60
self.client: Client = client
self.cached_file_ids: Dict[str, FileId] = {}
asyncio.create_task(self.clean_cache())
async def get_file_properties(self, db_id: str, multi_clients) -> FileId:
"""
Returns the properties of a media of a specific message in a FIleId class.
if the properties are cached, then it'll return the cached results.
or it'll generate the properties from the Message ID and cache them.
"""
if not db_id in self.cached_file_ids:
logging.debug("Before Calling generate_file_properties")
await self.generate_file_properties(db_id, multi_clients)
logging.debug(f"Cached file properties for file with ID {db_id}")
return self.cached_file_ids[db_id]
async def generate_file_properties(self, db_id: str, multi_clients) -> FileId:
"""
Generates the properties of a media file on a specific message.
returns ths properties in a FIleId class.
"""
logging.debug("Before calling get_file_ids")
file_id = await get_file_ids(self.client, db_id, multi_clients, Message)
logging.debug(f"Generated file ID and Unique ID for file with ID {db_id}")
self.cached_file_ids[db_id] = file_id
logging.debug(f"Cached media file with ID {db_id}")
return self.cached_file_ids[db_id]
async def generate_media_session(self, client: Client, file_id: FileId) -> Session:
"""
Generates the media session for the DC that contains the media file.
This is required for getting the bytes from Telegram servers.
"""
media_session = client.media_sessions.get(file_id.dc_id, None)
if media_session is None:
if file_id.dc_id != await client.storage.dc_id():
media_session = Session(
client,
file_id.dc_id,
await Auth(
client, file_id.dc_id, await client.storage.test_mode()
).create(),
await client.storage.test_mode(),
is_media=True,
)
await media_session.start()
for _ in range(6):
exported_auth = await client.invoke(
raw.functions.auth.ExportAuthorization(dc_id=file_id.dc_id)
)
try:
await media_session.invoke(
raw.functions.auth.ImportAuthorization(
id=exported_auth.id, bytes=exported_auth.bytes
)
)
break
except AuthBytesInvalid:
logging.debug(
f"Invalid authorization bytes for DC {file_id.dc_id}"
)
continue
else:
await media_session.stop()
raise AuthBytesInvalid
else:
media_session = Session(
client,
file_id.dc_id,
await client.storage.auth_key(),
await client.storage.test_mode(),
is_media=True,
)
await media_session.start()
logging.debug(f"Created media session for DC {file_id.dc_id}")
client.media_sessions[file_id.dc_id] = media_session
else:
logging.debug(f"Using cached media session for DC {file_id.dc_id}")
return media_session
@staticmethod
async def get_location(file_id: FileId) -> Union[raw.types.InputPhotoFileLocation,
raw.types.InputDocumentFileLocation,
raw.types.InputPeerPhotoFileLocation,]:
"""
Returns the file location for the media file.
"""
file_type = file_id.file_type
if file_type == FileType.CHAT_PHOTO:
if file_id.chat_id > 0:
peer = raw.types.InputPeerUser(
user_id=file_id.chat_id, access_hash=file_id.chat_access_hash
)
else:
if file_id.chat_access_hash == 0:
peer = raw.types.InputPeerChat(chat_id=-file_id.chat_id)
else:
peer = raw.types.InputPeerChannel(
channel_id=utils.get_channel_id(file_id.chat_id),
access_hash=file_id.chat_access_hash,
)
location = raw.types.InputPeerPhotoFileLocation(
peer=peer,
volume_id=file_id.volume_id,
local_id=file_id.local_id,
big=file_id.thumbnail_source == ThumbnailSource.CHAT_PHOTO_BIG,
)
elif file_type == FileType.PHOTO:
location = raw.types.InputPhotoFileLocation(
id=file_id.media_id,
access_hash=file_id.access_hash,
file_reference=file_id.file_reference,
thumb_size=file_id.thumbnail_size,
)
else:
location = raw.types.InputDocumentFileLocation(
id=file_id.media_id,
access_hash=file_id.access_hash,
file_reference=file_id.file_reference,
thumb_size=file_id.thumbnail_size,
)
return location
async def yield_file(
self,
file_id: FileId,
index: int,
offset: int,
first_part_cut: int,
last_part_cut: int,
part_count: int,
chunk_size: int,
) -> Union[str, None]:
"""
Custom generator that yields the bytes of the media file.
Modded from <https://github.com/eyaadh/megadlbot_oss/blob/master/mega/telegram/utils/custom_download.py#L20>
Thanks to Eyaadh <https://github.com/eyaadh>
"""
client = self.client
work_loads[index] += 1
logging.debug(f"Starting to yielding file with client {index}.")
media_session = await self.generate_media_session(client, file_id)
current_part = 1
location = await self.get_location(file_id)
try:
r = await media_session.invoke(
raw.functions.upload.GetFile(
location=location, offset=offset, limit=chunk_size
),
)
if isinstance(r, raw.types.upload.File):
while True:
chunk = r.bytes
if not chunk:
break
elif part_count == 1:
yield chunk[first_part_cut:last_part_cut]
elif current_part == 1:
yield chunk[first_part_cut:]
elif current_part == part_count:
yield chunk[:last_part_cut]
else:
yield chunk
current_part += 1
offset += chunk_size
if current_part > part_count:
break
r = await media_session.invoke(
raw.functions.upload.GetFile(
location=location, offset=offset, limit=chunk_size
),
)
except (TimeoutError, AttributeError):
pass
finally:
logging.debug(f"Finished yielding file with {current_part} parts.")
work_loads[index] -= 1
async def clean_cache(self) -> None:
"""
function to clean the cache to reduce memory usage
"""
while True:
await asyncio.sleep(self.clean_timer)
self.cached_file_ids.clear()
logging.debug("Cleaned the cache")
|