from diffusers import DiffusionPipeline from diffusers import DDPMPipeline from diffusers import DDPMScheduler, UNet2DConditionModel import torch import torchvision.transforms as T from PIL import Image import PIL.Image from transformers import AutoTokenizer from datasets import load_dataset import numpy as np import pandas as pd from tqdm.auto import tqdm class RCTDiffusionPipeline(DiffusionPipeline): def __init__(self, unet, scheduler, vae, latent_size=32, sample_size=256): super().__init__() # dictionnary that keeps the different classes of object description, color1, color2 and color3 self.object_description_dict = {} self.color1_dict = {} self.color2_dict = {} self.color3_dict = {} self.scheduler = scheduler self.unet = unet self.vae = vae self.latent_size = latent_size self.sample_size = sample_size # channels for 1 image self.num_channels = int(self.unet.config.in_channels / 4) def load_dictionaries_from_dataset(self): dataset = load_dataset('frutiemax/rct_dataset') dataset = dataset['train'] for row in dataset: if not row['object_description'] in self.object_description_dict: self.object_description_dict[row['object_description']] = len(self.object_description_dict) if not row['color1'] in self.color1_dict and row['color1'] != 'none': self.color1_dict[row['color1']] = len(self.color1_dict) if not row['color2'] in self.color2_dict and row['color2'] != 'none': self.color2_dict[row['color2']] = len(self.color2_dict) if not row['color3'] in self.color3_dict and row['color3'] != 'none': self.color3_dict[row['color3']] = len(self.color3_dict) # helper functions to know the classes def print_class_tokens_to_csv(self): object_descriptions = pd.DataFrame(self.object_description_dict.items()) object_descriptions.to_csv('object_descriptions_tokens.csv') color1 = pd.DataFrame(self.color1_dict.items()) color1.to_csv('color1_tokens.csv') color2 = pd.DataFrame(self.color2_dict.items()) color2.to_csv('color2_tokens.csv') color3 = pd.DataFrame(self.color3_dict.items()) color3.to_csv('color3_tokens.csv') # helper functions to build weight tables def get_object_description_weights(self, classifiers : list[tuple[str, float]]) -> np.array: result = np.zeros(len(self.object_description_dict.items())) for classifier in classifiers: id, weight = classifier if id in self.object_description_dict: weight_index = self.object_description_dict[id] result[weight_index] = weight return result def get_color1_weights(self, classifiers : list[tuple[str, float]]) -> np.array: result = np.zeros(len(self.color1_dict.items())) for classifier in classifiers: id, weight = classifier if id in self.color1_dict: weight_index = self.color1_dict[id] result[weight_index] = weight return result def get_color2_weights(self, classifiers : list[tuple[str, float]]) -> np.array: result = np.zeros(len(self.color2_dict.items())) for classifier in classifiers: id, weight = classifier if id in self.color2_dict: weight_index = self.color2_dict[id] result[weight_index] = weight return result def get_color3_weights(self, classifiers : list[tuple[str, float]]) -> np.array: result = np.zeros(len(self.color3_dict.items())) for classifier in classifiers: id, weight = classifier if id in self.color3_dict: weight_index = self.color3_dict[id] result[weight_index] = weight return result def get_class_labels_size(self): return len(self.object_description_dict.items()) + len(self.color1_dict.items()) + len(self.color2_dict.items()) + len(self.color3_dict.items()) def pack_labels_to_tensor(self, num_images, object_descriptions : np.array, colors1: np.array, colors2 : np.array, colors3 : np.array) -> torch.Tensor: num_labels = self.get_class_labels_size() class_labels = torch.Tensor(size=(num_images, num_labels)) for batch_index in range(num_images): offset = 0 class_labels[batch_index, offset:offset + len(self.object_description_dict)] = torch.from_numpy(object_descriptions[batch_index]) offset += len(self.object_description_dict.items()) class_labels[batch_index, offset:offset + len(self.color1_dict)] = torch.from_numpy(colors1[batch_index]) offset += len(self.color1_dict.items()) class_labels[batch_index, offset:offset + len(self.color2_dict)] = torch.from_numpy(colors2[batch_index]) offset += len(self.color2_dict.items()) class_labels[batch_index, offset:offset + len(self.color3_dict)] = torch.from_numpy(colors3[batch_index]) class_labels = torch.reshape(class_labels, (num_images, 1, self.get_class_labels_size())) return class_labels def get_class_labels(self, object_description : list[list[tuple[str, float]]], color1 : list[list[tuple[str, float]]], \ color2 : list[list[tuple[str, float]]] = None, color3 : list[list[tuple[str, float]]] = None, \ batch_size=1): # check if the labels are the correct size if len(object_description) != batch_size: return None if len(color1) != batch_size: return None if color2 != None and len(color2) != batch_size: return None if color3 != None and len(color3) != batch_size: return None # ok build the labels for each batch object_descriptions = [] colors1 = [] colors2 = [] colors3 = [] for batch_index in range(batch_size): obj_desc = self.get_object_description_weights(object_description[batch_index]) c1 = self.get_color1_weights(color1[batch_index]) if color2 != None: c2 = self.get_color2_weights(color2[batch_index]) else: c2 = self.get_color2_weights([]) if color3 != None: c3 = self.get_color3_weights(color3[batch_index]) else: c3 = self.get_color3_weights([]) object_descriptions.append(obj_desc) colors1.append(c1) colors2.append(c2) colors3.append(c3) # now put those weights into a tensor return self.pack_labels_to_tensor(batch_size, object_descriptions, colors1, colors2, colors3).to(device='cuda',dtype=torch.float16) def generate_noise_batches(self, batch_size): noise_batches = torch.Tensor(size=(batch_size, 4, self.num_channels, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda') for batch_index in range(batch_size): for view_index in range(4): noise = torch.randn(self.num_channels, self.latent_size, self.latent_size).to(dtype=torch.float16, device='cuda') noise_batches[batch_index, view_index] = noise return torch.reshape(noise_batches, (batch_size, 1, self.num_channels*4, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda') def __call__(self, object_description : list[list[tuple[str, float]]], color1 : list[list[tuple[str, float]]], \ color2 : list[list[tuple[str, float]]] = None, color3 : list[list[tuple[str, float]]] = None, \ batch_size=1, num_inference_steps=20, generator=torch.manual_seed(torch.random.seed())): class_labels = self.get_class_labels(object_description, color1, color2, color3, batch_size).to(device='cuda', dtype=torch.float16) if class_labels == None: return None # set the inference steps self.scheduler.set_timesteps(num_inference_steps) noise_batches = self.generate_noise_batches(batch_size) # now call the model for the n interations progress_bar = tqdm(total=num_inference_steps) epoch = 0 for t in self.scheduler.timesteps: progress_bar.set_description(f'Inference step {epoch}') for batch_index in range(batch_size): with torch.no_grad(): noise_residual = self.unet(noise_batches[batch_index], t, encoder_hidden_states=class_labels).sample previous_noisy_sample = self.scheduler.step(noise_residual, t, noise_batches[batch_index]).prev_sample noise_batches[batch_index] = previous_noisy_sample progress_bar.update(1) epoch = epoch + 1 # reshape the data so we get back 4 RGB images noise_batches = torch.reshape(noise_batches, (batch_size, 4, self.num_channels, self.latent_size, self.latent_size)) images = torch.Tensor(size=(batch_size, 4, 3, self.sample_size, self.sample_size)) with torch.no_grad(): for image_index in range(4): image = noise_batches[:, image_index] result = self.vae.decode(image).sample images[:, image_index] = result # convert those tensors to PIL images output_images = [] for batch_index in range(batch_size): for image_index in range(4): # run these into the vae decoder image = images[batch_index, image_index] image = (image / 2 + 0.5).clamp(0, 1).squeeze() image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy() image = (image * 255).round().astype("uint8") image = Image.fromarray(image) output_images.append(image) # for now just return the images return output_images