from datasets import load_dataset from PIL.Image import Image import PIL from PIL.Image import Resampling import numpy as np from rct_diffusion_pipeline import RCTDiffusionPipeline import torch import torchvision.transforms as T import torchvision import torch.nn.functional as F from diffusers.optimization import get_cosine_schedule_with_warmup from tqdm.auto import tqdm from accelerate import Accelerator from diffusers import DDPMScheduler, UNet2DConditionModel, AutoencoderKL from transformers import CLIPTextModel, CLIPTokenizer import torch.nn as nn from diffusers.image_processor import VaeImageProcessor SAMPLE_SIZE = 256 LATENT_SIZE = 32 SAMPLE_NUM_CHANNELS = 3 LATENT_NUM_CHANNELS = 4 from torchvision import transforms def save_and_test(pipeline, epoch): outputs = pipeline(['aleppo pine tree'], ['dark green']) for image_index in range(len(outputs)): file_name = f'out{image_index}_{epoch}.png' outputs[image_index].save(file_name) model_file = f'rct_foliage_{epoch}' pipeline.save_pretrained(model_file) def transform_images(image): pil_to_tensor = T.PILToTensor() scale_factor = np.minimum(SAMPLE_SIZE / image.width, SAMPLE_SIZE / image.height) image = Image.resize(image, size=(int(scale_factor * image.width), int(scale_factor * image.height)), resample=Resampling.NEAREST) new_image = PIL.Image.new('RGB', (SAMPLE_SIZE, SAMPLE_SIZE)) new_image.paste(image, box=(int((SAMPLE_SIZE - image.width)/2), int((SAMPLE_SIZE - image.height)/2))) return pil_to_tensor(new_image) def convert_images(dataset): images = [transform_images(image) for image in dataset["image"]] object_descriptions = [obj_desc for obj_desc in dataset["object_description"]] colors1 = [color1 for color1 in dataset['color1']] colors2 = [color1 for color1 in dataset['color2']] colors3 = [color1 for color1 in dataset['color3']] return {"image": images, 'object_description':object_descriptions, 'color1':colors1, \ 'color2':colors2, 'color3':colors3} def convert_labels(dataset, model, num_images): # get the labels view0_entries = [row for row in dataset if row['view'] == 0] obj_descriptions = [row['object_description'] for row in view0_entries] colors1 = [row['color1'] for row in view0_entries] colors2 = [row['color2'] for row in view0_entries] colors3 = [row['color3'] for row in view0_entries] del view0_entries # convert those descriptions, color1, color2 and color3 to a list of tuple with label and weight=1.0 obj_descriptions = [[(obj_desc, 1.0)] for obj_desc in obj_descriptions] colors1 = [[(color1, 1.0)] for color1 in colors1] colors2 = [[(color2, 1.0)] for color2 in colors2] colors3 = [[(color3, 1.0)] for color3 in colors3] # convert those tuples in numpy arrays using the helper function of the model obj_descriptions = [model.get_object_description_weights(obj_desc) for obj_desc in obj_descriptions] colors1 = [model.get_color1_weights(color1) for color1 in colors1] colors2 = [model.get_color2_weights(color2) for color2 in colors2] colors3 = [model.get_color3_weights(color3) for color3 in colors3] # finally, convert those numpy arrays to a tensor class_labels = model.pack_labels_to_tensor(num_images, obj_descriptions, colors1, colors2, colors3) del obj_descriptions del colors1 del colors2 del colors3 del dataset return class_labels.to(dtype=torch.float16, device='cuda') def create_embeddings(dataset, model): object_descriptions = dataset['object_description'] colors1 = dataset['color1'] colors2 = dataset['color2'] colors3 = dataset['color3'] return model.test_generate_embeddings(object_descriptions, colors1, colors2, colors3) def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timesteps=100, save_model_interval=10, start_learning_rate=1e-4, lr_warmup_steps=500): vae_image_processor = VaeImageProcessor() dataset = load_dataset('frutiemax/rct_dataset', split=f'train[0:{total_images}]') dataset.set_transform(convert_images) num_images = dataset.num_rows unet = UNet2DConditionModel(sample_size=LATENT_SIZE, in_channels=LATENT_NUM_CHANNELS, out_channels=LATENT_NUM_CHANNELS, \ down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\ up_block_types=("UpBlock2D","CrossAttnUpBlock2D","CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), cross_attention_dim=768, block_out_channels=(320, 640, 1280, 1280), norm_num_groups=32) unet = unet.to(dtype=torch.float32) #https://discuss.pytorch.org/t/training-with-half-precision/11815 for layer in unet.modules(): if isinstance(layer, nn.BatchNorm2d): layer.float() scheduler = DDPMScheduler(num_train_timesteps=scheduler_num_timesteps) tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained( "CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True ).to('cuda') vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True) vae = vae.to(dtype=torch.float32, device='cuda') vae.requires_grad_(False) text_encoder.requires_grad_(False) optimizer = torch.optim.AdamW(unet.parameters(), lr=start_learning_rate) lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999, verbose=True) model = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder, vae_image_processor) unet = unet.to('cuda') train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) # lets train for 100 epoch for each sprite in the dataset with a random noise level progress_bar = tqdm(total=epochs) loss_fn = torch.nn.MSELoss() tensor_to_pillow = T.ToPILImage() for epoch in range(epochs): # create a noisy version of each sprite for step, batch in enumerate(train_dataloader): clean_images = batch['image'] batch_size = clean_images.size(0) embeddings = create_embeddings(batch, model) clean_images = torch.reshape(clean_images, (batch['image'].size(0), SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE)).\ to(device='cuda') # use the vae to get the latent images clean_images = vae_image_processor.preprocess(clean_images) latent_images = vae.encode(clean_images).latent_dist.sample() latent_images = latent_images * vae.config.scaling_factor noise = torch.randn_like(latent_images) timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size, )).to(device='cuda') #timesteps = timesteps.to(dtype=torch.int, device='cuda') noisy_images = scheduler.add_noise(latent_images, noise, timesteps) batch_embeddings = embeddings batch_embeddings = batch_embeddings.to('cuda') optimizer.zero_grad() unet_results = unet(noisy_images, timesteps, batch_embeddings).sample loss = loss_fn(unet_results, noise) loss.backward() optimizer.step() lr_scheduler.step() optimizer.zero_grad() progress_bar.set_description(f'epoch={epoch}, batch_index={step}, last_loss={loss.item()}') if (epoch + 1) % save_model_interval == 0: # inference in float16 model = RCTDiffusionPipeline(unet.to(dtype=torch.float16), scheduler, \ vae.to(dtype=torch.float16), tokenizer, text_encoder.to(dtype=torch.float16), vae_image_processor) save_and_test(model, epoch) # training in float32 unet.to(dtype=torch.float32) vae.to(dtype=torch.float32) text_encoder.to(dtype=torch.float32) progress_bar.update(1) if __name__ == '__main__': train_model(batch_size=48, save_model_interval=25, epochs=1000, start_learning_rate=1e-3)