|
from datasets import load_dataset |
|
from PIL.Image import Image |
|
import PIL |
|
from PIL.Image import Resampling |
|
import numpy as np |
|
from rct_diffusion_pipeline import RCTDiffusionPipeline |
|
import torch |
|
import torchvision.transforms as T |
|
import torch.nn.functional as F |
|
from diffusers.optimization import get_cosine_schedule_with_warmup |
|
from tqdm.auto import tqdm |
|
from accelerate import Accelerator |
|
|
|
def save_and_test(pipeline, epoch): |
|
outputs = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]]) |
|
for image_index in range(len(outputs)): |
|
file_name = f'out{image_index}_{epoch}.png' |
|
outputs[image_index].save(file_name) |
|
|
|
model_file = f'rct_foliage_{epoch}.pth' |
|
pipeline.save_pretrained(model_file) |
|
|
|
def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model_interval=10, start_learning_rate=1e-3, lr_warmup_steps=500): |
|
dataset = load_dataset('frutiemax/rct_dataset') |
|
dataset = dataset['train'] |
|
|
|
num_images = int(dataset.num_rows / 4) |
|
|
|
|
|
views = [] |
|
|
|
for view_index in range(4): |
|
entries = [entry for entry in dataset if entry['view'] == view_index] |
|
views.append(entries) |
|
|
|
|
|
image_views = [] |
|
for view_index in range(4): |
|
images = [] |
|
for entry in views[view_index]: |
|
image = entry['image'] |
|
|
|
scale_factor = int(np.minimum(256 / image.width, 256 / image.height)) |
|
image = Image.resize(image, size=(scale_factor * image.width, scale_factor * image.height), resample=Resampling.NEAREST) |
|
|
|
new_image = PIL.Image.new('RGB', (256, 256)) |
|
new_image.paste(image, box=(int((256 - image.width)/2), int((256 - image.height)/2))) |
|
images.append(new_image) |
|
image_views.append(images) |
|
|
|
del views |
|
|
|
|
|
targets = torch.Tensor(size=(num_images, 4, 3, 256, 256)).to(dtype=torch.float16) |
|
pillow_to_tensor = T.ToTensor() |
|
|
|
for image_index in range(num_images): |
|
for view_index in range(4): |
|
targets[image_index, view_index] = pillow_to_tensor(image_views[view_index][image_index]).to(dtype=torch.float16) |
|
del image_views |
|
del entries |
|
|
|
targets = torch.reshape(targets, (num_images, 12, 256, 256)) |
|
|
|
|
|
view0_entries = [row for row in dataset if row['view'] == 0] |
|
obj_descriptions = [row['object_description'] for row in view0_entries] |
|
colors1 = [row['color1'] for row in view0_entries] |
|
colors2 = [row['color2'] for row in view0_entries] |
|
colors3 = [row['color3'] for row in view0_entries] |
|
|
|
del view0_entries |
|
|
|
|
|
obj_descriptions = [[(obj_desc, 1.0)] for obj_desc in obj_descriptions] |
|
colors1 = [[(color1, 1.0)] for color1 in colors1] |
|
colors2 = [[(color2, 1.0)] for color2 in colors2] |
|
colors3 = [[(color3, 1.0)] for color3 in colors3] |
|
|
|
|
|
model = RCTDiffusionPipeline() |
|
obj_descriptions = [model.get_object_description_weights(obj_desc) for obj_desc in obj_descriptions] |
|
colors1 = [model.get_color1_weights(color1) for color1 in colors1] |
|
colors2 = [model.get_color2_weights(color2) for color2 in colors2] |
|
colors3 = [model.get_color3_weights(color3) for color3 in colors3] |
|
|
|
|
|
class_labels = model.pack_labels_to_tensor(num_images, obj_descriptions, colors1, colors2, colors3) |
|
del obj_descriptions |
|
del colors1 |
|
del colors2 |
|
del colors3 |
|
del dataset |
|
|
|
optimizer = torch.optim.Adam(model.unet.parameters(), lr=start_learning_rate) |
|
lr_scheduler = get_cosine_schedule_with_warmup( |
|
optimizer=optimizer, |
|
num_warmup_steps=lr_warmup_steps, |
|
num_training_steps=num_images * epochs |
|
) |
|
|
|
|
|
progress_bar = tqdm(total=epochs) |
|
accelerator = Accelerator( |
|
mixed_precision='fp16', |
|
gradient_accumulation_steps=1, |
|
log_with="tensorboard", |
|
project_dir='logs', |
|
) |
|
|
|
unet, scheduler, optimizer, lr_scheduler = accelerator.prepare(model.unet, model.scheduler, \ |
|
optimizer, lr_scheduler) |
|
|
|
del model |
|
scheduler.set_timesteps(scheduler_num_timesteps) |
|
|
|
for epoch in range(epochs): |
|
|
|
for batch_index in range(0, num_images, batch_size): |
|
progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}') |
|
batch_end = np.minimum(num_images, batch_index + batch_size) |
|
clean_images = targets[batch_index:batch_end] |
|
clean_images = torch.reshape(clean_images, ((batch_end - batch_index), 12, 256, 256)) |
|
|
|
noise = torch.randn(clean_images.shape, dtype=torch.float16) |
|
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_end - batch_index, )) |
|
|
|
noisy_images = scheduler.add_noise(clean_images, noise, timesteps).to(device='cuda', dtype=torch.float16) |
|
|
|
with accelerator.accumulate(unet): |
|
noise_pred = unet(noisy_images, timesteps.to(device='cuda', dtype=torch.float16), class_labels[batch_index:batch_end].to(device='cuda',dtype=torch.float16), return_dict=False)[0] |
|
|
|
|
|
loss = F.mse_loss(noise_pred, noise.to('cuda', dtype=torch.float16)) |
|
accelerator.backward(loss) |
|
accelerator.clip_grad_norm_(unet.parameters(), 1.0) |
|
|
|
optimizer.step() |
|
lr_scheduler.step() |
|
optimizer.zero_grad() |
|
|
|
if (epoch + 1) % save_model_interval == 0: |
|
model.unet = accelerator.unwrap_model(unet) |
|
model.scheduler = scheduler |
|
save_and_test(model, epoch) |
|
progress_bar.update(1) |
|
|
|
|
|
if __name__ == '__main__': |
|
train_model(4) |