File size: 8,162 Bytes
d751051 9bde8da d751051 f6f5f48 054faf7 6fa0b52 42f8b67 054faf7 2104644 054faf7 42f8b67 d751051 6fa0b52 d751051 6fa0b52 d751051 9bde8da d751051 9bde8da 42f8b67 d751051 42f8b67 d751051 054faf7 d751051 054faf7 d751051 054faf7 42f8b67 054faf7 9bde8da 42f8b67 9bde8da 054faf7 42f8b67 6fa0b52 42f8b67 6fa0b52 2104644 9bde8da d751051 42f8b67 d751051 6fa0b52 42f8b67 6fa0b52 d751051 f6f5f48 6fa0b52 d751051 42f8b67 9bde8da 42f8b67 d751051 42f8b67 6fa0b52 f6f5f48 21640c8 6fa0b52 42f8b67 6fa0b52 9bde8da 6fa0b52 9bde8da 6fa0b52 42f8b67 d751051 42f8b67 d751051 42f8b67 d751051 9bde8da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
from datasets import load_dataset
from PIL.Image import Image
import PIL
from PIL.Image import Resampling
import numpy as np
from rct_diffusion_pipeline import RCTDiffusionPipeline
import torch
import torchvision.transforms as T
import torchvision
import torch.nn.functional as F
from diffusers.optimization import get_cosine_schedule_with_warmup
from tqdm.auto import tqdm
from accelerate import Accelerator
from diffusers import DDPMScheduler, UNet2DConditionModel, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
import torch.nn as nn
SAMPLE_SIZE = 256
LATENT_SIZE = 32
SAMPLE_NUM_CHANNELS = 3
LATENT_NUM_CHANNELS = 4
from torchvision import transforms
def save_and_test(pipeline, epoch):
outputs = pipeline(['aleppo pine tree'], ['dark green'])
for image_index in range(len(outputs)):
file_name = f'out{image_index}_{epoch}.png'
outputs[image_index].save(file_name)
model_file = f'rct_foliage_{epoch}'
pipeline.save_pretrained(model_file)
def transform_images(image):
res = torch.Tensor(SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE)
pil_to_tensor = T.PILToTensor()
res_index = 0
scale_factor = np.minimum(SAMPLE_SIZE / image.width, SAMPLE_SIZE / image.height)
image = Image.resize(image, size=(int(scale_factor * image.width), int(scale_factor * image.height)), resample=Resampling.NEAREST)
new_image = PIL.Image.new('RGB', (SAMPLE_SIZE, SAMPLE_SIZE))
new_image.paste(image, box=(int((SAMPLE_SIZE - image.width)/2), int((SAMPLE_SIZE - image.height)/2)))
res = pil_to_tensor(new_image)
return res
def convert_images(dataset):
images = [transform_images(image) for image in dataset["image"]]
object_descriptions = [obj_desc for obj_desc in dataset["object_description"]]
colors1 = [color1 for color1 in dataset['color1']]
colors2 = [color1 for color1 in dataset['color2']]
colors3 = [color1 for color1 in dataset['color3']]
return {"image": images, 'object_description':object_descriptions, 'color1':colors1, \
'color2':colors2, 'color3':colors3}
def convert_labels(dataset, model, num_images):
# get the labels
view0_entries = [row for row in dataset if row['view'] == 0]
obj_descriptions = [row['object_description'] for row in view0_entries]
colors1 = [row['color1'] for row in view0_entries]
colors2 = [row['color2'] for row in view0_entries]
colors3 = [row['color3'] for row in view0_entries]
del view0_entries
# convert those descriptions, color1, color2 and color3 to a list of tuple with label and weight=1.0
obj_descriptions = [[(obj_desc, 1.0)] for obj_desc in obj_descriptions]
colors1 = [[(color1, 1.0)] for color1 in colors1]
colors2 = [[(color2, 1.0)] for color2 in colors2]
colors3 = [[(color3, 1.0)] for color3 in colors3]
# convert those tuples in numpy arrays using the helper function of the model
obj_descriptions = [model.get_object_description_weights(obj_desc) for obj_desc in obj_descriptions]
colors1 = [model.get_color1_weights(color1) for color1 in colors1]
colors2 = [model.get_color2_weights(color2) for color2 in colors2]
colors3 = [model.get_color3_weights(color3) for color3 in colors3]
# finally, convert those numpy arrays to a tensor
class_labels = model.pack_labels_to_tensor(num_images, obj_descriptions, colors1, colors2, colors3)
del obj_descriptions
del colors1
del colors2
del colors3
del dataset
return class_labels.to(dtype=torch.float16, device='cuda')
def create_embeddings(dataset, model):
object_descriptions = dataset['object_description']
colors1 = dataset['color1']
colors2 = dataset['color2']
colors3 = dataset['color3']
return model.test_generate_embeddings(object_descriptions, colors1, colors2, colors3)
def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timesteps=100, save_model_interval=10, start_learning_rate=1e-4, lr_warmup_steps=500):
dataset = load_dataset('frutiemax/rct_dataset', split=f'train[0:{total_images}]')
dataset.set_transform(convert_images)
num_images = dataset.num_rows
unet = UNet2DConditionModel(sample_size=LATENT_SIZE, in_channels=LATENT_NUM_CHANNELS, out_channels=LATENT_NUM_CHANNELS, \
down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\
up_block_types=("UpBlock2D","CrossAttnUpBlock2D","CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), cross_attention_dim=768,
block_out_channels=(128, 256, 512, 512), norm_num_groups=32)
unet = unet.to(dtype=torch.float32)
#https://discuss.pytorch.org/t/training-with-half-precision/11815
for layer in unet.modules():
if isinstance(layer, nn.BatchNorm2d):
layer.float()
scheduler = DDPMScheduler(num_train_timesteps=scheduler_num_timesteps)
tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(
"CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True
).to('cuda')
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
vae = vae.to(dtype=torch.float32, device='cuda')
optimizer = torch.optim.AdamW(unet.parameters(), lr=start_learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=lr_warmup_steps,
num_training_steps=num_images * epochs
)
model = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
unet = unet.to('cuda')
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
# lets train for 100 epoch for each sprite in the dataset with a random noise level
progress_bar = tqdm(total=epochs)
loss_fn = torch.nn.MSELoss()
tensor_to_pillow = T.ToPILImage()
for epoch in range(epochs):
# create a noisy version of each sprite
for step, batch in enumerate(train_dataloader):
clean_images = batch['image']
batch_size = clean_images.size(0)
embeddings = create_embeddings(batch, model)
clean_images = torch.reshape(clean_images, (batch['image'].size(0), SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE)).\
to(device='cuda')
noise = torch.randn(clean_images.shape, dtype=torch.float32, device='cuda')
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size, )).to(device='cuda')
#timesteps = timesteps.to(dtype=torch.int, device='cuda')
noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
batch_embeddings = embeddings
batch_embeddings = batch_embeddings.to('cuda')
# use the vae to get the latent images
latent_images = vae.encode(noisy_images).latent_dist.sample()
optimizer.zero_grad()
unet_results = unet(latent_images, timesteps, batch_embeddings).sample
# get back the upscale result
noise_pred = vae.decode(unet_results).sample
loss = loss_fn(noise_pred, noise)
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.set_description(f'epoch={epoch}, batch_index={step}, last_loss={loss.item()}')
if (epoch + 1) % save_model_interval == 0:
# inference in float16
model = RCTDiffusionPipeline(unet.to(dtype=torch.float16), scheduler, \
vae.to(dtype=torch.float16), tokenizer, text_encoder.to(dtype=torch.float16))
save_and_test(model, epoch)
# training in float32
unet.to(dtype=torch.float32)
vae.to(dtype=torch.float32)
text_encoder.to(dtype=torch.float32)
progress_bar.update(1)
if __name__ == '__main__':
train_model(1, save_model_interval=10, epochs=100) |