|
import torch |
|
from torch import nn |
|
from functools import reduce |
|
from pathlib import Path |
|
|
|
from imagen_pytorch.configs import ImagenConfig, ElucidatedImagenConfig |
|
from ema_pytorch import EMA |
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
def safeget(dictionary, keys, default = None): |
|
return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary) |
|
|
|
def load_imagen_from_checkpoint( |
|
checkpoint_path, |
|
load_weights = True, |
|
load_ema_if_available = False |
|
): |
|
model_path = Path(checkpoint_path) |
|
full_model_path = str(model_path.resolve()) |
|
assert model_path.exists(), f'checkpoint not found at {full_model_path}' |
|
loaded = torch.load(str(model_path), map_location='cpu') |
|
|
|
imagen_params = safeget(loaded, 'imagen_params') |
|
imagen_type = safeget(loaded, 'imagen_type') |
|
|
|
if imagen_type == 'original': |
|
imagen_klass = ImagenConfig |
|
elif imagen_type == 'elucidated': |
|
imagen_klass = ElucidatedImagenConfig |
|
else: |
|
raise ValueError(f'unknown imagen type {imagen_type} - you need to instantiate your Imagen with configurations, using classes ImagenConfig or ElucidatedImagenConfig') |
|
|
|
assert exists(imagen_params) and exists(imagen_type), 'imagen type and configuration not saved in this checkpoint' |
|
|
|
imagen = imagen_klass(**imagen_params).create() |
|
|
|
if not load_weights: |
|
return imagen |
|
|
|
has_ema = 'ema' in loaded |
|
should_load_ema = has_ema and load_ema_if_available |
|
|
|
imagen.load_state_dict(loaded['model']) |
|
|
|
if not should_load_ema: |
|
print('loading non-EMA version of unets') |
|
return imagen |
|
|
|
ema_unets = nn.ModuleList([]) |
|
for unet in imagen.unets: |
|
ema_unets.append(EMA(unet)) |
|
|
|
ema_unets.load_state_dict(loaded['ema']) |
|
|
|
for unet, ema_unet in zip(imagen.unets, ema_unets): |
|
unet.load_state_dict(ema_unet.ema_model.state_dict()) |
|
|
|
print('loaded EMA version of unets') |
|
return imagen |
|
|