Use SGD and text encoder/tokenizer
Browse files- rct_diffusion_pipeline.py +70 -7
- test_pipeline.py +21 -3
- train_model.py +82 -46
rct_diffusion_pipeline.py
CHANGED
@@ -12,7 +12,7 @@ import pandas as pd
|
|
12 |
from tqdm.auto import tqdm
|
13 |
|
14 |
class RCTDiffusionPipeline(DiffusionPipeline):
|
15 |
-
def __init__(self, unet, scheduler, vae, latent_size=32, sample_size=256):
|
16 |
super().__init__()
|
17 |
|
18 |
# dictionnary that keeps the different classes of object description, color1, color2 and color3
|
@@ -26,11 +26,13 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
26 |
self.vae = vae
|
27 |
self.latent_size = latent_size
|
28 |
self.sample_size = sample_size
|
|
|
|
|
29 |
|
30 |
# channels for 1 image
|
31 |
self.num_channels = int(self.unet.config.in_channels / 4)
|
32 |
self.load_dictionaries_from_dataset()
|
33 |
-
self.register_modules(unet=unet, scheduler=scheduler, vae=vae)
|
34 |
|
35 |
def load_dictionaries_from_dataset(self):
|
36 |
dataset = load_dataset('frutiemax/rct_dataset')
|
@@ -177,13 +179,72 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
177 |
|
178 |
return torch.reshape(noise_batches, (batch_size, 1, self.num_channels*4, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
|
179 |
|
180 |
-
def
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
batch_size=1, num_inference_steps=20, generator=torch.manual_seed(torch.random.seed())):
|
183 |
|
184 |
-
|
185 |
-
if
|
186 |
return None
|
|
|
|
|
187 |
|
188 |
# set the inference steps
|
189 |
self.scheduler.set_timesteps(num_inference_steps)
|
@@ -196,8 +257,9 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
196 |
progress_bar.set_description(f'Inference step {epoch}')
|
197 |
|
198 |
for batch_index in range(batch_size):
|
|
|
199 |
with torch.no_grad():
|
200 |
-
noise_residual = self.unet(noise_batches[batch_index], t, encoder_hidden_states=
|
201 |
previous_noisy_sample = self.scheduler.step(noise_residual, t, noise_batches[batch_index]).prev_sample
|
202 |
noise_batches[batch_index] = previous_noisy_sample
|
203 |
progress_bar.update(1)
|
@@ -223,6 +285,7 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
223 |
image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
|
224 |
image = (image * 255).round().astype("uint8")
|
225 |
image = Image.fromarray(image)
|
|
|
226 |
output_images.append(image)
|
227 |
|
228 |
# for now just return the images
|
|
|
12 |
from tqdm.auto import tqdm
|
13 |
|
14 |
class RCTDiffusionPipeline(DiffusionPipeline):
|
15 |
+
def __init__(self, unet, scheduler, vae, text_tokenizer, text_encoder, latent_size=32, sample_size=256):
|
16 |
super().__init__()
|
17 |
|
18 |
# dictionnary that keeps the different classes of object description, color1, color2 and color3
|
|
|
26 |
self.vae = vae
|
27 |
self.latent_size = latent_size
|
28 |
self.sample_size = sample_size
|
29 |
+
self.text_encoder = text_encoder
|
30 |
+
self.text_tokenizer = text_tokenizer
|
31 |
|
32 |
# channels for 1 image
|
33 |
self.num_channels = int(self.unet.config.in_channels / 4)
|
34 |
self.load_dictionaries_from_dataset()
|
35 |
+
self.register_modules(unet=unet, scheduler=scheduler, vae=vae, text_tokenizer=text_tokenizer, text_encoder=text_encoder)
|
36 |
|
37 |
def load_dictionaries_from_dataset(self):
|
38 |
dataset = load_dataset('frutiemax/rct_dataset')
|
|
|
179 |
|
180 |
return torch.reshape(noise_batches, (batch_size, 1, self.num_channels*4, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
|
181 |
|
182 |
+
def test_generate_embeddings(self, object_description, color1, color2, color3) -> torch.Tensor:
|
183 |
+
batch_size = len(object_description)
|
184 |
+
|
185 |
+
embeddings = torch.Tensor(size=(batch_size, 77, 768))
|
186 |
+
for batch_index in range(batch_size):
|
187 |
+
prompt = f'{object_description[batch_index]},{color1[batch_index]},{color2[batch_index]}, {color3[batch_index]}'
|
188 |
+
tokens = self.text_tokenizer(prompt, \
|
189 |
+
padding="max_length", max_length=self.text_tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
190 |
+
with torch.no_grad():
|
191 |
+
embeddings[batch_index] = self.text_encoder(tokens.input_ids.to('cuda'))[0]
|
192 |
+
|
193 |
+
return embeddings.to(dtype=torch.float16)
|
194 |
+
|
195 |
+
def generate_embeddings(self, object_description, color1, color2, color3) -> torch.Tensor:
|
196 |
+
batch_size = len(object_description)
|
197 |
+
|
198 |
+
embeddings = torch.Tensor(size=(batch_size, 77, 768 * 4))
|
199 |
+
for batch_index in range(batch_size):
|
200 |
+
object_description_tokens = self.text_tokenizer(object_description[batch_index], \
|
201 |
+
padding="max_length", max_length=self.text_tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
202 |
+
color1_tokens = self.text_tokenizer(color1[batch_index], \
|
203 |
+
padding="max_length", max_length=self.text_tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
204 |
+
color2_tokens = self.text_tokenizer(color2[batch_index], \
|
205 |
+
padding="max_length", max_length=self.text_tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
206 |
+
color3_tokens = self.text_tokenizer(color3[batch_index], \
|
207 |
+
padding="max_length", max_length=self.text_tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
208 |
+
with torch.no_grad():
|
209 |
+
object_description_embeddings = self.text_encoder(object_description_tokens.input_ids.to('cuda'))[0]
|
210 |
+
color1_embeddings = self.text_encoder(color1_tokens.input_ids.to('cuda'))[0]
|
211 |
+
color2_embeddings = self.text_encoder(color2_tokens.input_ids.to('cuda'))[0]
|
212 |
+
color3_embeddings = self.text_encoder(color3_tokens.input_ids.to('cuda'))[0]
|
213 |
+
|
214 |
+
emb = torch.cat([object_description_embeddings, color1_embeddings, color2_embeddings, color3_embeddings], dim=2)
|
215 |
+
embeddings[batch_index] = emb
|
216 |
+
|
217 |
+
return embeddings.to(dtype=torch.float16)
|
218 |
+
|
219 |
+
def validate_inputs(self, object_description : list[str], color1 : list[str], \
|
220 |
+
color2 : list[str], color3 : list[str], batch_size) -> tuple[bool, list[str], list[str], list[str], list[str]]:
|
221 |
+
# check if the labels sizes are correct
|
222 |
+
if len(object_description) != batch_size:
|
223 |
+
return False
|
224 |
+
|
225 |
+
if len(color1) != batch_size:
|
226 |
+
return False
|
227 |
+
|
228 |
+
if color2 == None:
|
229 |
+
color2 = ['none'] * batch_size
|
230 |
+
elif len(color2) != batch_size:
|
231 |
+
return False
|
232 |
+
|
233 |
+
if color3 == None:
|
234 |
+
color3 = ['none'] * batch_size
|
235 |
+
elif len(color3) != batch_size:
|
236 |
+
return False
|
237 |
+
return True, object_description, color1, color2, color3
|
238 |
+
|
239 |
+
def __call__(self, object_description : list[str], color1 : list[str], \
|
240 |
+
color2 : list[str] = None, color3 : list[str] = None, \
|
241 |
batch_size=1, num_inference_steps=20, generator=torch.manual_seed(torch.random.seed())):
|
242 |
|
243 |
+
res, object_description, color1, color2, color3 = self.validate_inputs(object_description, color1, color2, color3, batch_size)
|
244 |
+
if res == False:
|
245 |
return None
|
246 |
+
embeddings = self.test_generate_embeddings(object_description, color1, color2, color3)
|
247 |
+
embeddings = embeddings.to('cuda')
|
248 |
|
249 |
# set the inference steps
|
250 |
self.scheduler.set_timesteps(num_inference_steps)
|
|
|
257 |
progress_bar.set_description(f'Inference step {epoch}')
|
258 |
|
259 |
for batch_index in range(batch_size):
|
260 |
+
noise_batches[batch_index] = self.scheduler.scale_model_input(noise_batches[batch_index], timestep=t)
|
261 |
with torch.no_grad():
|
262 |
+
noise_residual = self.unet(noise_batches[batch_index], t, encoder_hidden_states=embeddings).sample
|
263 |
previous_noisy_sample = self.scheduler.step(noise_residual, t, noise_batches[batch_index]).prev_sample
|
264 |
noise_batches[batch_index] = previous_noisy_sample
|
265 |
progress_bar.update(1)
|
|
|
285 |
image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
|
286 |
image = (image * 255).round().astype("uint8")
|
287 |
image = Image.fromarray(image)
|
288 |
+
image.save(f'test{image_index}.png')
|
289 |
output_images.append(image)
|
290 |
|
291 |
# for now just return the images
|
test_pipeline.py
CHANGED
@@ -1,20 +1,38 @@
|
|
1 |
from rct_diffusion_pipeline import RCTDiffusionPipeline
|
2 |
from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
|
3 |
import torch
|
|
|
4 |
|
5 |
torch_device = "cuda"
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
unet = UNet2DConditionModel(sample_size=32, in_channels=16, out_channels=16, \
|
8 |
down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'),\
|
9 |
-
up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=
|
10 |
block_out_channels=(64, 128, 256), norm_num_groups=32)
|
11 |
unet = unet.to('cuda', dtype=torch.float16)
|
12 |
scheduler = DDPMScheduler(num_train_timesteps=20)
|
13 |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
|
14 |
vae = vae.to('cuda', dtype=torch.float16)
|
15 |
|
16 |
-
pipeline = RCTDiffusionPipeline(unet, scheduler, vae)
|
17 |
-
output = pipeline([
|
18 |
pipeline.save_pretrained('test')
|
19 |
|
20 |
# from PIL import Image
|
|
|
1 |
from rct_diffusion_pipeline import RCTDiffusionPipeline
|
2 |
from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
|
3 |
import torch
|
4 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
5 |
|
6 |
torch_device = "cuda"
|
7 |
|
8 |
+
# test of text tokenizers
|
9 |
+
tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
|
10 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
11 |
+
"CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True
|
12 |
+
).to('cuda')
|
13 |
+
|
14 |
+
test1 = tokenizer(['aleppo pine tree, common oak tree'], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
15 |
+
#test3 = tokenizer([1.0, 0.0, .05], is_split_into_words=True, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
16 |
+
|
17 |
+
with torch.no_grad():
|
18 |
+
test1 = text_encoder(test1.input_ids.to('cuda'))[0]
|
19 |
+
|
20 |
+
test2 = tokenizer('dark green', padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
21 |
+
|
22 |
+
with torch.no_grad():
|
23 |
+
test2 = text_encoder(test2.input_ids.to('cuda'))[0]
|
24 |
+
|
25 |
unet = UNet2DConditionModel(sample_size=32, in_channels=16, out_channels=16, \
|
26 |
down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'),\
|
27 |
+
up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=768*4,
|
28 |
block_out_channels=(64, 128, 256), norm_num_groups=32)
|
29 |
unet = unet.to('cuda', dtype=torch.float16)
|
30 |
scheduler = DDPMScheduler(num_train_timesteps=20)
|
31 |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
|
32 |
vae = vae.to('cuda', dtype=torch.float16)
|
33 |
|
34 |
+
pipeline = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
|
35 |
+
output = pipeline(['aleppo pine tree'], ['dark green'])
|
36 |
pipeline.save_pretrained('test')
|
37 |
|
38 |
# from PIL import Image
|
train_model.py
CHANGED
@@ -11,6 +11,7 @@ from diffusers.optimization import get_cosine_schedule_with_warmup
|
|
11 |
from tqdm.auto import tqdm
|
12 |
from accelerate import Accelerator
|
13 |
from diffusers import DDPMScheduler, UNet2DConditionModel, AutoencoderKL
|
|
|
14 |
|
15 |
SAMPLE_SIZE = 256
|
16 |
LATENT_SIZE = 32
|
@@ -18,12 +19,12 @@ SAMPLE_NUM_CHANNELS = 3
|
|
18 |
LATENT_NUM_CHANNELS = 4
|
19 |
|
20 |
def save_and_test(pipeline, epoch):
|
21 |
-
outputs = pipeline([
|
22 |
for image_index in range(len(outputs)):
|
23 |
file_name = f'out{image_index}_{epoch}.png'
|
24 |
outputs[image_index].save(file_name)
|
25 |
|
26 |
-
model_file = f'rct_foliage_{epoch}
|
27 |
pipeline.save_pretrained(model_file)
|
28 |
|
29 |
def convert_images(dataset):
|
@@ -42,18 +43,18 @@ def convert_images(dataset):
|
|
42 |
for entry in views[view_index]:
|
43 |
image = entry['image']
|
44 |
|
45 |
-
scale_factor =
|
46 |
-
image = Image.resize(image, size=(scale_factor * image.width, scale_factor * image.height), resample=Resampling.NEAREST)
|
47 |
|
48 |
-
new_image = PIL.Image.new('
|
49 |
-
new_image.paste(image, box=(int((
|
50 |
images.append(new_image)
|
51 |
image_views.append(images)
|
52 |
|
53 |
del views
|
54 |
|
55 |
# convert those views in tensors
|
56 |
-
targets = torch.Tensor(size=(num_images, 4,
|
57 |
pillow_to_tensor = T.ToTensor()
|
58 |
|
59 |
for image_index in range(num_images):
|
@@ -62,7 +63,7 @@ def convert_images(dataset):
|
|
62 |
del image_views
|
63 |
del entries
|
64 |
|
65 |
-
return torch.reshape(targets, (num_images, 4 *
|
66 |
|
67 |
def convert_labels(dataset, model, num_images):
|
68 |
# get the labels
|
@@ -96,80 +97,115 @@ def convert_labels(dataset, model, num_images):
|
|
96 |
del dataset
|
97 |
return class_labels.to(dtype=torch.float16, device='cuda')
|
98 |
|
99 |
-
def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model_interval=10, start_learning_rate=1e-3, lr_warmup_steps=
|
100 |
dataset = load_dataset('frutiemax/rct_dataset')
|
101 |
dataset = dataset['train']
|
102 |
|
103 |
targets = convert_images(dataset)
|
104 |
-
num_images = int(dataset.num_rows / 4)
|
105 |
|
106 |
-
unet = UNet2DConditionModel(sample_size=LATENT_SIZE, in_channels=LATENT_NUM_CHANNELS
|
107 |
-
down_block_types=(
|
108 |
-
up_block_types=(
|
109 |
-
block_out_channels=(
|
110 |
unet = unet.to(dtype=torch.float16)
|
111 |
-
scheduler = DDPMScheduler(num_train_timesteps=
|
|
|
|
|
|
|
|
|
112 |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
|
113 |
vae = vae.to(dtype=torch.float16)
|
114 |
|
115 |
-
optimizer = torch.optim.
|
116 |
lr_scheduler = get_cosine_schedule_with_warmup(
|
117 |
optimizer=optimizer,
|
118 |
num_warmup_steps=lr_warmup_steps,
|
119 |
num_training_steps=num_images * epochs
|
120 |
)
|
121 |
-
model = RCTDiffusionPipeline(unet, scheduler, vae)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
labels = convert_labels(dataset, model, num_images)
|
123 |
del model
|
124 |
|
|
|
|
|
|
|
|
|
|
|
125 |
# lets train for 100 epoch for each sprite in the dataset with a random noise level
|
126 |
progress_bar = tqdm(total=epochs)
|
127 |
accelerator = Accelerator(mixed_precision='fp16')
|
|
|
128 |
unet, scheduler, lr_scheduler, vae = accelerator.prepare(unet, scheduler, lr_scheduler, vae)
|
129 |
|
|
|
|
|
|
|
130 |
for epoch in range(epochs):
|
131 |
# create a noisy version of each sprite
|
132 |
for batch_index in range(0, num_images, batch_size):
|
133 |
-
progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}')
|
134 |
batch_end = np.minimum(num_images, batch_index + batch_size)
|
135 |
clean_images = targets[batch_index:batch_end]
|
136 |
-
clean_images = torch.reshape(clean_images, ((batch_end - batch_index),
|
|
|
137 |
|
138 |
noise = torch.randn(clean_images.shape, dtype=torch.float16, device='cuda')
|
139 |
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_end - batch_index, )).to(device='cuda')
|
|
|
140 |
#timesteps = timesteps.to(dtype=torch.int, device='cuda')
|
141 |
noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
|
142 |
-
|
143 |
-
|
144 |
-
#
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
if (epoch + 1) % save_model_interval == 0:
|
169 |
-
model = RCTDiffusionPipeline(accelerator.unwrap_model(unet), scheduler, vae)
|
170 |
save_and_test(model, epoch)
|
171 |
progress_bar.update(1)
|
172 |
|
173 |
|
174 |
if __name__ == '__main__':
|
175 |
-
train_model(1, save_model_interval=1)
|
|
|
11 |
from tqdm.auto import tqdm
|
12 |
from accelerate import Accelerator
|
13 |
from diffusers import DDPMScheduler, UNet2DConditionModel, AutoencoderKL
|
14 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
15 |
|
16 |
SAMPLE_SIZE = 256
|
17 |
LATENT_SIZE = 32
|
|
|
19 |
LATENT_NUM_CHANNELS = 4
|
20 |
|
21 |
def save_and_test(pipeline, epoch):
|
22 |
+
outputs = pipeline(['aleppo pine tree'], ['dark green'])
|
23 |
for image_index in range(len(outputs)):
|
24 |
file_name = f'out{image_index}_{epoch}.png'
|
25 |
outputs[image_index].save(file_name)
|
26 |
|
27 |
+
model_file = f'rct_foliage_{epoch}'
|
28 |
pipeline.save_pretrained(model_file)
|
29 |
|
30 |
def convert_images(dataset):
|
|
|
43 |
for entry in views[view_index]:
|
44 |
image = entry['image']
|
45 |
|
46 |
+
scale_factor = np.minimum(LATENT_SIZE / image.width, LATENT_SIZE / image.height)
|
47 |
+
image = Image.resize(image, size=(int(scale_factor * image.width), int(scale_factor * image.height)), resample=Resampling.NEAREST)
|
48 |
|
49 |
+
new_image = PIL.Image.new('RGBA', (LATENT_SIZE, LATENT_SIZE))
|
50 |
+
new_image.paste(image, box=(int((LATENT_SIZE - image.width)/2), int((LATENT_SIZE - image.height)/2)))
|
51 |
images.append(new_image)
|
52 |
image_views.append(images)
|
53 |
|
54 |
del views
|
55 |
|
56 |
# convert those views in tensors
|
57 |
+
targets = torch.Tensor(size=(num_images, 4, LATENT_NUM_CHANNELS, LATENT_SIZE, LATENT_SIZE)).to(dtype=torch.float16)
|
58 |
pillow_to_tensor = T.ToTensor()
|
59 |
|
60 |
for image_index in range(num_images):
|
|
|
63 |
del image_views
|
64 |
del entries
|
65 |
|
66 |
+
return torch.reshape(targets, (num_images, 4 * LATENT_NUM_CHANNELS, LATENT_SIZE, LATENT_SIZE))
|
67 |
|
68 |
def convert_labels(dataset, model, num_images):
|
69 |
# get the labels
|
|
|
97 |
del dataset
|
98 |
return class_labels.to(dtype=torch.float16, device='cuda')
|
99 |
|
100 |
+
def train_model(batch_size=4, total_images=None, epochs=100, scheduler_num_timesteps=20, save_model_interval=10, start_learning_rate=1e-3, lr_warmup_steps=1):
|
101 |
dataset = load_dataset('frutiemax/rct_dataset')
|
102 |
dataset = dataset['train']
|
103 |
|
104 |
targets = convert_images(dataset)
|
105 |
+
num_images = int(dataset.num_rows / 4) if total_images == None else int(total_images / 4)
|
106 |
|
107 |
+
unet = UNet2DConditionModel(sample_size=LATENT_SIZE, in_channels=LATENT_NUM_CHANNELS*4, out_channels=LATENT_NUM_CHANNELS*4, \
|
108 |
+
down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\
|
109 |
+
up_block_types=("UpBlock2D","CrossAttnUpBlock2D","CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), cross_attention_dim=768,
|
110 |
+
block_out_channels=(320, 640, 1280, 1280), norm_num_groups=32)
|
111 |
unet = unet.to(dtype=torch.float16)
|
112 |
+
scheduler = DDPMScheduler(num_train_timesteps=scheduler_num_timesteps)
|
113 |
+
tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
|
114 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
115 |
+
"CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True
|
116 |
+
).to('cuda')
|
117 |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
|
118 |
vae = vae.to(dtype=torch.float16)
|
119 |
|
120 |
+
optimizer = torch.optim.SGD(unet.parameters(), lr=start_learning_rate)
|
121 |
lr_scheduler = get_cosine_schedule_with_warmup(
|
122 |
optimizer=optimizer,
|
123 |
num_warmup_steps=lr_warmup_steps,
|
124 |
num_training_steps=num_images * epochs
|
125 |
)
|
126 |
+
model = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
|
127 |
+
|
128 |
+
# get all the object descriptions, color1, color2, color3
|
129 |
+
object_descriptions = dataset['object_description']
|
130 |
+
colors1 = dataset['color1']
|
131 |
+
colors2 = dataset['color2']
|
132 |
+
colors3 = dataset['color3']
|
133 |
+
|
134 |
+
# we only need 1 of the 4 views
|
135 |
+
object_descriptions = [object_descriptions[desc_index] for desc_index in range(0, len(object_descriptions), 4)]
|
136 |
+
colors1 = [colors1[desc_index] for desc_index in range(0, len(colors1), 4)]
|
137 |
+
colors2 = [colors2[desc_index] for desc_index in range(0, len(colors2), 4)]
|
138 |
+
colors3 = [colors3[desc_index] for desc_index in range(0, len(colors3), 4)]
|
139 |
+
#embeddings = model.generate_embeddings(object_descriptions, colors1, colors2, colors3)
|
140 |
+
embeddings = model.test_generate_embeddings(object_descriptions, colors1, colors2, colors3)
|
141 |
+
|
142 |
labels = convert_labels(dataset, model, num_images)
|
143 |
del model
|
144 |
|
145 |
+
if total_images != None:
|
146 |
+
targets = targets[:int(total_images/4)]
|
147 |
+
label_indices = [index for index in range(0, total_images, 4)]
|
148 |
+
labels = labels[label_indices]
|
149 |
+
|
150 |
# lets train for 100 epoch for each sprite in the dataset with a random noise level
|
151 |
progress_bar = tqdm(total=epochs)
|
152 |
accelerator = Accelerator(mixed_precision='fp16')
|
153 |
+
accelerator.clip_grad_norm_(unet.parameters(), 1.0)
|
154 |
unet, scheduler, lr_scheduler, vae = accelerator.prepare(unet, scheduler, lr_scheduler, vae)
|
155 |
|
156 |
+
loss_fn = torch.nn.MSELoss()
|
157 |
+
|
158 |
+
tensor_to_pillow = T.ToPILImage()
|
159 |
for epoch in range(epochs):
|
160 |
# create a noisy version of each sprite
|
161 |
for batch_index in range(0, num_images, batch_size):
|
|
|
162 |
batch_end = np.minimum(num_images, batch_index + batch_size)
|
163 |
clean_images = targets[batch_index:batch_end]
|
164 |
+
clean_images = torch.reshape(clean_images, ((batch_end - batch_index), LATENT_NUM_CHANNELS * 4, LATENT_SIZE, LATENT_SIZE)).\
|
165 |
+
to(device='cuda', dtype=torch.float16)
|
166 |
|
167 |
noise = torch.randn(clean_images.shape, dtype=torch.float16, device='cuda')
|
168 |
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_end - batch_index, )).to(device='cuda')
|
169 |
+
|
170 |
#timesteps = timesteps.to(dtype=torch.int, device='cuda')
|
171 |
noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
|
172 |
+
|
173 |
+
# with accelerator.accumulate(unet):
|
174 |
+
# assert not torch.any(torch.isnan(timesteps))
|
175 |
+
|
176 |
+
# batch_embeddings = embeddings[batch_index:batch_end]
|
177 |
+
# batch_embeddings = batch_embeddings.to('cuda')
|
178 |
+
|
179 |
+
# optimizer.zero_grad()
|
180 |
+
# unet_results = unet(noisy_images, timesteps, batch_embeddings).sample
|
181 |
+
# unet_results = unet_results.to(dtype=torch.float16)
|
182 |
+
|
183 |
+
# loss = loss_fn(unet_results, noise)
|
184 |
+
# accelerator.backward(loss)
|
185 |
+
|
186 |
+
# optimizer.step()
|
187 |
+
# lr_scheduler.step()
|
188 |
+
# optimizer.zero_grad()
|
189 |
+
|
190 |
+
batch_embeddings = embeddings[batch_index:batch_end]
|
191 |
+
batch_embeddings = batch_embeddings.to('cuda')
|
192 |
+
|
193 |
+
optimizer.zero_grad()
|
194 |
+
unet_results = unet(noisy_images, timesteps, batch_embeddings).sample
|
195 |
+
unet_results = unet_results.to(dtype=torch.float16)
|
196 |
+
loss = loss_fn(unet_results, noise)
|
197 |
+
loss.backward()
|
198 |
+
optimizer.step()
|
199 |
+
lr_scheduler.step()
|
200 |
+
optimizer.zero_grad()
|
201 |
+
|
202 |
+
progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}, last_loss={loss.item()}')
|
203 |
|
204 |
if (epoch + 1) % save_model_interval == 0:
|
205 |
+
model = RCTDiffusionPipeline(accelerator.unwrap_model(unet), scheduler, vae, tokenizer, text_encoder)
|
206 |
save_and_test(model, epoch)
|
207 |
progress_bar.update(1)
|
208 |
|
209 |
|
210 |
if __name__ == '__main__':
|
211 |
+
train_model(1, total_images=4, save_model_interval=1)
|