|
from diffusers import DiffusionPipeline |
|
from diffusers import DDPMPipeline |
|
from diffusers import DDPMScheduler, UNet2DConditionModel |
|
import torch |
|
import torchvision.transforms as T |
|
from PIL import Image |
|
import PIL.Image |
|
from transformers import AutoTokenizer |
|
from datasets import load_dataset |
|
import numpy as np |
|
import pandas as pd |
|
from tqdm.auto import tqdm |
|
from diffusers.image_processor import VaeImageProcessor |
|
|
|
class RCTDiffusionPipeline(DiffusionPipeline): |
|
def __init__(self, unet, scheduler, vae, text_tokenizer, text_encoder, vae_image_processor : VaeImageProcessor, latent_size=32, sample_size=256): |
|
super().__init__() |
|
|
|
|
|
self.object_description_dict = {} |
|
self.color1_dict = {} |
|
self.color2_dict = {} |
|
self.color3_dict = {} |
|
|
|
self.scheduler = scheduler |
|
self.unet = unet |
|
self.vae = vae |
|
self.latent_size = latent_size |
|
self.sample_size = sample_size |
|
self.text_encoder = text_encoder |
|
self.text_tokenizer = text_tokenizer |
|
|
|
|
|
self.vae_image_processor = vae_image_processor |
|
|
|
|
|
self.num_channels = int(self.unet.config.in_channels) |
|
self.load_dictionaries_from_dataset() |
|
self.register_modules(unet=unet, scheduler=scheduler, vae=vae, text_tokenizer=text_tokenizer, text_encoder=text_encoder) |
|
|
|
def load_dictionaries_from_dataset(self): |
|
dataset = load_dataset('frutiemax/rct_dataset') |
|
dataset = dataset['train'] |
|
|
|
for row in dataset: |
|
if not row['object_description'] in self.object_description_dict: |
|
self.object_description_dict[row['object_description']] = len(self.object_description_dict) |
|
if not row['color1'] in self.color1_dict and row['color1'] != 'none': |
|
self.color1_dict[row['color1']] = len(self.color1_dict) |
|
if not row['color2'] in self.color2_dict and row['color2'] != 'none': |
|
self.color2_dict[row['color2']] = len(self.color2_dict) |
|
if not row['color3'] in self.color3_dict and row['color3'] != 'none': |
|
self.color3_dict[row['color3']] = len(self.color3_dict) |
|
|
|
|
|
def print_class_tokens_to_csv(self): |
|
object_descriptions = pd.DataFrame(self.object_description_dict.items()) |
|
object_descriptions.to_csv('object_descriptions_tokens.csv') |
|
|
|
color1 = pd.DataFrame(self.color1_dict.items()) |
|
color1.to_csv('color1_tokens.csv') |
|
|
|
color2 = pd.DataFrame(self.color2_dict.items()) |
|
color2.to_csv('color2_tokens.csv') |
|
|
|
color3 = pd.DataFrame(self.color3_dict.items()) |
|
color3.to_csv('color3_tokens.csv') |
|
|
|
|
|
def get_object_description_weights(self, classifiers : list[tuple[str, float]]) -> np.array: |
|
result = np.zeros(len(self.object_description_dict.items())) |
|
|
|
for classifier in classifiers: |
|
id, weight = classifier |
|
if id in self.object_description_dict: |
|
weight_index = self.object_description_dict[id] |
|
result[weight_index] = weight |
|
return result |
|
|
|
def get_color1_weights(self, classifiers : list[tuple[str, float]]) -> np.array: |
|
result = np.zeros(len(self.color1_dict.items())) |
|
|
|
for classifier in classifiers: |
|
id, weight = classifier |
|
if id in self.color1_dict: |
|
weight_index = self.color1_dict[id] |
|
result[weight_index] = weight |
|
return result |
|
|
|
def get_color2_weights(self, classifiers : list[tuple[str, float]]) -> np.array: |
|
result = np.zeros(len(self.color2_dict.items())) |
|
|
|
for classifier in classifiers: |
|
id, weight = classifier |
|
if id in self.color2_dict: |
|
weight_index = self.color2_dict[id] |
|
result[weight_index] = weight |
|
return result |
|
|
|
def get_color3_weights(self, classifiers : list[tuple[str, float]]) -> np.array: |
|
result = np.zeros(len(self.color3_dict.items())) |
|
|
|
for classifier in classifiers: |
|
id, weight = classifier |
|
if id in self.color3_dict: |
|
weight_index = self.color3_dict[id] |
|
result[weight_index] = weight |
|
return result |
|
|
|
def get_class_labels_size(self): |
|
return len(self.object_description_dict.items()) + len(self.color1_dict.items()) + len(self.color2_dict.items()) + len(self.color3_dict.items()) |
|
|
|
def pack_labels_to_tensor(self, num_images, object_descriptions : np.array, colors1: np.array, colors2 : np.array, colors3 : np.array) -> torch.Tensor: |
|
num_labels = self.get_class_labels_size() |
|
class_labels = torch.Tensor(size=(num_images, num_labels)) |
|
|
|
for batch_index in range(num_images): |
|
offset = 0 |
|
class_labels[batch_index, offset:offset + len(self.object_description_dict)] = torch.from_numpy(object_descriptions[batch_index]) |
|
|
|
offset += len(self.object_description_dict.items()) |
|
class_labels[batch_index, offset:offset + len(self.color1_dict)] = torch.from_numpy(colors1[batch_index]) |
|
|
|
offset += len(self.color1_dict.items()) |
|
class_labels[batch_index, offset:offset + len(self.color2_dict)] = torch.from_numpy(colors2[batch_index]) |
|
|
|
offset += len(self.color2_dict.items()) |
|
class_labels[batch_index, offset:offset + len(self.color3_dict)] = torch.from_numpy(colors3[batch_index]) |
|
|
|
class_labels = torch.reshape(class_labels, (num_images, 1, self.get_class_labels_size())) |
|
return class_labels |
|
|
|
def get_class_labels(self, object_description : list[list[tuple[str, float]]], color1 : list[list[tuple[str, float]]], \ |
|
color2 : list[list[tuple[str, float]]] = None, color3 : list[list[tuple[str, float]]] = None, \ |
|
batch_size=1): |
|
|
|
if len(object_description) != batch_size: |
|
return None |
|
|
|
if len(color1) != batch_size: |
|
return None |
|
|
|
if color2 != None and len(color2) != batch_size: |
|
return None |
|
|
|
if color3 != None and len(color3) != batch_size: |
|
return None |
|
|
|
|
|
object_descriptions = [] |
|
colors1 = [] |
|
colors2 = [] |
|
colors3 = [] |
|
|
|
for batch_index in range(batch_size): |
|
obj_desc = self.get_object_description_weights(object_description[batch_index]) |
|
c1 = self.get_color1_weights(color1[batch_index]) |
|
|
|
if color2 != None: |
|
c2 = self.get_color2_weights(color2[batch_index]) |
|
else: |
|
c2 = self.get_color2_weights([]) |
|
|
|
if color3 != None: |
|
c3 = self.get_color3_weights(color3[batch_index]) |
|
else: |
|
c3 = self.get_color3_weights([]) |
|
|
|
object_descriptions.append(obj_desc) |
|
colors1.append(c1) |
|
colors2.append(c2) |
|
colors3.append(c3) |
|
|
|
|
|
return self.pack_labels_to_tensor(batch_size, object_descriptions, colors1, colors2, colors3).to(device='cuda',dtype=torch.float16) |
|
|
|
def generate_noise_batches(self, batch_size): |
|
noise_batches = torch.Tensor(size=(batch_size, self.num_channels, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda') |
|
seed = torch.seed() |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
for batch_index in range(batch_size): |
|
noise = torch.randn((self.num_channels, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda') |
|
noise_batches[batch_index] = noise |
|
|
|
return torch.reshape(noise_batches, (batch_size, self.num_channels, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda') |
|
|
|
def test_generate_embeddings(self, object_description, color1, color2, color3) -> torch.Tensor: |
|
batch_size = len(object_description) |
|
|
|
embeddings = torch.Tensor(size=(batch_size, 77, 768)) |
|
for batch_index in range(batch_size): |
|
prompt = f'{object_description[batch_index]},{color1[batch_index]},{color2[batch_index]}, {color3[batch_index]}' |
|
tokens = self.text_tokenizer(prompt, \ |
|
padding="max_length", max_length=self.text_tokenizer.model_max_length, truncation=True, return_tensors="pt") |
|
with torch.no_grad(): |
|
embeddings[batch_index] = self.text_encoder(tokens.input_ids.to('cuda'))[0] |
|
|
|
return embeddings |
|
|
|
def generate_embeddings(self, object_description, color1, color2, color3) -> torch.Tensor: |
|
batch_size = len(object_description) |
|
|
|
embeddings = torch.Tensor(size=(batch_size, 77, 768 * 4)) |
|
for batch_index in range(batch_size): |
|
object_description_tokens = self.text_tokenizer(object_description[batch_index], \ |
|
padding="max_length", max_length=self.text_tokenizer.model_max_length, truncation=True, return_tensors="pt") |
|
color1_tokens = self.text_tokenizer(color1[batch_index], \ |
|
padding="max_length", max_length=self.text_tokenizer.model_max_length, truncation=True, return_tensors="pt") |
|
color2_tokens = self.text_tokenizer(color2[batch_index], \ |
|
padding="max_length", max_length=self.text_tokenizer.model_max_length, truncation=True, return_tensors="pt") |
|
color3_tokens = self.text_tokenizer(color3[batch_index], \ |
|
padding="max_length", max_length=self.text_tokenizer.model_max_length, truncation=True, return_tensors="pt") |
|
with torch.no_grad(): |
|
object_description_embeddings = self.text_encoder(object_description_tokens.input_ids.to('cuda'))[0] |
|
color1_embeddings = self.text_encoder(color1_tokens.input_ids.to('cuda'))[0] |
|
color2_embeddings = self.text_encoder(color2_tokens.input_ids.to('cuda'))[0] |
|
color3_embeddings = self.text_encoder(color3_tokens.input_ids.to('cuda'))[0] |
|
|
|
emb = torch.cat([object_description_embeddings, color1_embeddings, color2_embeddings, color3_embeddings], dim=2) |
|
embeddings[batch_index] = emb |
|
|
|
return embeddings.to(dtype=torch.float16) |
|
|
|
def validate_inputs(self, object_description : list[str], color1 : list[str], \ |
|
color2 : list[str], color3 : list[str], batch_size) -> tuple[bool, list[str], list[str], list[str], list[str]]: |
|
|
|
if len(object_description) != batch_size: |
|
return False |
|
|
|
if len(color1) != batch_size: |
|
return False |
|
|
|
if color2 == None: |
|
color2 = ['none'] * batch_size |
|
elif len(color2) != batch_size: |
|
return False |
|
|
|
if color3 == None: |
|
color3 = ['none'] * batch_size |
|
elif len(color3) != batch_size: |
|
return False |
|
return True, object_description, color1, color2, color3 |
|
|
|
def __call__(self, object_description : list[str], color1 : list[str], \ |
|
color2 : list[str] = None, color3 : list[str] = None, \ |
|
batch_size=1, num_inference_steps=100, generator=torch.manual_seed(torch.random.seed())): |
|
|
|
self.unet.to(device='cuda', dtype=torch.float16) |
|
self.vae.to(device='cuda', dtype=torch.float16) |
|
self.text_encoder.to(device='cuda', dtype=torch.float16) |
|
|
|
res, object_description, color1, color2, color3 = self.validate_inputs(object_description, color1, color2, color3, batch_size) |
|
if res == False: |
|
return None |
|
embeddings = self.test_generate_embeddings(object_description, color1, color2, color3) |
|
embeddings = embeddings.to(device='cuda', dtype=torch.float16) |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps) |
|
noise_batches = self.generate_noise_batches(batch_size).to(dtype=torch.float16) |
|
|
|
|
|
progress_bar = tqdm(total=num_inference_steps) |
|
epoch = 0 |
|
test_image = None |
|
for t in self.scheduler.timesteps: |
|
progress_bar.set_description(f'Inference step {epoch}') |
|
|
|
for batch_index in range(batch_size): |
|
noise_batch = self.scheduler.scale_model_input(noise_batches, timestep=t) |
|
with torch.no_grad(): |
|
noise_residual = self.unet(noise_batch, t, encoder_hidden_states=embeddings).sample |
|
previous_noisy_sample = self.scheduler.step(noise_residual, t, noise_batch).prev_sample |
|
noise_batches[batch_index] = previous_noisy_sample |
|
|
|
|
|
test_image = self.decode_latent(noise_batches[batch_index], self.vae.config.scaling_factor) |
|
|
|
|
|
progress_bar.update(1) |
|
epoch = epoch + 1 |
|
test_image.show() |
|
|
|
|
|
noise_batches = torch.reshape(noise_batches, (batch_size, self.num_channels, self.latent_size, self.latent_size)) |
|
noise_batches = noise_batches.to('cuda') |
|
images = torch.Tensor(size=(batch_size, 3, self.sample_size, self.sample_size)).to('cuda') |
|
images = noise_batches[:, :3] |
|
|
|
with torch.no_grad(): |
|
image = noise_batches |
|
result = self.vae.decode(image / self.vae.config.scaling_factor).sample |
|
image = self.vae_image_processor.denormalize(result) |
|
images = image |
|
|
|
|
|
tensor_to_pil = T.ToPILImage() |
|
output_images = [] |
|
for batch_index in range(batch_size): |
|
image = images[batch_index] |
|
output_images.append(image) |
|
|
|
|
|
return [tensor_to_pil(image) for image in output_images] |
|
|
|
def decode_latent(self, image, vae_scaling_factor) -> torch.Tensor: |
|
tensor_to_pil = T.ToPILImage() |
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
image = tensor_to_pil(image) |
|
return image |