rct_model / train_model.py
frutiemax's picture
Use vae for encoding and decoding for training
9bde8da
raw
history blame
8.16 kB
from datasets import load_dataset
from PIL.Image import Image
import PIL
from PIL.Image import Resampling
import numpy as np
from rct_diffusion_pipeline import RCTDiffusionPipeline
import torch
import torchvision.transforms as T
import torchvision
import torch.nn.functional as F
from diffusers.optimization import get_cosine_schedule_with_warmup
from tqdm.auto import tqdm
from accelerate import Accelerator
from diffusers import DDPMScheduler, UNet2DConditionModel, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
import torch.nn as nn
SAMPLE_SIZE = 256
LATENT_SIZE = 32
SAMPLE_NUM_CHANNELS = 3
LATENT_NUM_CHANNELS = 4
from torchvision import transforms
def save_and_test(pipeline, epoch):
outputs = pipeline(['aleppo pine tree'], ['dark green'])
for image_index in range(len(outputs)):
file_name = f'out{image_index}_{epoch}.png'
outputs[image_index].save(file_name)
model_file = f'rct_foliage_{epoch}'
pipeline.save_pretrained(model_file)
def transform_images(image):
res = torch.Tensor(SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE)
pil_to_tensor = T.PILToTensor()
res_index = 0
scale_factor = np.minimum(SAMPLE_SIZE / image.width, SAMPLE_SIZE / image.height)
image = Image.resize(image, size=(int(scale_factor * image.width), int(scale_factor * image.height)), resample=Resampling.NEAREST)
new_image = PIL.Image.new('RGB', (SAMPLE_SIZE, SAMPLE_SIZE))
new_image.paste(image, box=(int((SAMPLE_SIZE - image.width)/2), int((SAMPLE_SIZE - image.height)/2)))
res = pil_to_tensor(new_image)
return res
def convert_images(dataset):
images = [transform_images(image) for image in dataset["image"]]
object_descriptions = [obj_desc for obj_desc in dataset["object_description"]]
colors1 = [color1 for color1 in dataset['color1']]
colors2 = [color1 for color1 in dataset['color2']]
colors3 = [color1 for color1 in dataset['color3']]
return {"image": images, 'object_description':object_descriptions, 'color1':colors1, \
'color2':colors2, 'color3':colors3}
def convert_labels(dataset, model, num_images):
# get the labels
view0_entries = [row for row in dataset if row['view'] == 0]
obj_descriptions = [row['object_description'] for row in view0_entries]
colors1 = [row['color1'] for row in view0_entries]
colors2 = [row['color2'] for row in view0_entries]
colors3 = [row['color3'] for row in view0_entries]
del view0_entries
# convert those descriptions, color1, color2 and color3 to a list of tuple with label and weight=1.0
obj_descriptions = [[(obj_desc, 1.0)] for obj_desc in obj_descriptions]
colors1 = [[(color1, 1.0)] for color1 in colors1]
colors2 = [[(color2, 1.0)] for color2 in colors2]
colors3 = [[(color3, 1.0)] for color3 in colors3]
# convert those tuples in numpy arrays using the helper function of the model
obj_descriptions = [model.get_object_description_weights(obj_desc) for obj_desc in obj_descriptions]
colors1 = [model.get_color1_weights(color1) for color1 in colors1]
colors2 = [model.get_color2_weights(color2) for color2 in colors2]
colors3 = [model.get_color3_weights(color3) for color3 in colors3]
# finally, convert those numpy arrays to a tensor
class_labels = model.pack_labels_to_tensor(num_images, obj_descriptions, colors1, colors2, colors3)
del obj_descriptions
del colors1
del colors2
del colors3
del dataset
return class_labels.to(dtype=torch.float16, device='cuda')
def create_embeddings(dataset, model):
object_descriptions = dataset['object_description']
colors1 = dataset['color1']
colors2 = dataset['color2']
colors3 = dataset['color3']
return model.test_generate_embeddings(object_descriptions, colors1, colors2, colors3)
def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timesteps=100, save_model_interval=10, start_learning_rate=1e-4, lr_warmup_steps=500):
dataset = load_dataset('frutiemax/rct_dataset', split=f'train[0:{total_images}]')
dataset.set_transform(convert_images)
num_images = dataset.num_rows
unet = UNet2DConditionModel(sample_size=LATENT_SIZE, in_channels=LATENT_NUM_CHANNELS, out_channels=LATENT_NUM_CHANNELS, \
down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\
up_block_types=("UpBlock2D","CrossAttnUpBlock2D","CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), cross_attention_dim=768,
block_out_channels=(128, 256, 512, 512), norm_num_groups=32)
unet = unet.to(dtype=torch.float32)
#https://discuss.pytorch.org/t/training-with-half-precision/11815
for layer in unet.modules():
if isinstance(layer, nn.BatchNorm2d):
layer.float()
scheduler = DDPMScheduler(num_train_timesteps=scheduler_num_timesteps)
tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(
"CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True
).to('cuda')
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
vae = vae.to(dtype=torch.float32, device='cuda')
optimizer = torch.optim.AdamW(unet.parameters(), lr=start_learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=lr_warmup_steps,
num_training_steps=num_images * epochs
)
model = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
unet = unet.to('cuda')
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
# lets train for 100 epoch for each sprite in the dataset with a random noise level
progress_bar = tqdm(total=epochs)
loss_fn = torch.nn.MSELoss()
tensor_to_pillow = T.ToPILImage()
for epoch in range(epochs):
# create a noisy version of each sprite
for step, batch in enumerate(train_dataloader):
clean_images = batch['image']
batch_size = clean_images.size(0)
embeddings = create_embeddings(batch, model)
clean_images = torch.reshape(clean_images, (batch['image'].size(0), SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE)).\
to(device='cuda')
noise = torch.randn(clean_images.shape, dtype=torch.float32, device='cuda')
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size, )).to(device='cuda')
#timesteps = timesteps.to(dtype=torch.int, device='cuda')
noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
batch_embeddings = embeddings
batch_embeddings = batch_embeddings.to('cuda')
# use the vae to get the latent images
latent_images = vae.encode(noisy_images).latent_dist.sample()
optimizer.zero_grad()
unet_results = unet(latent_images, timesteps, batch_embeddings).sample
# get back the upscale result
noise_pred = vae.decode(unet_results).sample
loss = loss_fn(noise_pred, noise)
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.set_description(f'epoch={epoch}, batch_index={step}, last_loss={loss.item()}')
if (epoch + 1) % save_model_interval == 0:
# inference in float16
model = RCTDiffusionPipeline(unet.to(dtype=torch.float16), scheduler, \
vae.to(dtype=torch.float16), tokenizer, text_encoder.to(dtype=torch.float16))
save_and_test(model, epoch)
# training in float32
unet.to(dtype=torch.float32)
vae.to(dtype=torch.float32)
text_encoder.to(dtype=torch.float32)
progress_bar.update(1)
if __name__ == '__main__':
train_model(1, save_model_interval=10, epochs=100)