from datasets import load_dataset from PIL.Image import Image import PIL from PIL.Image import Resampling import numpy as np from rct_diffusion_pipeline import RCTDiffusionPipeline import torch import torchvision.transforms as T import torch.nn.functional as F from diffusers.optimization import get_cosine_schedule_with_warmup from tqdm.auto import tqdm from accelerate import Accelerator from diffusers import DDPMScheduler, UNet2DConditionModel def save_and_test(pipeline, epoch): outputs = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]]) for image_index in range(len(outputs)): file_name = f'out{image_index}_{epoch}.png' outputs[image_index].save(file_name) model_file = f'rct_foliage_{epoch}.pth' pipeline.save_pretrained(model_file) def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model_interval=10, start_learning_rate=1e-3, lr_warmup_steps=500): dataset = load_dataset('frutiemax/rct_dataset') dataset = dataset['train'] num_images = int(dataset.num_rows / 4) # let's get all the entries for the 4 views split in four lists views = [] for view_index in range(4): entries = [entry for entry in dataset if entry['view'] == view_index] views.append(entries) # convert those images to 256x256 by cropping and scaling up the image image_views = [] for view_index in range(4): images = [] for entry in views[view_index]: image = entry['image'] scale_factor = int(np.minimum(256 / image.width, 256 / image.height)) image = Image.resize(image, size=(scale_factor * image.width, scale_factor * image.height), resample=Resampling.NEAREST) new_image = PIL.Image.new('RGB', (256, 256)) new_image.paste(image, box=(int((256 - image.width)/2), int((256 - image.height)/2))) images.append(new_image) image_views.append(images) del views # convert those views in tensors targets = torch.Tensor(size=(num_images, 4, 3, 256, 256)).to(dtype=torch.float16) pillow_to_tensor = T.ToTensor() for image_index in range(num_images): for view_index in range(4): targets[image_index, view_index] = pillow_to_tensor(image_views[view_index][image_index]).to(dtype=torch.float16) del image_views del entries targets = torch.reshape(targets, (num_images, 12, 256, 256)) # get the labels view0_entries = [row for row in dataset if row['view'] == 0] obj_descriptions = [row['object_description'] for row in view0_entries] colors1 = [row['color1'] for row in view0_entries] colors2 = [row['color2'] for row in view0_entries] colors3 = [row['color3'] for row in view0_entries] del view0_entries # convert those descriptions, color1, color2 and color3 to a list of tuple with label and weight=1.0 obj_descriptions = [[(obj_desc, 1.0)] for obj_desc in obj_descriptions] colors1 = [[(color1, 1.0)] for color1 in colors1] colors2 = [[(color2, 1.0)] for color2 in colors2] colors3 = [[(color3, 1.0)] for color3 in colors3] # convert those tuples in numpy arrays using the helper function of the model model = RCTDiffusionPipeline() obj_descriptions = [model.get_object_description_weights(obj_desc) for obj_desc in obj_descriptions] colors1 = [model.get_color1_weights(color1) for color1 in colors1] colors2 = [model.get_color2_weights(color2) for color2 in colors2] colors3 = [model.get_color3_weights(color3) for color3 in colors3] # finally, convert those numpy arrays to a tensor class_labels = model.pack_labels_to_tensor(num_images, obj_descriptions, colors1, colors2, colors3) del obj_descriptions del colors1 del colors2 del colors3 del dataset unet = RCTDiffusionPipeline.get_default_unet(160) optimizer = torch.optim.Adam(unet.parameters(), lr=start_learning_rate) lr_scheduler = get_cosine_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=lr_warmup_steps, num_training_steps=num_images * epochs ) # lets train for 100 epoch for each sprite in the dataset with a random noise level progress_bar = tqdm(total=epochs) scheduler = DDPMScheduler(scheduler_num_timesteps) unet = unet.to(device='cuda', dtype=torch.float16) scheduler.set_timesteps(scheduler_num_timesteps) for epoch in range(epochs): # create a noisy version of each sprite for batch_index in range(0, num_images, batch_size): progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}') batch_end = np.minimum(num_images, batch_index + batch_size) clean_images = targets[batch_index:batch_end] clean_images = torch.reshape(clean_images, ((batch_end - batch_index), 12, 256, 256)).to(device='cuda', dtype=torch.float16) noise = torch.randn(clean_images.shape, dtype=torch.float16, device='cuda') timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_end - batch_index, )).to(device='cuda') #timesteps = timesteps.to(dtype=torch.int, device='cuda') noisy_images = scheduler.add_noise(clean_images, noise, timesteps) noise_pred = unet(noisy_images, timesteps, class_labels[batch_index:batch_end].to(device='cuda',dtype=torch.float16), return_dict=False)[0] #noise_pred = noise_pred.to(device='cuda', dtype=torch.float16) loss = F.mse_loss(noise_pred, noise) loss.backward() optimizer.step() lr_scheduler.step() optimizer.zero_grad() if (epoch + 1) % save_model_interval == 0: model.unet = unet model.scheduler = scheduler save_and_test(model, epoch) del model.unet del model.scheduler progress_bar.update(1) if __name__ == '__main__': train_model(8, save_model_interval=1)