import hashlib import os import tarfile import urllib.request from tqdm import tqdm def print_arguments(args): print("----------- Configuration Arguments -----------") for arg, value in vars(args).items(): print("%s: %s" % (arg, value)) print("------------------------------------------------") def strtobool(val): val = val.lower() if val in ('y', 'yes', 't', 'true', 'on', '1'): return True elif val in ('n', 'no', 'f', 'false', 'off', '0'): return False else: raise ValueError("invalid truth value %r" % (val,)) def str_none(val): if val == 'None': return None else: return val def add_arguments(argname, type, default, help, argparser, **kwargs): type = strtobool if type == bool else type type = str_none if type == str else type argparser.add_argument("--" + argname, default=default, type=type, help=help + ' Default: %(default)s.', **kwargs) def md5file(fname): hash_md5 = hashlib.md5() f = open(fname, "rb") for chunk in iter(lambda: f.read(4096), b""): hash_md5.update(chunk) f.close() return hash_md5.hexdigest() def download(url, md5sum, target_dir): """Download file from url to target_dir, and check md5sum.""" if not os.path.exists(target_dir): os.makedirs(target_dir) filepath = os.path.join(target_dir, url.split("/")[-1]) if not (os.path.exists(filepath) and md5file(filepath) == md5sum): print(f"Downloading {url} to {filepath} ...") with urllib.request.urlopen(url) as source, open(filepath, "wb") as output: with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: while True: buffer = source.read(8192) if not buffer: break output.write(buffer) loop.update(len(buffer)) print(f"\nMD5 Chesksum {filepath} ...") if not md5file(filepath) == md5sum: raise RuntimeError("MD5 checksum failed.") else: print(f"File exists, skip downloading. ({filepath})") return filepath def unpack(filepath, target_dir, rm_tar=False): """Unpack the file to the target_dir.""" print("Unpacking %s ..." % filepath) tar = tarfile.open(filepath) tar.extractall(target_dir) tar.close() if rm_tar: os.remove(filepath) def make_inputs_require_grad(module, input, output): output.requires_grad_(True)