rct_model / train_model.py
frutiemax's picture
Switch to stabilityai/sd-vae-ft-mse
2104644
raw
history blame
8.34 kB
from datasets import load_dataset
from PIL.Image import Image
import PIL
from PIL.Image import Resampling
import numpy as np
from rct_diffusion_pipeline import RCTDiffusionPipeline
import torch
import torchvision.transforms as T
import torch.nn.functional as F
from diffusers.optimization import get_cosine_schedule_with_warmup
from tqdm.auto import tqdm
from accelerate import Accelerator
from diffusers import DDPMScheduler, UNet2DConditionModel, AutoencoderKL
SAMPLE_SIZE = 256
LATENT_SIZE = 32
SAMPLE_NUM_CHANNELS = 3
LATENT_NUM_CHANNELS = 4
def save_and_test(pipeline, epoch):
outputs = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]])
for image_index in range(len(outputs)):
file_name = f'out{image_index}_{epoch}.png'
outputs[image_index].save(file_name)
model_file = f'rct_foliage_{epoch}.pth'
pipeline.save_pretrained(model_file)
def convert_images(dataset):
# let's get all the entries for the 4 views split in four lists
views = []
num_images = int(dataset.num_rows / 4)
for view_index in range(4):
entries = [entry for entry in dataset if entry['view'] == view_index]
views.append(entries)
# convert those images to 256x256 by cropping and scaling up the image
image_views = []
for view_index in range(4):
images = []
for entry in views[view_index]:
image = entry['image']
scale_factor = int(np.minimum(SAMPLE_SIZE / image.width, SAMPLE_SIZE / image.height))
image = Image.resize(image, size=(scale_factor * image.width, scale_factor * image.height), resample=Resampling.NEAREST)
new_image = PIL.Image.new('RGB', (SAMPLE_SIZE, SAMPLE_SIZE))
new_image.paste(image, box=(int((SAMPLE_SIZE - image.width)/2), int((SAMPLE_SIZE - image.height)/2)))
images.append(new_image)
image_views.append(images)
del views
# convert those views in tensors
targets = torch.Tensor(size=(num_images, 4, SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE)).to(dtype=torch.float16)
pillow_to_tensor = T.ToTensor()
for image_index in range(num_images):
for view_index in range(4):
targets[image_index, view_index] = pillow_to_tensor(image_views[view_index][image_index]).to(dtype=torch.float16)
del image_views
del entries
return torch.reshape(targets, (num_images, 4 * SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE))
def convert_labels(dataset, model, num_images):
# get the labels
view0_entries = [row for row in dataset if row['view'] == 0]
obj_descriptions = [row['object_description'] for row in view0_entries]
colors1 = [row['color1'] for row in view0_entries]
colors2 = [row['color2'] for row in view0_entries]
colors3 = [row['color3'] for row in view0_entries]
del view0_entries
# convert those descriptions, color1, color2 and color3 to a list of tuple with label and weight=1.0
obj_descriptions = [[(obj_desc, 1.0)] for obj_desc in obj_descriptions]
colors1 = [[(color1, 1.0)] for color1 in colors1]
colors2 = [[(color2, 1.0)] for color2 in colors2]
colors3 = [[(color3, 1.0)] for color3 in colors3]
# convert those tuples in numpy arrays using the helper function of the model
obj_descriptions = [model.get_object_description_weights(obj_desc) for obj_desc in obj_descriptions]
colors1 = [model.get_color1_weights(color1) for color1 in colors1]
colors2 = [model.get_color2_weights(color2) for color2 in colors2]
colors3 = [model.get_color3_weights(color3) for color3 in colors3]
# finally, convert those numpy arrays to a tensor
class_labels = model.pack_labels_to_tensor(num_images, obj_descriptions, colors1, colors2, colors3)
del obj_descriptions
del colors1
del colors2
del colors3
del dataset
return class_labels.to(dtype=torch.float16, device='cuda')
def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model_interval=10, start_learning_rate=1e-3, lr_warmup_steps=500):
dataset = load_dataset('frutiemax/rct_dataset')
dataset = dataset['train']
targets = convert_images(dataset)
num_images = int(dataset.num_rows / 4)
unet = UNet2DConditionModel(sample_size=LATENT_SIZE, in_channels=LATENT_NUM_CHANNELS * 4, out_channels=LATENT_NUM_CHANNELS * 4, \
down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'),\
up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=160,
block_out_channels=(64, 128, 256), norm_num_groups=32)
unet = unet.to(dtype=torch.float16)
scheduler = DDPMScheduler(num_train_timesteps=20)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
vae = vae.to(dtype=torch.float16)
optimizer = torch.optim.Adam(unet.parameters(), lr=start_learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=lr_warmup_steps,
num_training_steps=num_images * epochs
)
model = RCTDiffusionPipeline(unet, scheduler, vae)
model.load_dictionaries_from_dataset()
labels = convert_labels(dataset, model, num_images)
del model
# lets train for 100 epoch for each sprite in the dataset with a random noise level
progress_bar = tqdm(total=epochs)
accelerator = Accelerator(mixed_precision='fp16')
unet, scheduler, lr_scheduler, vae = accelerator.prepare(unet, scheduler, lr_scheduler, vae)
for epoch in range(epochs):
# create a noisy version of each sprite
for batch_index in range(0, num_images, batch_size):
progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}')
batch_end = np.minimum(num_images, batch_index + batch_size)
clean_images = targets[batch_index:batch_end]
clean_images = torch.reshape(clean_images, ((batch_end - batch_index), SAMPLE_NUM_CHANNELS * 4, SAMPLE_SIZE, SAMPLE_SIZE)).to(device='cuda', dtype=torch.float16)
noise = torch.randn(clean_images.shape, dtype=torch.float16, device='cuda')
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_end - batch_index, )).to(device='cuda')
#timesteps = timesteps.to(dtype=torch.int, device='cuda')
noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
del clean_images
# encode through the vae
with accelerator.accumulate(unet):
latent_images = torch.Tensor(size=(batch_end - batch_index, LATENT_NUM_CHANNELS * 4, LATENT_SIZE, LATENT_SIZE)).to(device='cuda', dtype=torch.float16)
latent_noises = torch.Tensor(size=(batch_end - batch_index, LATENT_NUM_CHANNELS * 4, LATENT_SIZE, LATENT_SIZE)).to(device='cuda', dtype=torch.float16)
for view_index in range(4):
images = noisy_images[:, view_index*SAMPLE_NUM_CHANNELS:(view_index+1)*SAMPLE_NUM_CHANNELS]
result = vae.encode(images).latent_dist.sample()
latent_images[:, view_index*LATENT_NUM_CHANNELS:(view_index+1)*LATENT_NUM_CHANNELS] = result
images = noise[:, view_index*SAMPLE_NUM_CHANNELS:(view_index+1)*SAMPLE_NUM_CHANNELS]
result = vae.encode(images).latent_dist.sample()
latent_noises[:, view_index*LATENT_NUM_CHANNELS:(view_index+1)*LATENT_NUM_CHANNELS] = result
del noise
del noisy_images
unet_results = unet(latent_images, timesteps, labels[batch_index:batch_end])[0]
unet_results = unet_results.to(dtype=torch.float16)
loss = F.mse_loss(unet_results, latent_noises)
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
if (epoch + 1) % save_model_interval == 0:
model = RCTDiffusionPipeline(accelerator.unwrap_model(unet), scheduler, vae)
save_and_test(model, epoch)
progress_bar.update(1)
if __name__ == '__main__':
train_model(1, save_model_interval=1)