JustinLin610's picture
first commit
ee21b96
raw
history blame
5.27 kB
import os, hashlib
import requests
from tqdm import tqdm
import importlib
URL_MAP = {
"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
}
CKPT_MAP = {
"vgg_lpips": "vgg.pth"
}
MD5_MAP = {
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
}
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def instantiate_from_config(config):
if not "target" in config:
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def download(url, local_path, chunk_size=1024):
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
with requests.get(url, stream=True) as r:
total_size = int(r.headers.get("content-length", 0))
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
with open(local_path, "wb") as f:
for data in r.iter_content(chunk_size=chunk_size):
if data:
f.write(data)
pbar.update(chunk_size)
def md5_hash(path):
with open(path, "rb") as f:
content = f.read()
return hashlib.md5(content).hexdigest()
def get_ckpt_path(name, root, check=False):
assert name in URL_MAP
path = os.path.join(root, CKPT_MAP[name])
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
download(URL_MAP[name], path)
md5 = md5_hash(path)
assert md5 == MD5_MAP[name], md5
return path
class KeyNotFoundError(Exception):
def __init__(self, cause, keys=None, visited=None):
self.cause = cause
self.keys = keys
self.visited = visited
messages = list()
if keys is not None:
messages.append("Key not found: {}".format(keys))
if visited is not None:
messages.append("Visited: {}".format(visited))
messages.append("Cause:\n{}".format(cause))
message = "\n".join(messages)
super().__init__(message)
def retrieve(
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
):
"""Given a nested list or dict return the desired value at key expanding
callable nodes if necessary and :attr:`expand` is ``True``. The expansion
is done in-place.
Parameters
----------
list_or_dict : list or dict
Possibly nested list or dictionary.
key : str
key/to/value, path like string describing all keys necessary to
consider to get to the desired value. List indices can also be
passed here.
splitval : str
String that defines the delimiter between keys of the
different depth levels in `key`.
default : obj
Value returned if :attr:`key` is not found.
expand : bool
Whether to expand callable nodes on the path or not.
Returns
-------
The desired value or if :attr:`default` is not ``None`` and the
:attr:`key` is not found returns ``default``.
Raises
------
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
``None``.
"""
keys = key.split(splitval)
success = True
try:
visited = []
parent = None
last_key = None
for key in keys:
if callable(list_or_dict):
if not expand:
raise KeyNotFoundError(
ValueError(
"Trying to get past callable node with expand=False."
),
keys=keys,
visited=visited,
)
list_or_dict = list_or_dict()
parent[last_key] = list_or_dict
last_key = key
parent = list_or_dict
try:
if isinstance(list_or_dict, dict):
list_or_dict = list_or_dict[key]
else:
list_or_dict = list_or_dict[int(key)]
except (KeyError, IndexError, ValueError) as e:
raise KeyNotFoundError(e, keys=keys, visited=visited)
visited += [key]
# final expansion of retrieved value
if expand and callable(list_or_dict):
list_or_dict = list_or_dict()
parent[last_key] = list_or_dict
except KeyNotFoundError as e:
if default is None:
raise e
else:
list_or_dict = default
success = False
if not pass_success:
return list_or_dict
else:
return list_or_dict, success
if __name__ == "__main__":
config = {"keya": "a",
"keyb": "b",
"keyc":
{"cc1": 1,
"cc2": 2,
}
}
from omegaconf import OmegaConf
config = OmegaConf.create(config)
print(config)
retrieve(config, "keya")