File size: 5,226 Bytes
94be4c7
 
 
ab11bdd
 
81ccbca
 
 
 
 
fc73e59
 
 
 
94be4c7
81ccbca
ab11bdd
0002379
81ccbca
fc73e59
 
 
 
 
 
b58675c
fc73e59
 
 
 
 
 
 
 
ab11bdd
fc73e59
 
81ccbca
fc73e59
81ccbca
fc73e59
 
 
81ccbca
fc73e59
 
 
0002379
fc73e59
 
 
 
94be4c7
 
c8aa68b
94be4c7
fc73e59
 
 
 
81ccbca
fc73e59
 
 
94be4c7
fc73e59
 
 
 
 
 
94be4c7
 
fc73e59
 
 
 
 
ab11bdd
 
 
81ccbca
 
ab11bdd
 
fc73e59
 
 
 
ab11bdd
 
fc73e59
94be4c7
81ccbca
 
7c89716
 
 
fd9afda
81ccbca
 
 
 
 
 
0002379
 
81ccbca
 
 
 
 
 
 
94be4c7
 
 
 
 
 
81ccbca
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from random import random

from accelerate.utils import set_seed
from torch.cuda.amp import autocast

from StableDiffuser import StableDiffuser
from finetuning import FineTunedModel
import torch
from tqdm import tqdm

from memory_efficiency import MemoryEfficiencyWrapper


def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
          use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1):

    nsteps = 50
    diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda')

    memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers,
                                                        use_gradient_checkpointing=use_gradient_checkpointing )
    with memory_efficiency_wrapper:
        diffuser.train()
        finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
        if use_adamw8bit:
            print("using AdamW 8Bit optimizer")
            import bitsandbytes as bnb
            optimizer = bnb.optim.AdamW8bit(finetuner.parameters(),
                                            lr=lr,
                                            betas=(0.9, 0.999),
                                            weight_decay=0.010,
                                            eps=1e-8
                                            )
        else:
            print("using Adam optimizer")
            optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
        criteria = torch.nn.MSELoss()

        pbar = tqdm(range(iterations))

        with torch.no_grad():
            neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
            positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)

        del diffuser.vae
        del diffuser.text_encoder
        del diffuser.tokenizer

        torch.cuda.empty_cache()

        print(f"using img_size of {img_size}")

        if seed == -1:
            seed = random.randint(0, 2 ** 30)
        set_seed(int(seed))

        for i in pbar:
            with torch.no_grad():
                diffuser.set_scheduler_timesteps(nsteps)
                optimizer.zero_grad()

                iteration = torch.randint(1, nsteps - 1, (1,)).item()
                latents = diffuser.get_initial_latents(1, width=img_size, height=img_size, n_prompts=1)

                with finetuner:
                    latents_steps, _ = diffuser.diffusion(
                        latents,
                        positive_text_embeddings,
                        start_iteration=0,
                        end_iteration=iteration,
                        guidance_scale=3,
                        show_progress=False,
                        use_amp=use_amp
                    )

                diffuser.set_scheduler_timesteps(1000)
                iteration = int(iteration / nsteps * 1000)

                with autocast(enabled=use_amp):
                    positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
                    neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1)

            with finetuner:
                with autocast(enabled=use_amp):
                    negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)

            positive_latents.requires_grad = False
            neutral_latents.requires_grad = False

            # loss = criteria(e_n, e_0) works the best try 5000 epochs
            loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents)))
            memory_efficiency_wrapper.step(optimizer, loss)
            optimizer.zero_grad()

    torch.save(finetuner.state_dict(), save_path)

    del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents, latents_steps, latents

    torch.cuda.empty_cache()
if __name__ == '__main__':

    import argparse

    parser = argparse.ArgumentParser()

    parser.add_argument("--repo_id_or_path", required=True)
    parser.add_argument("--img_size", type=int, required=False, default=512)
    parser.add_argument('--prompt', required=True)
    parser.add_argument('--modules', required=True)
    parser.add_argument('--freeze_modules', nargs='+', required=True)
    parser.add_argument('--save_path', required=True)
    parser.add_argument('--iterations', type=int, required=True)
    parser.add_argument('--lr', type=float, required=True)
    parser.add_argument('--negative_guidance', type=float, required=True)
    parser.add_argument('--seed', type=int, required=False, default=-1,
                        help='Training seed for reproducible results, or -1 to pick a random seed')
    parser.add_argument('--use_adamw8bit', action='store_true')
    parser.add_argument('--use_xformers', action='store_true')
    parser.add_argument('--use_amp', action='store_true')
    parser.add_argument('--use_gradient_checkpointing', action='store_true')

    train(**vars(parser.parse_args()))