Spaces:
Runtime error
Runtime error
""" Utiliy functions to load pre-trained models more easily """ | |
import os | |
import pkg_resources | |
from omegaconf import OmegaConf | |
import torch | |
from huggingface_hub import hf_hub_download | |
from imagedream.ldm.util import instantiate_from_config | |
PRETRAINED_MODELS = { | |
"sd-v2.1-base-4view-ipmv": { | |
"config": "sd_v2_base_ipmv.yaml", | |
"repo_id": "Peng-Wang/ImageDream", | |
"filename": "sd-v2.1-base-4view-ipmv.pt", | |
}, | |
"sd-v2.1-base-4view-ipmv-local": { | |
"config": "sd_v2_base_ipmv_local.yaml", | |
"repo_id": "Peng-Wang/ImageDream", | |
"filename": "sd-v2.1-base-4view-ipmv-local.pt", | |
}, | |
} | |
def get_config_file(config_path): | |
cfg_file = pkg_resources.resource_filename( | |
"imagedream", os.path.join("configs", config_path) | |
) | |
if not os.path.exists(cfg_file): | |
raise RuntimeError(f"Config {config_path} not available!") | |
return cfg_file | |
def build_model(model_name, config_path=None, ckpt_path=None, cache_dir=None): | |
if (config_path is not None) and (ckpt_path is not None): | |
config = OmegaConf.load(config_path) | |
model = instantiate_from_config(config.model) | |
model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False) | |
return model | |
if not model_name in PRETRAINED_MODELS: | |
raise RuntimeError( | |
f"Model name {model_name} is not a pre-trained model. Available models are:\n- " | |
+ "\n- ".join(PRETRAINED_MODELS.keys()) | |
) | |
model_info = PRETRAINED_MODELS[model_name] | |
# Instiantiate the model | |
print(f"Loading model from config: {model_info['config']}") | |
config_file = get_config_file(model_info["config"]) | |
config = OmegaConf.load(config_file) | |
model = instantiate_from_config(config.model) | |
# Load pre-trained checkpoint from huggingface | |
if not ckpt_path: | |
ckpt_path = hf_hub_download( | |
repo_id=model_info["repo_id"], | |
filename=model_info["filename"], | |
cache_dir=cache_dir, | |
) | |
print(f"Loading model from cache file: {ckpt_path}") | |
model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False) | |
return model | |