File size: 1,966 Bytes
239ee43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
from torch import nn
from functools import reduce
from pathlib import Path

from imagen_pytorch.configs import ImagenConfig, ElucidatedImagenConfig
from ema_pytorch import EMA

def exists(val):
    return val is not None

def safeget(dictionary, keys, default = None):
    return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)

def load_imagen_from_checkpoint(
    checkpoint_path,
    load_weights = True,
    load_ema_if_available = False
):
    model_path = Path(checkpoint_path)
    full_model_path = str(model_path.resolve())
    assert model_path.exists(), f'checkpoint not found at {full_model_path}'
    loaded = torch.load(str(model_path), map_location='cpu')

    imagen_params = safeget(loaded, 'imagen_params')
    imagen_type = safeget(loaded, 'imagen_type')

    if imagen_type == 'original':
        imagen_klass = ImagenConfig
    elif imagen_type == 'elucidated':
        imagen_klass = ElucidatedImagenConfig
    else:
        raise ValueError(f'unknown imagen type {imagen_type} - you need to instantiate your Imagen with configurations, using classes ImagenConfig or ElucidatedImagenConfig')

    assert exists(imagen_params) and exists(imagen_type), 'imagen type and configuration not saved in this checkpoint'

    imagen = imagen_klass(**imagen_params).create()

    if not load_weights:
        return imagen

    has_ema = 'ema' in loaded
    should_load_ema = has_ema and load_ema_if_available

    imagen.load_state_dict(loaded['model'])

    if not should_load_ema:
        print('loading non-EMA version of unets')
        return imagen

    ema_unets = nn.ModuleList([])
    for unet in imagen.unets:
        ema_unets.append(EMA(unet))

    ema_unets.load_state_dict(loaded['ema'])

    for unet, ema_unet in zip(imagen.unets, ema_unets):
        unet.load_state_dict(ema_unet.ema_model.state_dict())

    print('loaded EMA version of unets')
    return imagen