File size: 3,773 Bytes
6eb7da6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
import lpips
import clip


from encoders.modules import BERTEmbedder
from models.clipseg import CLIPDensePredT

from huggingface_hub import hf_hub_download

STEPS = 100
USE_DDPM = False
USE_DDIM = False
USE_CPU = False
CLIP_SEG_PATH = './weights/rd64-uni.pth'
CLIP_GUIDANCE = False

def make_models():
    segmodel = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)
    segmodel.eval()

    # non-strict, because we only stored decoder weights (not CLIP weights)
    segmodel.load_state_dict(torch.load(CLIP_SEG_PATH, map_location=torch.device('cpu')), strict=False)
    # segmodel.save_pretrained("./weights/hf_clipseg")

    device = torch.device('cuda:0' if (torch.cuda.is_available() and not USE_CPU) else 'cpu')
    print('Using device:', device)

    hf_inpaint_path = hf_hub_download("alvanlii/rdm_inpaint", "inpaint.pt")
    model_state_dict = torch.load(hf_inpaint_path, map_location='cpu')

    # print(
    #     'hey',
    #     'clip_proj.weight' in model_state_dict, # True
    #     model_state_dict['input_blocks.0.0.weight'].shape[1] == 8, # True
    #     'external_block.0.0.weight' in model_state_dict # False
    # )

    model_params = {
        'attention_resolutions': '32,16,8',
        'class_cond': False,
        'diffusion_steps': 1000,
        'rescale_timesteps': True,
        'timestep_respacing': STEPS,  # Modify this value to decrease the number of
                                    # timesteps.
        'image_size': 32,
        'learn_sigma': False,
        'noise_schedule': 'linear',
        'num_channels': 320,
        'num_heads': 8,
        'num_res_blocks': 2,
        'resblock_updown': False,
        'use_fp16': False,
        'use_scale_shift_norm': False,
        'clip_embed_dim': 768,
        'image_condition': True,
        'super_res_condition': False,
    }

    if USE_DDPM:
        model_params['timestep_respacing'] = '1000'
    if USE_DDIM:
        if STEPS:
            model_params['timestep_respacing'] = 'ddim'+str(STEPS)
        else:
            model_params['timestep_respacing'] = 'ddim50'
    elif STEPS:
        model_params['timestep_respacing'] = str(STEPS)

    model_config = model_and_diffusion_defaults()
    model_config.update(model_params)

    if USE_CPU:
        model_config['use_fp16'] = False


    model, diffusion = create_model_and_diffusion(**model_config)
    model.load_state_dict(model_state_dict, strict=False)

    model.requires_grad_(CLIP_GUIDANCE).eval().to(device)

    if model_config['use_fp16']:
        model.convert_to_fp16()
    else:
        model.convert_to_fp32()

    def set_requires_grad(model, value):
        for param in model.parameters():
            param.requires_grad = value


    lpips_model = lpips.LPIPS(net="vgg").to(device)
    hf_kl_path = hf_hub_download("alvanlii/rdm_inpaint", "kl-f8.pt")

    ldm = torch.load(hf_kl_path, map_location="cpu")
    
    # torch.save(ldm, "./weights/hf_ldm")
    ldm.to(device)
    ldm.eval()
    ldm.requires_grad_(CLIP_GUIDANCE)
    set_requires_grad(ldm, CLIP_GUIDANCE)

    bert = BERTEmbedder(1280, 32)
    hf_bert_path = hf_hub_download("alvanlii/rdm_inpaint", 'bert.pt')
    # bert = BERTEmbedder.from_pretrained("alvanlii/rdm_bert")
    sd = torch.load(hf_bert_path, map_location="cpu")
    bert.load_state_dict(sd)
    # bert.save_pretrained("./weights/hf_bert")

    bert.to(device)
    bert.half().eval()
    set_requires_grad(bert, False)


    clip_model, clip_preprocess = clip.load('ViT-L/14', device=device, jit=False)
    clip_model.eval().requires_grad_(False)

    return segmodel, model, diffusion, ldm, bert, clip_model, model_params


if __name__ == "__main__":
    make_models()