|
|
|
|
|
import hashlib |
|
import logging |
|
import os |
|
import tarfile |
|
import urllib |
|
import urllib.request |
|
import zipfile |
|
from os.path import expanduser |
|
from typing import Any, Iterable, List, Optional |
|
|
|
from torch.utils.model_zoo import tqdm |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def stream_url( |
|
url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True |
|
) -> Iterable: |
|
"""Stream url by chunk |
|
|
|
Args: |
|
url (str): Url. |
|
start_byte (int or None, optional): Start streaming at that point (Default: ``None``). |
|
block_size (int, optional): Size of chunks to stream (Default: ``32 * 1024``). |
|
progress_bar (bool, optional): Display a progress bar (Default: ``True``). |
|
""" |
|
|
|
|
|
req = urllib.request.Request(url, method="HEAD") |
|
with urllib.request.urlopen(req) as response: |
|
url_size = int(response.info().get("Content-Length", -1)) |
|
if url_size == start_byte: |
|
return |
|
|
|
req = urllib.request.Request(url) |
|
if start_byte: |
|
req.headers["Range"] = "bytes={}-".format(start_byte) |
|
|
|
with ( |
|
urllib.request.urlopen(req) as upointer, |
|
tqdm( |
|
unit="B", |
|
unit_scale=True, |
|
unit_divisor=1024, |
|
total=url_size, |
|
disable=not progress_bar, |
|
) as pbar, |
|
): |
|
num_bytes = 0 |
|
while True: |
|
chunk = upointer.read(block_size) |
|
if not chunk: |
|
break |
|
yield chunk |
|
num_bytes += len(chunk) |
|
pbar.update(len(chunk)) |
|
|
|
|
|
def download_url( |
|
url: str, |
|
download_folder: str, |
|
filename: Optional[str] = None, |
|
hash_value: Optional[str] = None, |
|
hash_type: str = "sha256", |
|
progress_bar: bool = True, |
|
resume: bool = False, |
|
) -> None: |
|
"""Download file to disk. |
|
|
|
Args: |
|
url (str): Url. |
|
download_folder (str): Folder to download file. |
|
filename (str or None, optional): Name of downloaded file. If None, it is inferred from the url |
|
(Default: ``None``). |
|
hash_value (str or None, optional): Hash for url (Default: ``None``). |
|
hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``). |
|
progress_bar (bool, optional): Display a progress bar (Default: ``True``). |
|
resume (bool, optional): Enable resuming download (Default: ``False``). |
|
""" |
|
|
|
req = urllib.request.Request(url, method="HEAD") |
|
req_info = urllib.request.urlopen(req).info() |
|
|
|
|
|
filename = filename or req_info.get_filename() or os.path.basename(url) |
|
filepath = os.path.join(download_folder, filename) |
|
if resume and os.path.exists(filepath): |
|
mode = "ab" |
|
local_size: Optional[int] = os.path.getsize(filepath) |
|
|
|
elif not resume and os.path.exists(filepath): |
|
raise RuntimeError("{} already exists. Delete the file manually and retry.".format(filepath)) |
|
else: |
|
mode = "wb" |
|
local_size = None |
|
|
|
if hash_value and local_size == int(req_info.get("Content-Length", -1)): |
|
with open(filepath, "rb") as file_obj: |
|
if validate_file(file_obj, hash_value, hash_type): |
|
return |
|
raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath)) |
|
|
|
with open(filepath, mode) as fpointer: |
|
for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar): |
|
fpointer.write(chunk) |
|
|
|
with open(filepath, "rb") as file_obj: |
|
if hash_value and not validate_file(file_obj, hash_value, hash_type): |
|
raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath)) |
|
|
|
|
|
def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> bool: |
|
"""Validate a given file object with its hash. |
|
|
|
Args: |
|
file_obj: File object to read from. |
|
hash_value (str): Hash for url. |
|
hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``). |
|
|
|
Returns: |
|
bool: return True if its a valid file, else False. |
|
""" |
|
|
|
if hash_type == "sha256": |
|
hash_func = hashlib.sha256() |
|
elif hash_type == "md5": |
|
hash_func = hashlib.md5() |
|
else: |
|
raise ValueError |
|
|
|
while True: |
|
|
|
chunk = file_obj.read(1024**2) |
|
if not chunk: |
|
break |
|
hash_func.update(chunk) |
|
|
|
return hash_func.hexdigest() == hash_value |
|
|
|
|
|
def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]: |
|
"""Extract archive. |
|
Args: |
|
from_path (str): the path of the archive. |
|
to_path (str or None, optional): the root path of the extraced files (directory of from_path) |
|
(Default: ``None``) |
|
overwrite (bool, optional): overwrite existing files (Default: ``False``) |
|
|
|
Returns: |
|
list: List of paths to extracted files even if not overwritten. |
|
""" |
|
logger.info("Extracting archive file...") |
|
if to_path is None: |
|
to_path = os.path.dirname(from_path) |
|
|
|
try: |
|
with tarfile.open(from_path, "r") as tar: |
|
logger.info("Opened tar file %s.", from_path) |
|
files = [] |
|
for file_ in tar: |
|
file_path = os.path.join(to_path, file_.name) |
|
if file_.isfile(): |
|
files.append(file_path) |
|
if os.path.exists(file_path): |
|
logger.info("%s already extracted.", file_path) |
|
if not overwrite: |
|
continue |
|
tar.extract(file_, to_path) |
|
return files |
|
except tarfile.ReadError: |
|
pass |
|
|
|
try: |
|
with zipfile.ZipFile(from_path, "r") as zfile: |
|
logger.info("Opened zip file %s.", from_path) |
|
files = zfile.namelist() |
|
for file_ in files: |
|
file_path = os.path.join(to_path, file_) |
|
if os.path.exists(file_path): |
|
logger.info("%s already extracted.", file_path) |
|
if not overwrite: |
|
continue |
|
zfile.extract(file_, to_path) |
|
return files |
|
except zipfile.BadZipFile: |
|
pass |
|
|
|
raise NotImplementedError(" > [!] only supports tar.gz, tgz, and zip achives.") |
|
|
|
|
|
def download_kaggle_dataset(dataset_path: str, dataset_name: str, output_path: str): |
|
"""Download dataset from kaggle. |
|
Args: |
|
dataset_path (str): |
|
This the kaggle link to the dataset. for example vctk is 'mfekadu/english-multispeaker-corpus-for-voice-cloning' |
|
dataset_name (str): Name of the folder the dataset will be saved in. |
|
output_path (str): Path of the location you want the dataset folder to be saved to. |
|
""" |
|
data_path = os.path.join(output_path, dataset_name) |
|
try: |
|
import kaggle |
|
|
|
kaggle.api.authenticate() |
|
logger.info("Downloading %s...", dataset_name) |
|
kaggle.api.dataset_download_files(dataset_path, path=data_path, unzip=True) |
|
except OSError: |
|
logger.exception( |
|
"In order to download kaggle datasets, you need to have a kaggle api token stored in your %s", |
|
os.path.join(expanduser("~"), ".kaggle/kaggle.json"), |
|
) |
|
|