File size: 6,282 Bytes
d751051 f6f5f48 d751051 f6f5f48 d751051 f6f5f48 d751051 f6f5f48 d751051 f6f5f48 1618027 d751051 f6f5f48 1618027 d751051 f6f5f48 04d70cd f6f5f48 d751051 f6f5f48 d751051 f6f5f48 d751051 f6f5f48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
from datasets import load_dataset
from PIL.Image import Image
import PIL
from PIL.Image import Resampling
import numpy as np
from rct_diffusion_pipeline import RCTDiffusionPipeline
import torch
import torchvision.transforms as T
import torch.nn.functional as F
from diffusers.optimization import get_cosine_schedule_with_warmup
from tqdm.auto import tqdm
from accelerate import Accelerator
def save_and_test(pipeline, epoch):
outputs = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]])
for image_index in range(len(outputs)):
file_name = f'out{image_index}_{epoch}.png'
outputs[image_index].save(file_name)
model_file = f'rct_foliage_{epoch}.pth'
pipeline.save_pretrained(model_file)
def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model_interval=10, start_learning_rate=1e-3, lr_warmup_steps=500):
dataset = load_dataset('frutiemax/rct_dataset')
dataset = dataset['train']
num_images = int(dataset.num_rows / 4)
# let's get all the entries for the 4 views split in four lists
views = []
for view_index in range(4):
entries = [entry for entry in dataset if entry['view'] == view_index]
views.append(entries)
# convert those images to 256x256 by cropping and scaling up the image
image_views = []
for view_index in range(4):
images = []
for entry in views[view_index]:
image = entry['image']
scale_factor = int(np.minimum(256 / image.width, 256 / image.height))
image = Image.resize(image, size=(scale_factor * image.width, scale_factor * image.height), resample=Resampling.NEAREST)
new_image = PIL.Image.new('RGB', (256, 256))
new_image.paste(image, box=(int((256 - image.width)/2), int((256 - image.height)/2)))
images.append(new_image)
image_views.append(images)
del views
# convert those views in tensors
targets = torch.Tensor(size=(num_images, 4, 3, 256, 256)).to(dtype=torch.float16)
pillow_to_tensor = T.ToTensor()
for image_index in range(num_images):
for view_index in range(4):
targets[image_index, view_index] = pillow_to_tensor(image_views[view_index][image_index]).to(dtype=torch.float16)
del image_views
del entries
targets = torch.reshape(targets, (num_images, 12, 256, 256))
# get the labels
view0_entries = [row for row in dataset if row['view'] == 0]
obj_descriptions = [row['object_description'] for row in view0_entries]
colors1 = [row['color1'] for row in view0_entries]
colors2 = [row['color2'] for row in view0_entries]
colors3 = [row['color3'] for row in view0_entries]
del view0_entries
# convert those descriptions, color1, color2 and color3 to a list of tuple with label and weight=1.0
obj_descriptions = [[(obj_desc, 1.0)] for obj_desc in obj_descriptions]
colors1 = [[(color1, 1.0)] for color1 in colors1]
colors2 = [[(color2, 1.0)] for color2 in colors2]
colors3 = [[(color3, 1.0)] for color3 in colors3]
# convert those tuples in numpy arrays using the helper function of the model
model = RCTDiffusionPipeline()
obj_descriptions = [model.get_object_description_weights(obj_desc) for obj_desc in obj_descriptions]
colors1 = [model.get_color1_weights(color1) for color1 in colors1]
colors2 = [model.get_color2_weights(color2) for color2 in colors2]
colors3 = [model.get_color3_weights(color3) for color3 in colors3]
# finally, convert those numpy arrays to a tensor
class_labels = model.pack_labels_to_tensor(num_images, obj_descriptions, colors1, colors2, colors3)
del obj_descriptions
del colors1
del colors2
del colors3
del dataset
optimizer = torch.optim.Adam(model.unet.parameters(), lr=start_learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=lr_warmup_steps,
num_training_steps=num_images * epochs
)
# lets train for 100 epoch for each sprite in the dataset with a random noise level
progress_bar = tqdm(total=epochs)
accelerator = Accelerator(
mixed_precision='fp16',
gradient_accumulation_steps=1,
log_with="tensorboard",
project_dir='logs',
)
unet, scheduler, optimizer, lr_scheduler = accelerator.prepare(model.unet, model.scheduler, \
optimizer, lr_scheduler)
del model
scheduler.set_timesteps(scheduler_num_timesteps)
for epoch in range(epochs):
# create a noisy version of each sprite
for batch_index in range(0, num_images, batch_size):
progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}')
batch_end = np.minimum(num_images, batch_index + batch_size)
clean_images = targets[batch_index:batch_end]
clean_images = torch.reshape(clean_images, ((batch_end - batch_index), 12, 256, 256))
noise = torch.randn(clean_images.shape, dtype=torch.float16)
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_end - batch_index, ))
#timesteps = timesteps.to(dtype=torch.int, device='cuda')
noisy_images = scheduler.add_noise(clean_images, noise, timesteps).to(device='cuda', dtype=torch.float16)
with accelerator.accumulate(unet):
noise_pred = unet(noisy_images, timesteps.to(device='cuda', dtype=torch.float16), class_labels[batch_index:batch_end].to(device='cuda',dtype=torch.float16), return_dict=False)[0]
#noise_pred = noise_pred.to(device='cuda', dtype=torch.float16)
loss = F.mse_loss(noise_pred, noise.to('cuda', dtype=torch.float16))
accelerator.backward(loss)
accelerator.clip_grad_norm_(unet.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
if (epoch + 1) % save_model_interval == 0:
model.unet = accelerator.unwrap_model(unet)
model.scheduler = scheduler
save_and_test(model, epoch)
progress_bar.update(1)
if __name__ == '__main__':
train_model(4) |