# Adapted from https://github.com/pytorch/audio/ import hashlib import logging import os import tarfile import urllib import urllib.request import zipfile from os.path import expanduser from typing import Any, Iterable, List, Optional from torch.utils.model_zoo import tqdm logger = logging.getLogger(__name__) def stream_url( url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True ) -> Iterable: """Stream url by chunk Args: url (str): Url. start_byte (int or None, optional): Start streaming at that point (Default: ``None``). block_size (int, optional): Size of chunks to stream (Default: ``32 * 1024``). progress_bar (bool, optional): Display a progress bar (Default: ``True``). """ # If we already have the whole file, there is no need to download it again req = urllib.request.Request(url, method="HEAD") with urllib.request.urlopen(req) as response: url_size = int(response.info().get("Content-Length", -1)) if url_size == start_byte: return req = urllib.request.Request(url) if start_byte: req.headers["Range"] = "bytes={}-".format(start_byte) with ( urllib.request.urlopen(req) as upointer, tqdm( unit="B", unit_scale=True, unit_divisor=1024, total=url_size, disable=not progress_bar, ) as pbar, ): num_bytes = 0 while True: chunk = upointer.read(block_size) if not chunk: break yield chunk num_bytes += len(chunk) pbar.update(len(chunk)) def download_url( url: str, download_folder: str, filename: Optional[str] = None, hash_value: Optional[str] = None, hash_type: str = "sha256", progress_bar: bool = True, resume: bool = False, ) -> None: """Download file to disk. Args: url (str): Url. download_folder (str): Folder to download file. filename (str or None, optional): Name of downloaded file. If None, it is inferred from the url (Default: ``None``). hash_value (str or None, optional): Hash for url (Default: ``None``). hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``). progress_bar (bool, optional): Display a progress bar (Default: ``True``). resume (bool, optional): Enable resuming download (Default: ``False``). """ req = urllib.request.Request(url, method="HEAD") req_info = urllib.request.urlopen(req).info() # pylint: disable=consider-using-with # Detect filename filename = filename or req_info.get_filename() or os.path.basename(url) filepath = os.path.join(download_folder, filename) if resume and os.path.exists(filepath): mode = "ab" local_size: Optional[int] = os.path.getsize(filepath) elif not resume and os.path.exists(filepath): raise RuntimeError("{} already exists. Delete the file manually and retry.".format(filepath)) else: mode = "wb" local_size = None if hash_value and local_size == int(req_info.get("Content-Length", -1)): with open(filepath, "rb") as file_obj: if validate_file(file_obj, hash_value, hash_type): return raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath)) with open(filepath, mode) as fpointer: for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar): fpointer.write(chunk) with open(filepath, "rb") as file_obj: if hash_value and not validate_file(file_obj, hash_value, hash_type): raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath)) def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> bool: """Validate a given file object with its hash. Args: file_obj: File object to read from. hash_value (str): Hash for url. hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``). Returns: bool: return True if its a valid file, else False. """ if hash_type == "sha256": hash_func = hashlib.sha256() elif hash_type == "md5": hash_func = hashlib.md5() else: raise ValueError while True: # Read by chunk to avoid filling memory chunk = file_obj.read(1024**2) if not chunk: break hash_func.update(chunk) return hash_func.hexdigest() == hash_value def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]: """Extract archive. Args: from_path (str): the path of the archive. to_path (str or None, optional): the root path of the extraced files (directory of from_path) (Default: ``None``) overwrite (bool, optional): overwrite existing files (Default: ``False``) Returns: list: List of paths to extracted files even if not overwritten. """ logger.info("Extracting archive file...") if to_path is None: to_path = os.path.dirname(from_path) try: with tarfile.open(from_path, "r") as tar: logger.info("Opened tar file %s.", from_path) files = [] for file_ in tar: # type: Any file_path = os.path.join(to_path, file_.name) if file_.isfile(): files.append(file_path) if os.path.exists(file_path): logger.info("%s already extracted.", file_path) if not overwrite: continue tar.extract(file_, to_path) return files except tarfile.ReadError: pass try: with zipfile.ZipFile(from_path, "r") as zfile: logger.info("Opened zip file %s.", from_path) files = zfile.namelist() for file_ in files: file_path = os.path.join(to_path, file_) if os.path.exists(file_path): logger.info("%s already extracted.", file_path) if not overwrite: continue zfile.extract(file_, to_path) return files except zipfile.BadZipFile: pass raise NotImplementedError(" > [!] only supports tar.gz, tgz, and zip achives.") def download_kaggle_dataset(dataset_path: str, dataset_name: str, output_path: str): """Download dataset from kaggle. Args: dataset_path (str): This the kaggle link to the dataset. for example vctk is 'mfekadu/english-multispeaker-corpus-for-voice-cloning' dataset_name (str): Name of the folder the dataset will be saved in. output_path (str): Path of the location you want the dataset folder to be saved to. """ data_path = os.path.join(output_path, dataset_name) try: import kaggle # pylint: disable=import-outside-toplevel kaggle.api.authenticate() logger.info("Downloading %s...", dataset_name) kaggle.api.dataset_download_files(dataset_path, path=data_path, unzip=True) except OSError: logger.exception( "In order to download kaggle datasets, you need to have a kaggle api token stored in your %s", os.path.join(expanduser("~"), ".kaggle/kaggle.json"), )