File size: 8,162 Bytes
d751051
 
 
 
 
 
 
 
9bde8da
d751051
 
 
f6f5f48
054faf7
6fa0b52
42f8b67
054faf7
2104644
 
054faf7
 
42f8b67
d751051
 
6fa0b52
d751051
 
 
 
6fa0b52
d751051
 
9bde8da
 
 
 
 
 
 
d751051
9bde8da
 
 
 
 
 
 
42f8b67
 
 
 
d751051
42f8b67
 
d751051
054faf7
d751051
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
054faf7
d751051
 
 
 
 
 
 
 
 
 
 
 
054faf7
 
42f8b67
 
 
 
 
 
 
054faf7
9bde8da
42f8b67
 
9bde8da
054faf7
42f8b67
6fa0b52
 
42f8b67
 
 
 
 
 
 
 
6fa0b52
 
 
 
 
2104644
9bde8da
d751051
42f8b67
d751051
 
 
 
 
6fa0b52
42f8b67
 
 
6fa0b52
d751051
 
f6f5f48
6fa0b52
 
 
d751051
 
42f8b67
 
 
 
9bde8da
42f8b67
d751051
42f8b67
 
6fa0b52
f6f5f48
21640c8
6fa0b52
42f8b67
6fa0b52
 
9bde8da
 
 
6fa0b52
9bde8da
 
 
 
 
 
6fa0b52
 
 
 
 
42f8b67
d751051
 
42f8b67
 
 
d751051
42f8b67
 
 
 
 
 
d751051
 
 
 
9bde8da
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
from datasets import load_dataset
from PIL.Image import Image
import PIL
from PIL.Image import Resampling
import numpy as np
from rct_diffusion_pipeline import RCTDiffusionPipeline
import torch
import torchvision.transforms as T
import torchvision
import torch.nn.functional as F
from diffusers.optimization import get_cosine_schedule_with_warmup
from tqdm.auto import tqdm
from accelerate import Accelerator
from diffusers import DDPMScheduler, UNet2DConditionModel, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
import torch.nn as nn

SAMPLE_SIZE = 256
LATENT_SIZE = 32
SAMPLE_NUM_CHANNELS = 3
LATENT_NUM_CHANNELS = 4
from torchvision import transforms

def save_and_test(pipeline, epoch):
    outputs = pipeline(['aleppo pine tree'], ['dark green'])
    for image_index in range(len(outputs)):
        file_name = f'out{image_index}_{epoch}.png'
        outputs[image_index].save(file_name)
    
    model_file = f'rct_foliage_{epoch}'
    pipeline.save_pretrained(model_file)

def transform_images(image):
    res = torch.Tensor(SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE)
    pil_to_tensor = T.PILToTensor()

    res_index = 0
    scale_factor = np.minimum(SAMPLE_SIZE / image.width, SAMPLE_SIZE / image.height)
    image = Image.resize(image, size=(int(scale_factor * image.width), int(scale_factor * image.height)), resample=Resampling.NEAREST)
    
    new_image = PIL.Image.new('RGB', (SAMPLE_SIZE, SAMPLE_SIZE))
    new_image.paste(image, box=(int((SAMPLE_SIZE - image.width)/2), int((SAMPLE_SIZE - image.height)/2)))
    res = pil_to_tensor(new_image)
    return res
        
def convert_images(dataset):
    images = [transform_images(image) for image in dataset["image"]]
    object_descriptions = [obj_desc for obj_desc in dataset["object_description"]]
    colors1 = [color1 for color1 in dataset['color1']]
    colors2 = [color1 for color1 in dataset['color2']]
    colors3 = [color1 for color1 in dataset['color3']]

    return {"image": images, 'object_description':object_descriptions, 'color1':colors1, \
            'color2':colors2, 'color3':colors3}

def convert_labels(dataset, model, num_images):
    # get the labels
    view0_entries = [row for row in dataset if row['view'] == 0]
    obj_descriptions = [row['object_description'] for row in view0_entries]
    colors1 = [row['color1'] for row in view0_entries]
    colors2 = [row['color2'] for row in view0_entries]
    colors3 = [row['color3'] for row in view0_entries]

    del view0_entries

    # convert those descriptions, color1, color2 and color3 to a list of tuple with label and weight=1.0
    obj_descriptions = [[(obj_desc, 1.0)] for obj_desc in obj_descriptions]
    colors1 = [[(color1, 1.0)] for color1 in colors1]
    colors2 = [[(color2, 1.0)] for color2 in colors2]
    colors3 = [[(color3, 1.0)] for color3 in colors3]

    # convert those tuples in numpy arrays using the helper function of the model
    
    obj_descriptions = [model.get_object_description_weights(obj_desc) for obj_desc in obj_descriptions]
    colors1 = [model.get_color1_weights(color1) for color1 in colors1]
    colors2 = [model.get_color2_weights(color2) for color2 in colors2]
    colors3 = [model.get_color3_weights(color3) for color3 in colors3]

    # finally, convert those numpy arrays to a tensor
    class_labels = model.pack_labels_to_tensor(num_images, obj_descriptions, colors1, colors2, colors3)
    del obj_descriptions
    del colors1
    del colors2
    del colors3
    del dataset
    return class_labels.to(dtype=torch.float16, device='cuda')

def create_embeddings(dataset, model):
    object_descriptions = dataset['object_description']
    colors1 = dataset['color1']
    colors2 = dataset['color2']
    colors3 = dataset['color3']
    return model.test_generate_embeddings(object_descriptions, colors1, colors2, colors3)


def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timesteps=100, save_model_interval=10, start_learning_rate=1e-4, lr_warmup_steps=500):
    dataset = load_dataset('frutiemax/rct_dataset', split=f'train[0:{total_images}]')
    dataset.set_transform(convert_images)
    num_images = dataset.num_rows

    unet = UNet2DConditionModel(sample_size=LATENT_SIZE, in_channels=LATENT_NUM_CHANNELS, out_channels=LATENT_NUM_CHANNELS, \
                        down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\
                              up_block_types=("UpBlock2D","CrossAttnUpBlock2D","CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), cross_attention_dim=768,
                            block_out_channels=(128, 256, 512, 512), norm_num_groups=32)
    unet = unet.to(dtype=torch.float32)

    #https://discuss.pytorch.org/t/training-with-half-precision/11815
    for layer in unet.modules():
        if isinstance(layer, nn.BatchNorm2d):
            layer.float()

    scheduler = DDPMScheduler(num_train_timesteps=scheduler_num_timesteps)
    tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(
        "CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True
    ).to('cuda')
    vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
    vae = vae.to(dtype=torch.float32, device='cuda')

    optimizer = torch.optim.AdamW(unet.parameters(), lr=start_learning_rate)
    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=lr_warmup_steps,
        num_training_steps=num_images * epochs
    )
    model = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
    unet = unet.to('cuda')
    
    train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # lets train for 100 epoch for each sprite in the dataset with a random noise level
    progress_bar = tqdm(total=epochs)

    loss_fn = torch.nn.MSELoss()

    tensor_to_pillow = T.ToPILImage()
    for epoch in range(epochs):
        # create a noisy version of each sprite
        for step, batch in enumerate(train_dataloader):
            clean_images = batch['image']
            batch_size = clean_images.size(0)
            embeddings = create_embeddings(batch, model)
            clean_images = torch.reshape(clean_images, (batch['image'].size(0), SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE)).\
                to(device='cuda')

            noise = torch.randn(clean_images.shape, dtype=torch.float32, device='cuda')
            timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size, )).to(device='cuda')

            #timesteps = timesteps.to(dtype=torch.int, device='cuda')
            noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
            
            batch_embeddings = embeddings
            batch_embeddings = batch_embeddings.to('cuda')

            # use the vae to get the latent images
            latent_images = vae.encode(noisy_images).latent_dist.sample()

            optimizer.zero_grad()
            unet_results = unet(latent_images, timesteps, batch_embeddings).sample

            # get back the upscale result
            noise_pred = vae.decode(unet_results).sample

            loss = loss_fn(noise_pred, noise)
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            progress_bar.set_description(f'epoch={epoch}, batch_index={step}, last_loss={loss.item()}')
        
        if (epoch + 1) % save_model_interval == 0:
            # inference in float16
            model = RCTDiffusionPipeline(unet.to(dtype=torch.float16), scheduler, \
                                         vae.to(dtype=torch.float16), tokenizer, text_encoder.to(dtype=torch.float16))
            save_and_test(model, epoch)

            # training in float32
            unet.to(dtype=torch.float32)
            vae.to(dtype=torch.float32)
            text_encoder.to(dtype=torch.float32)

        progress_bar.update(1)


if __name__ == '__main__':
    train_model(1, save_model_interval=10, epochs=100)