|
from datasets import load_dataset |
|
from PIL.Image import Image |
|
import PIL |
|
from PIL.Image import Resampling |
|
import numpy as np |
|
from rct_diffusion_pipeline import RCTDiffusionPipeline |
|
import torch |
|
import torchvision.transforms as T |
|
import torchvision |
|
import torch.nn.functional as F |
|
from diffusers.optimization import get_cosine_schedule_with_warmup |
|
from tqdm.auto import tqdm |
|
from accelerate import Accelerator |
|
from diffusers import DDPMScheduler, UNet2DConditionModel, AutoencoderKL |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
import torch.nn as nn |
|
|
|
SAMPLE_SIZE = 256 |
|
LATENT_SIZE = 32 |
|
SAMPLE_NUM_CHANNELS = 3 |
|
LATENT_NUM_CHANNELS = 4 |
|
from torchvision import transforms |
|
|
|
def save_and_test(pipeline, epoch): |
|
outputs = pipeline(['aleppo pine tree'], ['dark green']) |
|
for image_index in range(len(outputs)): |
|
file_name = f'out{image_index}_{epoch}.png' |
|
outputs[image_index].save(file_name) |
|
|
|
model_file = f'rct_foliage_{epoch}' |
|
pipeline.save_pretrained(model_file) |
|
|
|
def transform_images(image): |
|
res = torch.Tensor(SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE) |
|
pil_to_tensor = T.PILToTensor() |
|
|
|
res_index = 0 |
|
scale_factor = np.minimum(SAMPLE_SIZE / image.width, SAMPLE_SIZE / image.height) |
|
image = Image.resize(image, size=(int(scale_factor * image.width), int(scale_factor * image.height)), resample=Resampling.NEAREST) |
|
|
|
new_image = PIL.Image.new('RGB', (SAMPLE_SIZE, SAMPLE_SIZE)) |
|
new_image.paste(image, box=(int((SAMPLE_SIZE - image.width)/2), int((SAMPLE_SIZE - image.height)/2))) |
|
res = pil_to_tensor(new_image) |
|
return res |
|
|
|
def convert_images(dataset): |
|
images = [transform_images(image) for image in dataset["image"]] |
|
object_descriptions = [obj_desc for obj_desc in dataset["object_description"]] |
|
colors1 = [color1 for color1 in dataset['color1']] |
|
colors2 = [color1 for color1 in dataset['color2']] |
|
colors3 = [color1 for color1 in dataset['color3']] |
|
|
|
return {"image": images, 'object_description':object_descriptions, 'color1':colors1, \ |
|
'color2':colors2, 'color3':colors3} |
|
|
|
def convert_labels(dataset, model, num_images): |
|
|
|
view0_entries = [row for row in dataset if row['view'] == 0] |
|
obj_descriptions = [row['object_description'] for row in view0_entries] |
|
colors1 = [row['color1'] for row in view0_entries] |
|
colors2 = [row['color2'] for row in view0_entries] |
|
colors3 = [row['color3'] for row in view0_entries] |
|
|
|
del view0_entries |
|
|
|
|
|
obj_descriptions = [[(obj_desc, 1.0)] for obj_desc in obj_descriptions] |
|
colors1 = [[(color1, 1.0)] for color1 in colors1] |
|
colors2 = [[(color2, 1.0)] for color2 in colors2] |
|
colors3 = [[(color3, 1.0)] for color3 in colors3] |
|
|
|
|
|
|
|
obj_descriptions = [model.get_object_description_weights(obj_desc) for obj_desc in obj_descriptions] |
|
colors1 = [model.get_color1_weights(color1) for color1 in colors1] |
|
colors2 = [model.get_color2_weights(color2) for color2 in colors2] |
|
colors3 = [model.get_color3_weights(color3) for color3 in colors3] |
|
|
|
|
|
class_labels = model.pack_labels_to_tensor(num_images, obj_descriptions, colors1, colors2, colors3) |
|
del obj_descriptions |
|
del colors1 |
|
del colors2 |
|
del colors3 |
|
del dataset |
|
return class_labels.to(dtype=torch.float16, device='cuda') |
|
|
|
def create_embeddings(dataset, model): |
|
object_descriptions = dataset['object_description'] |
|
colors1 = dataset['color1'] |
|
colors2 = dataset['color2'] |
|
colors3 = dataset['color3'] |
|
return model.test_generate_embeddings(object_descriptions, colors1, colors2, colors3) |
|
|
|
|
|
def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timesteps=100, save_model_interval=10, start_learning_rate=1e-4, lr_warmup_steps=500): |
|
dataset = load_dataset('frutiemax/rct_dataset', split=f'train[0:{total_images}]') |
|
dataset.set_transform(convert_images) |
|
num_images = dataset.num_rows |
|
|
|
unet = UNet2DConditionModel(sample_size=LATENT_SIZE, in_channels=LATENT_NUM_CHANNELS, out_channels=LATENT_NUM_CHANNELS, \ |
|
down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\ |
|
up_block_types=("UpBlock2D","CrossAttnUpBlock2D","CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), cross_attention_dim=768, |
|
block_out_channels=(128, 256, 512, 512), norm_num_groups=32) |
|
unet = unet.to(dtype=torch.float32) |
|
|
|
|
|
for layer in unet.modules(): |
|
if isinstance(layer, nn.BatchNorm2d): |
|
layer.float() |
|
|
|
scheduler = DDPMScheduler(num_train_timesteps=scheduler_num_timesteps) |
|
tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer") |
|
text_encoder = CLIPTextModel.from_pretrained( |
|
"CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True |
|
).to('cuda') |
|
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True) |
|
vae = vae.to(dtype=torch.float32, device='cuda') |
|
|
|
optimizer = torch.optim.AdamW(unet.parameters(), lr=start_learning_rate) |
|
lr_scheduler = get_cosine_schedule_with_warmup( |
|
optimizer=optimizer, |
|
num_warmup_steps=lr_warmup_steps, |
|
num_training_steps=num_images * epochs |
|
) |
|
model = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder) |
|
unet = unet.to('cuda') |
|
|
|
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) |
|
|
|
|
|
progress_bar = tqdm(total=epochs) |
|
|
|
loss_fn = torch.nn.MSELoss() |
|
|
|
tensor_to_pillow = T.ToPILImage() |
|
for epoch in range(epochs): |
|
|
|
for step, batch in enumerate(train_dataloader): |
|
clean_images = batch['image'] |
|
batch_size = clean_images.size(0) |
|
embeddings = create_embeddings(batch, model) |
|
clean_images = torch.reshape(clean_images, (batch['image'].size(0), SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE)).\ |
|
to(device='cuda') |
|
|
|
noise = torch.randn(clean_images.shape, dtype=torch.float32, device='cuda') |
|
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size, )).to(device='cuda') |
|
|
|
|
|
noisy_images = scheduler.add_noise(clean_images, noise, timesteps) |
|
|
|
batch_embeddings = embeddings |
|
batch_embeddings = batch_embeddings.to('cuda') |
|
|
|
|
|
latent_images = vae.encode(noisy_images).latent_dist.sample() |
|
|
|
optimizer.zero_grad() |
|
unet_results = unet(latent_images, timesteps, batch_embeddings).sample |
|
|
|
|
|
noise_pred = vae.decode(unet_results).sample |
|
|
|
loss = loss_fn(noise_pred, noise) |
|
loss.backward() |
|
optimizer.step() |
|
lr_scheduler.step() |
|
optimizer.zero_grad() |
|
|
|
progress_bar.set_description(f'epoch={epoch}, batch_index={step}, last_loss={loss.item()}') |
|
|
|
if (epoch + 1) % save_model_interval == 0: |
|
|
|
model = RCTDiffusionPipeline(unet.to(dtype=torch.float16), scheduler, \ |
|
vae.to(dtype=torch.float16), tokenizer, text_encoder.to(dtype=torch.float16)) |
|
save_and_test(model, epoch) |
|
|
|
|
|
unet.to(dtype=torch.float32) |
|
vae.to(dtype=torch.float32) |
|
text_encoder.to(dtype=torch.float32) |
|
|
|
progress_bar.update(1) |
|
|
|
|
|
if __name__ == '__main__': |
|
train_model(1, save_model_interval=10, epochs=100) |