StreamingSVD / utils /loader.py
lev1's picture
Initial commit
8fd2f2f
import importlib
from functools import partialmethod
from pathlib import Path
from torchvision.datasets.utils import download_url
import gdown
from utils.aux import ensure_annotation_class
def get_class(cls_path: str, *args, **kwargs):
module_name = ".".join(cls_path.split(".")[:-1])
module = importlib.import_module(module_name)
class_ = getattr(module, cls_path.split(".")[-1])
class_.__init__ = partialmethod(class_.__init__, *args, **kwargs)
return class_
@ensure_annotation_class
def download_ckpt(local_path: Path, global_path: str) -> str:
if local_path.exists():
return local_path.as_posix()
else:
if not local_path.parent.exists():
local_path.parent.mkdir(parents=True)
if "drive.google.com" in global_path and "file" in global_path:
url = global_path
dest = local_path.as_posix()
gdown.download(url=url, output=dest, fuzzy=True)
elif "drive.google.com" in global_path and "folder" in global_path:
url = global_path
dest = local_path.parent.as_posix()
gdown.download_folder(url=url, output=dest)
elif local_path.suffix == ".safetensors" or "." not in local_path.as_posix():
ckpt_url = f"https://huggingface.co/{global_path}"
try:
download_url(ckpt_url, local_path.parent.as_posix(),
local_path.name)
except Exception as e:
print(
f"Error: Failed to download model from {ckpt_url} to {local_path}")
raise e
else:
raise NotImplementedError(
f"Download model file {global_path} not supported")
assert local_path.exists(), f"Missing checkpoint {local_path}"
return local_path.as_posix()