import os import sys from typing import Iterable, Iterator, Optional, BinaryIO from telethon.tl.types import Message, DocumentAttributeFilename if sys.version_info < (3, 8): cached_property = property else: from functools import cached_property CHUNK_FILE_SIZE = 1024 * 1024 def pipe_file(read_file_name: str, write_file: BinaryIO): """Read a file by its file name and write in another file already open.""" with open(read_file_name, "rb") as read_file: while True: data = read_file.read(CHUNK_FILE_SIZE) if data: write_file.write(data) else: break class JoinStrategyBase: """Base class to inherit join strategies. The strategies depend on the file type. For example, zip files and rar files do not merge in the same way. """ def __init__(self): self.download_files = [] def is_part(self, download_file: 'DownloadFile') -> bool: """Returns if the download file is part of this bundle.""" raise NotImplementedError def add_download_file(self, download_file: 'DownloadFile') -> None: """Add a download file to this bundle.""" if download_file in self.download_files: return self.download_files.append(download_file) @classmethod def is_applicable(cls, download_file: 'DownloadFile') -> bool: """Returns if this strategy is applicable to the download file.""" raise NotImplementedError def join_download_files(self): """Join the downloaded files in the bundle.""" raise NotImplementedError class UnionJoinStrategy(JoinStrategyBase): """Join separate files without any application. These files have extension 01, 02, 03... """ base_name: Optional[str] = None @staticmethod def get_base_name(download_file: 'DownloadFile'): """Returns the file name without extension.""" return download_file.file_name.rsplit(".", 1)[0] def add_download_file(self, download_file: 'DownloadFile') -> None: """Add a download file to this bundle.""" if self.base_name is None: self.base_name = self.get_base_name(download_file) super().add_download_file(download_file) def is_part(self, download_file: 'DownloadFile') -> bool: """Returns if the download file is part of this bundle.""" return self.base_name == self.get_base_name(download_file) @classmethod def is_applicable(cls, download_file: 'DownloadFile') -> bool: """Returns if this strategy is applicable to the download file.""" return download_file.file_name_extension.isdigit() def join_download_files(self): """Join the downloaded files in the bundle.""" download_files = self.download_files sorted_files = sorted(download_files, key=lambda x: x.file_name_extension) sorted_files = [file for file in sorted_files if os.path.lexists(file.downloaded_file_name or "")] if not sorted_files or len(sorted_files) - 1 != int(sorted_files[-1].file_name_extension): # There are parts of the file missing. Stopping... return with open(self.get_base_name(sorted_files[0]), "wb") as new_file: for download_file in sorted_files: pipe_file(download_file.downloaded_file_name, new_file) for download_file in sorted_files: os.remove(download_file.downloaded_file_name) JOIN_STRATEGIES = [ UnionJoinStrategy, ] def get_join_strategy(download_file: 'DownloadFile') -> Optional[JoinStrategyBase]: """Get join strategy for the download file. An instance is returned if a strategy is available. Otherwise, None is returned. """ for strategy_cls in JOIN_STRATEGIES: if strategy_cls.is_applicable(download_file): strategy = strategy_cls() strategy.add_download_file(download_file) return strategy class DownloadFile: """File to download. This includes the Telethon message with the file.""" downloaded_file_name: Optional[str] = None def __init__(self, message: Message): """Creates the download file instance from the message.""" self.message = message def set_download_file_name(self, file_name): """After download the file, set the final download file name.""" self.downloaded_file_name = file_name @cached_property def filename_attr(self) -> Optional[DocumentAttributeFilename]: """Get the document attribute file name attribute in the document.""" return next(filter(lambda x: isinstance(x, DocumentAttributeFilename), self.document.attributes), None) @cached_property def file_name(self) -> str: """Get the file name.""" return self.filename_attr.file_name if self.filename_attr else 'Unknown' @property def file_name_extension(self) -> str: """Get the file name extension.""" parts = self.file_name.rsplit(".", 1) return parts[-1] if len(parts) >= 2 else "" @property def document(self): """Get the message document.""" return self.message.document @property def size(self) -> int: """Get the file size.""" return self.document.size def __eq__(self, other: 'DownloadFile'): """Compare download files by their file name.""" return self.file_name == other.file_name class DownloadSplitFilesBase: """Iterate over complete and split files. Base class to inherit.""" def __init__(self, messages: Iterable[Message]): self.messages = messages def get_iterator(self) -> Iterator[DownloadFile]: """Get an iterator with the download files.""" raise NotImplementedError def __iter__(self) -> 'DownloadSplitFilesBase': """Set the iterator from the get_iterator method.""" self._iterator = self.get_iterator() return self def __next__(self) -> 'DownloadFile': """Get the next download file in the iterator.""" if self._iterator is None: self._iterator = self.get_iterator() return next(self._iterator) class KeepDownloadSplitFiles(DownloadSplitFilesBase): """Download split files without join it.""" def get_iterator(self) -> Iterator[DownloadFile]: """Get an iterator with the download files.""" return map(lambda message: DownloadFile(message), self.messages) class JoinDownloadSplitFiles(DownloadSplitFilesBase): """Download split files and join it.""" def get_iterator(self) -> Iterator[DownloadFile]: """Get an iterator with the download files. This method applies the join strategy and joins the files after download it. """ current_join_strategy: Optional[JoinStrategyBase] = None for message in self.messages: download_file = DownloadFile(message) yield download_file if current_join_strategy and current_join_strategy.is_part(download_file): # There is a bundle in process and the download file is part of it. Add the download # file to the bundle. current_join_strategy.add_download_file(download_file) elif current_join_strategy and not current_join_strategy.is_part(download_file): # There is a bundle in process and the download file is not part of it. Join the files # in the bundle and finish it. current_join_strategy.join_download_files() current_join_strategy = None if current_join_strategy is None: # There is no bundle in process. Get the current bundle if the file has a strategy # available. current_join_strategy = get_join_strategy(download_file) else: # After finish all the files, join the latest bundle. if current_join_strategy: current_join_strategy.join_download_files()