File size: 6,282 Bytes
d751051
 
 
 
 
 
 
 
 
 
 
f6f5f48
d751051
 
 
 
 
 
 
 
 
 
f6f5f48
d751051
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6f5f48
d751051
 
 
 
f6f5f48
d751051
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6f5f48
 
 
 
 
 
 
 
 
 
 
 
1618027
d751051
 
 
 
 
f6f5f48
1618027
d751051
f6f5f48
 
 
 
 
 
 
04d70cd
f6f5f48
 
 
 
d751051
f6f5f48
 
 
d751051
 
f6f5f48
 
d751051
 
 
 
 
f6f5f48
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from datasets import load_dataset
from PIL.Image import Image
import PIL
from PIL.Image import Resampling
import numpy as np
from rct_diffusion_pipeline import RCTDiffusionPipeline
import torch
import torchvision.transforms as T
import torch.nn.functional as F
from diffusers.optimization import get_cosine_schedule_with_warmup
from tqdm.auto import tqdm
from accelerate import Accelerator

def save_and_test(pipeline, epoch):
    outputs = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]])
    for image_index in range(len(outputs)):
        file_name = f'out{image_index}_{epoch}.png'
        outputs[image_index].save(file_name)
    
    model_file = f'rct_foliage_{epoch}.pth'
    pipeline.save_pretrained(model_file)

def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model_interval=10, start_learning_rate=1e-3, lr_warmup_steps=500):
    dataset = load_dataset('frutiemax/rct_dataset')
    dataset = dataset['train']

    num_images = int(dataset.num_rows / 4)

    # let's get all the entries for the 4 views split in four lists
    views = []

    for view_index in range(4):
        entries = [entry for entry in dataset if entry['view'] == view_index]
        views.append(entries)
    
    # convert those images to 256x256 by cropping and scaling up the image
    image_views = []
    for view_index in range(4):
        images = []
        for entry in views[view_index]:
            image = entry['image']
            
            scale_factor = int(np.minimum(256 / image.width, 256 / image.height))
            image = Image.resize(image, size=(scale_factor * image.width, scale_factor * image.height), resample=Resampling.NEAREST)
            
            new_image = PIL.Image.new('RGB', (256, 256))
            new_image.paste(image, box=(int((256 - image.width)/2), int((256 - image.height)/2)))
            images.append(new_image)
        image_views.append(images)
    
    del views
    
    # convert those views in tensors
    targets = torch.Tensor(size=(num_images, 4, 3, 256, 256)).to(dtype=torch.float16)
    pillow_to_tensor = T.ToTensor()

    for image_index in range(num_images):
        for view_index in range(4):
            targets[image_index, view_index] = pillow_to_tensor(image_views[view_index][image_index]).to(dtype=torch.float16)
    del image_views
    del entries

    targets = torch.reshape(targets, (num_images, 12, 256, 256))

    # get the labels
    view0_entries = [row for row in dataset if row['view'] == 0]
    obj_descriptions = [row['object_description'] for row in view0_entries]
    colors1 = [row['color1'] for row in view0_entries]
    colors2 = [row['color2'] for row in view0_entries]
    colors3 = [row['color3'] for row in view0_entries]

    del view0_entries

    # convert those descriptions, color1, color2 and color3 to a list of tuple with label and weight=1.0
    obj_descriptions = [[(obj_desc, 1.0)] for obj_desc in obj_descriptions]
    colors1 = [[(color1, 1.0)] for color1 in colors1]
    colors2 = [[(color2, 1.0)] for color2 in colors2]
    colors3 = [[(color3, 1.0)] for color3 in colors3]

    # convert those tuples in numpy arrays using the helper function of the model
    model = RCTDiffusionPipeline()
    obj_descriptions = [model.get_object_description_weights(obj_desc) for obj_desc in obj_descriptions]
    colors1 = [model.get_color1_weights(color1) for color1 in colors1]
    colors2 = [model.get_color2_weights(color2) for color2 in colors2]
    colors3 = [model.get_color3_weights(color3) for color3 in colors3]

    # finally, convert those numpy arrays to a tensor
    class_labels = model.pack_labels_to_tensor(num_images, obj_descriptions, colors1, colors2, colors3)
    del obj_descriptions
    del colors1
    del colors2
    del colors3
    del dataset

    optimizer = torch.optim.Adam(model.unet.parameters(), lr=start_learning_rate)
    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=lr_warmup_steps,
        num_training_steps=num_images * epochs
    )

    # lets train for 100 epoch for each sprite in the dataset with a random noise level
    progress_bar = tqdm(total=epochs)
    accelerator = Accelerator(
        mixed_precision='fp16',
        gradient_accumulation_steps=1,
        log_with="tensorboard",
        project_dir='logs',
    )

    unet, scheduler, optimizer, lr_scheduler = accelerator.prepare(model.unet, model.scheduler, \
      optimizer, lr_scheduler)

    del model
    scheduler.set_timesteps(scheduler_num_timesteps)
    
    for epoch in range(epochs):
        # create a noisy version of each sprite
        for batch_index in range(0, num_images, batch_size):
            progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}')
            batch_end = np.minimum(num_images, batch_index + batch_size)
            clean_images = targets[batch_index:batch_end]
            clean_images = torch.reshape(clean_images, ((batch_end - batch_index), 12, 256, 256))

            noise = torch.randn(clean_images.shape, dtype=torch.float16)
            timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_end - batch_index, ))
            #timesteps = timesteps.to(dtype=torch.int, device='cuda')
            noisy_images = scheduler.add_noise(clean_images, noise, timesteps).to(device='cuda', dtype=torch.float16)

            with accelerator.accumulate(unet):
                noise_pred = unet(noisy_images, timesteps.to(device='cuda', dtype=torch.float16), class_labels[batch_index:batch_end].to(device='cuda',dtype=torch.float16), return_dict=False)[0]

                #noise_pred = noise_pred.to(device='cuda', dtype=torch.float16)
                loss = F.mse_loss(noise_pred, noise.to('cuda', dtype=torch.float16))
                accelerator.backward(loss)
                accelerator.clip_grad_norm_(unet.parameters(), 1.0)

                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
        
        if (epoch + 1) % save_model_interval == 0:
            model.unet = accelerator.unwrap_model(unet)
            model.scheduler = scheduler
            save_and_test(model, epoch)
        progress_bar.update(1)


if __name__ == '__main__':
    train_model(4)